diff --git a/.gitignore b/.gitignore
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..4c4915b1cc49c368023011b3d70de4a536fe29de 100644
--- a/.gitignore
+++ b/.gitignore
@@ -0,0 +1,6 @@
+**/.mypy_cache/
+**/__pycache__/
+**/*.egg-info/
+
+data/
+output/
diff --git a/README.md b/README.md
index 7e99e890dcd8a64c6876c209c9ab1011db7821ff..f782b346fcf1d1c6da5c4a3395b9b117e53fa3ce 100644
--- a/README.md
+++ b/README.md
@@ -40,6 +40,22 @@ This repository is built with python 3 (>=3.7) on pytorch lightning.
 
 ## Usage
 
+Remember to activate your conda environment if you have one.
+
+Train:
+
+```bash
+./train_evreflex.py TRAIN_DIR VAL_DIR OUTPUT_DIR
+```
+
+where `TRAIN_DIR` and `VAL_DIR` can be LMDB directories, or text files containing a list of LMDB directories.
+
+`OUTPUT_DIR` is a tensorboard directory, so to view the output:
+
+```bash
+tensorboard --bind_all --load_fast=true --logdir OUTPUT_DIR
+```
+
 ## Dataset
 
 ### Source files
diff --git a/evreflex/__init__.py b/evreflex/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/evreflex/datasets/EVReflex.py b/evreflex/datasets/EVReflex.py
new file mode 100755
index 0000000000000000000000000000000000000000..2a344c84123a9626c9cff6e45a387005dcc6df35
--- /dev/null
+++ b/evreflex/datasets/EVReflex.py
@@ -0,0 +1,286 @@
+#!/usr/bin/env python
+import torch
+from torch.utils.data import Dataset
+import torchvision.transforms.functional as F
+from PIL import Image
+import numpy as np
+from pathlib import Path
+from natsort import natsorted
+import torchvision.transforms as tfs
+
+from evreflex.datasets.database import ImageDatabase, ArrayDatabase#, PointcloudDatabase
+
+from typing import Optional
+
+TRAIN = 0.7
+VAL = 0.15
+TEST = 0.15 # Not actually used, it's the remainer of the other two
+
+# ==================================================================================================
+class EVReflex(Dataset):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(
+		self,
+		data_path: Path,
+		split: Optional[str] = None,
+		count_only: bool = False,
+		time_only: bool = False,
+		skip_frames: bool = False,
+		transforms=[[]] * 4,
+		vis: bool = True,
+		subset: Optional[bool] = None,
+		include_img: bool = False,
+		include_ev: bool = False,
+		dynamic: bool = False,
+	):
+		self.data_path = data_path
+		self.flow_path = data_path / "flow"
+		self.pcl_path = data_path / "pcl"
+		self.danger_path = data_path / "danger"
+		self.img_path = data_path / "img"
+		self.ev_path = data_path / "event"
+		self.split = split
+		self.count_only = count_only
+		self.time_only = time_only
+		self.skip_frames = skip_frames
+		self.transforms = transforms
+		self.vis = vis
+		self.subset = subset
+		self.dynamic = dynamic
+
+		self.db_flow = ArrayDatabase(self.flow_path)
+		self.db_pcl = ArrayDatabase(self.pcl_path)
+		self.db_danger = ArrayDatabase(self.danger_path)
+		self.db_img = ImageDatabase(self.img_path, mode="L") if include_img else None
+		self.db_ev = ArrayDatabase(self.ev_path) if include_ev else None
+
+		if self.count_only or self.time_only:
+			raise NotImplementedError("See `_read_evs()`")
+
+	# ----------------------------------------------------------------------------------------------
+	def __len__(self):
+		if self.db_img:
+			self.length = min(len(self.db_flow), len(self.db_pcl), len(self.db_danger), len(self.db_img))
+		else:
+			self.length = min(len(self.db_flow), len(self.db_pcl), len(self.db_danger))
+
+		if self.split is None:
+			length = self.length
+		elif self.split == "train":
+			length = int(self.length * TRAIN)
+		elif self.split == "val":
+			length = int(self.length * VAL)
+		elif self.split == "test":
+			length = self.length - int(self.length * TRAIN) - int(self.length * VAL) - 1
+
+		if self.subset is not None:
+			assert self.subset <= 1.0
+			return int(length * self.subset)
+
+		# FIX check if this is still needed. May have been fixed in merged LMDBs
+		# Accounts for danger keys starting at "2", or to provide sequential depth images with each index?
+		length -= 1
+
+		return length
+
+	# ----------------------------------------------------------------------------------------------
+	def __getitem__(self, index):
+		if self.split == "val":
+			index = index + int(self.length * TRAIN)
+		elif self.split == "test":
+			index = index + int(self.length * TRAIN) + int(self.length * VAL)
+		# EVDodge:
+		# It's the stacking of {E+, E-, E_t}:
+		# 1. Per-pixel average number of _positive_ event triggers in the spatio-temporal window spanned between t_0 and t_0 + d_t
+		# 2. Per-pixel average number of _negative_ event triggers in the spatio-temporal window spanned between t_0 and t_0 + d_t
+		# 3. Average trigger time per pixel
+
+		# FIX check if this is still needed. May have been fixed in merged LMDBs
+		# Accounts for danger keys starting at "2", or to provide sequential depth images with each index?
+		# index += 1
+		index += 2
+		flow = self.db_flow[index]
+		points = self.db_pcl[index]
+		if self.dynamic:
+			points2 = self.db_pcl[index + 1]
+			points = np.vstack([points, points2])
+		danger = self.db_danger[index]
+		if self.db_img:
+			img = self.db_img[index]
+		else:
+			img = torch.empty(0)
+		if self.db_ev:
+			# NOTE this follows EVDodge's format
+			events = self.db_ev[index]
+			event_count_images = torch.from_numpy(events[0].astype(np.int16))
+			event_time_images = torch.from_numpy(events[1].astype(np.float32))
+			image_times = torch.from_numpy(events[2].astype(np.float64))
+			event_count_image1, event_time_image1 = self.__read_evs(event_count_images, event_time_images, 0)
+			events1 = torch.stack([event_count_image1[0], event_count_image1[1], event_time_image1[0]])
+			event_count_image2, event_time_image2 = self.__read_evs(event_count_images, event_time_images, 1)
+			events2 = torch.stack([event_count_image2[0], event_count_image2[1], event_time_image2[0]])
+			events = torch.vstack([events1, events2])
+		else:
+			events = torch.empty(0)
+
+		for transform in self.transforms[0]:
+			flow = transform(flow)
+		for transform in self.transforms[1]:
+			points = transform(points)
+		for transform in self.transforms[2]:
+			danger = transform(danger)
+		# FIX this is badly written code
+		if isinstance(img, Image.Image):
+			for transform in self.transforms[3]:
+				img = transform(img)
+		if isinstance(events, np.ndarray):
+			for transform in self.transforms[4]:
+				events = transform(events)
+
+		return flow, points, danger, img, events
+
+	# ----------------------------------------------------------------------------------------------
+	def _read_evs(self, ev_count_imgs, ev_time_imgs, n_frames):
+		# ev_count_imgs = ev_count_imgs.reshape(shape).type(torch.float32)
+		ev_count_img = ev_count_imgs[:n_frames, :, :, :]
+		ev_count_img = torch.sum(ev_count_img, dim=0).type(torch.float32)
+		ev_count_img = ev_count_img.permute(2, 0, 1)
+
+		# ev_time_imgs = ev_time_imgs.reshape(shape).type(torch.float32)
+		ev_time_img = ev_time_imgs[:n_frames, :, :, :]
+		ev_time_img = torch.max(ev_time_img, dim=0)[0]
+
+		ev_time_img /= torch.max(ev_time_img)
+		ev_time_img = ev_time_img.permute(2, 0, 1)
+
+		"""
+		if self.count_only:
+			ev_img = ev_count_img
+		elif self.time_only:
+			ev_img = ev_time_img
+		else:
+			ev_img = torch.cat([ev_count_img, ev_time_img], dim=2)
+
+		ev_img = ev_img.permute(2,0,1).type(torch.float32)
+		"""
+
+		return ev_count_img, ev_time_img
+
+	# ----------------------------------------------------------------------------------------------
+	def __read_evs(self, ev_count_imgs, ev_time_imgs, frame):
+		ev_count_img = ev_count_imgs[frame:frame + 1, :, :, :]
+		ev_count_img = torch.sum(ev_count_img, dim=0).type(torch.float32)
+		ev_count_img = ev_count_img.permute(2, 0, 1)
+
+		ev_time_img = ev_time_imgs[frame:frame + 1, :, :, :]
+		ev_time_img = torch.max(ev_time_img, dim=-1, keepdim=True)[0].squeeze(0)
+
+		ev_time_img /= torch.max(ev_time_img)
+		ev_time_img = ev_time_img.permute(2, 0, 1)
+
+		"""
+		if self.count_only:
+			ev_img = ev_count_img
+		elif self.time_only:
+			ev_img = ev_time_img
+		else:
+			ev_img = torch.cat([ev_count_img, ev_time_img], dim=2)
+
+		ev_img = ev_img.permute(2,0,1).type(torch.float32)
+		"""
+
+		return ev_count_img, ev_time_img
+
+
+	# ----------------------------------------------------------------------------------------------
+	# Argument `points` has 3 channels (dimension 1)
+	# Position (X, Y, Z) (dimension 3 == 0)
+	# colour (C, _, _) (dimension 3 == 1)
+	# If sequential, dimension 3 == 2, 3 is the same, for next frame t+1
+	@staticmethod
+	def points_2_depth(points, sequential: bool):
+		if sequential:
+			depth1 = points[:, 2:3, :, 0]
+			depth2 = points[:, 2:3, :, 2]
+			depth1 = depth1.reshape(*depth1.shape[:2], 260, 346)
+			depth2 = depth2.reshape(*depth2.shape[:2], 260, 346)
+			depth = torch.hstack([depth1, depth2]) # [B, 2, H, W]
+		else:
+			depth = points[:, 2:3, :, 0]
+			depth = depth.reshape(*depth.shape[:2], 260, 346)
+
+		return depth
+
+
+# ==================================================================================================
+if __name__ == "__main__":
+	import colored_traceback.auto
+	import argparse
+	from torch.utils.data import DataLoader
+	from kellog import info, warning, error, debug
+	import cv2
+	import open3d as o3d
+
+	parser = argparse.ArgumentParser()
+	parser.add_argument("data_path", type=Path, help="Path to EVReflex dataset, either LMDB directory or `.txt` list of relative paths")
+	args = parser.parse_args()
+
+	if args.data_path.is_file():
+		with open(args.data_path, "r") as txt:
+			dirs = txt.read().splitlines()
+		dataset = torch.utils.data.ConcatDataset(
+			[EVReflex(
+				data_path=args.data_path.parent / d,
+				split="train",
+				include_img=True,
+				include_ev=True,
+				transforms=[
+					[],
+					[],
+					[tfs.ToTensor()],
+					[tfs.ToTensor()],
+					[],
+				]
+			) for d in dirs]
+		)
+	elif args.data_path.is_dir():
+		dataset = EVReflex(
+			data_path=args.data_path,
+			split="train",
+			include_img=True,
+			include_ev=True,
+			transforms=[
+				[],
+				[],
+				[tfs.ToTensor()],
+				[tfs.ToTensor()],
+				[],
+			]
+		)
+	else:
+		raise ValueError(f"Expected data_path '{args.data_path}' to be a file or directory")
+
+	dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
+	info(f"Length: {len(dataloader)}")
+	info("Press 'q' or escape to quit, any other key to advance")
+	for i, (flow, pcl, danger, img) in enumerate(dataloader):
+		print(f"Showing {i + 1}/{len(dataloader)}")
+		flow = flow[0].numpy()
+		danger = danger[0].numpy()
+		pcl = pcl[0]
+		flow = flow[0].squeeze() # x
+		danger = danger.squeeze()
+
+		# Output is white but reports no error if I don't do this
+		pcl[torch.isinf(pcl)] = 0
+		pcl[torch.isnan(pcl)] = 0
+
+		pcd = o3d.geometry.PointCloud()
+		pcd.points = o3d.utility.Vector3dVector(pcl[0])
+		pcd.colors = o3d.utility.Vector3dVector(pcl[1])
+		o3d.visualization.draw_geometries([pcd])
+		cv2.imshow("flow", flow)
+		cv2.imshow("danger", danger)
+		key = cv2.waitKey(0)
+		if key == ord("q") or key == 27: # 27 = escape
+			break
diff --git a/evreflex/datasets/__init__.py b/evreflex/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..617ffc6a9c5e50c701565cb2d0f3cacf50bcf848
--- /dev/null
+++ b/evreflex/datasets/__init__.py
@@ -0,0 +1,3 @@
+"""Datasets for Dataloaders, and lmdb functions"""
+from .database import ImageDatabase, ArrayDatabase, PointcloudDatabase
+from .EVReflex import EVReflex
diff --git a/evreflex/datasets/database.py b/evreflex/datasets/database.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dd2259a69f68292a01dd54531fdb343c257be7b
--- /dev/null
+++ b/evreflex/datasets/database.py
@@ -0,0 +1,135 @@
+#!/usr/bin/env python3
+from pathlib import Path
+import io
+import lmdb
+import pickle
+from PIL import Image
+from PIL import ImageFile
+import numpy as np
+import open3d as o3d
+import tempfile
+from kellog import info, warning, error, debug
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+# ==================================================================================================
+class Database(object):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, path: Path):
+		self.path = path
+		if not self.path.exists():
+			raise FileNotFoundError(self.path)
+		self.db = lmdb.open(
+			path=str(self.path),
+			readonly=True,
+			readahead=False,
+			max_spare_txns=128,
+			lock=False,
+		)
+		with self.db.begin() as txn:
+			keys = pickle.loads(txn.get(key=pickle.dumps("keys")))
+		self.keys = set(keys)
+
+	# ----------------------------------------------------------------------------------------------
+	def __iter__(self):
+		return iter(self.keys)
+
+	# ----------------------------------------------------------------------------------------------
+	def __len__(self):
+		return len(self.keys)
+
+	# ----------------------------------------------------------------------------------------------
+	def __getitem__(self, item):
+		key = pickle.dumps(item)
+		with self.db.begin() as txn:
+			value = txn.get(key)
+
+		return value
+
+	# ----------------------------------------------------------------------------------------------
+	def __del__(self):
+		self.db.close()
+
+
+# ==================================================================================================
+class ImageDatabase(Database):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, path: Path, mode: str="RGB"):
+		super().__init__(path)
+		self.mode = mode
+
+	# ----------------------------------------------------------------------------------------------
+	def __getitem__(self, item):
+		try:
+			key = pickle.dumps(str(item))
+			with self.db.begin() as txn:
+				image = Image.open(io.BytesIO(txn.get(key))).convert(mode=self.mode)
+		except OSError:
+			error(f"Failed to read '{self.path.stem}' database at index {item}")
+			raise
+
+		return image
+
+
+# ==================================================================================================
+class ArrayDatabase(Database):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, path: Path):
+		super().__init__(path)
+
+	# ----------------------------------------------------------------------------------------------
+	def __getitem__(self, item):
+		try:
+			key = pickle.dumps(str(item))
+			with self.db.begin() as txn:
+				with io.BytesIO(txn.get(key)) as f:
+					array = np.load(f, allow_pickle=True, encoding="bytes")
+					if isinstance(array, np.ndarray):
+						pass
+					elif isinstance(array, np.lib.npyio.NpzFile):
+						array = array.f.arr_0
+					else:
+						raise RuntimeError(f"Unexpected type array type '{type(array)}'")
+		except OSError:
+			error(f"Failed to read '{self.path.stem}' database at index {item}")
+			raise
+
+		return array
+
+
+# ==================================================================================================
+class PointcloudDatabase(Database):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, path: Path):
+		raise NotImplementedError("Use ArrayDatabase and numpy instead...")
+		super().__init__(path)
+
+	# ----------------------------------------------------------------------------------------------
+	def __getitem__(self, item):
+		key = pickle.dumps(str(item))
+		with self.db.begin() as txn:
+			# o3d.io can't read from open files, grr.
+			with tempfile.NamedTemporaryFile(suffix=".pcd") as f:
+			# with tempfile.SpooledTemporaryFile() as f:
+				f.write(txn.get(key))
+				f.seek(0)
+				pcl = o3d.io.read_point_cloud(f.name)
+
+		return np.asarray(pcl.points)
+
+
+# ==================================================================================================
+class LabelDatabase(Database):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, path: Path):
+		super().__init__(path)
+		raise NotImplementedError
+
+	# ----------------------------------------------------------------------------------------------
+	def __getitem__(self, item):
+		# key = pickle.dumps(str(item))
+		key = pickle.dumps(item)
+		with self.db.begin() as txn:
+			label = pickle.loads(txn.get(key))
+
+		return label
diff --git a/evreflex/models/EVReflex.py b/evreflex/models/EVReflex.py
new file mode 100755
index 0000000000000000000000000000000000000000..5547f742d0feb7853d0d20023967b94ce7e86e61
--- /dev/null
+++ b/evreflex/models/EVReflex.py
@@ -0,0 +1,372 @@
+#!/usr/bin/env python
+import pytorch_lightning as pl
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader, ConcatDataset
+import torchvision.transforms as tfs
+from pathlib import Path
+import time
+from kellog import info, warning, error, debug
+import evreflex.utils
+from PIL import Image
+
+from typing import Sequence, Optional
+
+from evreflex.datasets import EVReflex
+from evreflex.models.layers import ConvolutionBlock, ResidualConvolutionBlock, ResidualTransposedConvolutionBlock, TransposedConvolutionBlock
+from evreflex import utils
+
+eps = 1e-12
+
+# ==================================================================================================
+class EVReflexNet(pl.LightningModule):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(
+		self,
+		blocks: Sequence[int],
+		operation: str,
+		base_channels: int = 8,
+		train_path: Optional[Path] = None,
+		val_path: Optional[Path] = None,
+		test_path: Optional[Path] = None,
+		batch_size: int = 1,
+		workers: int = 0,
+		lr: float = 0.01,
+		subset: Optional[bool] = None
+	):
+		super().__init__()
+		self.operation = operation
+		self.train_path = train_path
+		self.val_path = val_path
+		self.test_path = test_path
+		self.batch_size = batch_size
+		self.workers = workers
+		self.lr = lr
+		self.subset = subset
+
+		# example_input_array is required to log the graph to tensorboard
+		if self.operation == "depth":
+			num_classes = 1
+			self.example_input_array = (torch.zeros(1, 1, 260, 344), torch.empty(0))
+		elif self.operation == "flow":
+			num_classes = 2
+			self.example_input_array = (torch.empty(0), torch.zeros(1, 2, 260, 344))
+		elif self.operation == "both":
+			num_classes = 3
+			self.example_input_array = (torch.zeros(1, 1, 260, 344), torch.zeros(1, 2, 260, 344))
+		elif self.operation == "dynamic_depth":
+			num_classes = 2
+			self.example_input_array = (torch.zeros(1, 2, 260, 344), torch.empty(0))
+		elif self.operation == "dynamic_both":
+			num_classes = 4
+			self.example_input_array = (torch.zeros(1, 2, 260, 344), torch.zeros(1, 2, 260, 344))
+		else:
+			raise NotImplementedError(f"{self.operation} is not a known operation")
+
+		# Point cloud
+		layers = []
+		layers += [ConvolutionBlock(
+			num_layers=blocks[0],
+			in_channels=num_classes,
+			out_channels=base_channels,
+		)]
+		layers += [ResidualConvolutionBlock(
+			num_layers=blocks[1],
+			in_channels=base_channels,
+			mid_channels=base_channels,
+			out_channels=base_channels * 2 ** 2,
+		)]
+		layers += [ResidualTransposedConvolutionBlock(
+			num_layers=blocks[1],
+			in_channels=base_channels * 2 ** 2,
+			mid_channels=base_channels,
+			out_channels=base_channels,
+		)]
+		layers += [TransposedConvolutionBlock(
+			num_layers=blocks[0],
+			in_channels=base_channels,
+			out_channels=1,
+		)]
+
+		# self.layers = nn.Sequential(*layers)
+		self.layers = layers
+		self.l1 = layers[0]
+		self.l2 = layers[1]
+		self.l3 = layers[2]
+		self.l4 = layers[3]
+		self.num_classes = num_classes
+		self.out_channels = layers[-2].out_channels
+
+		self.criterion = nn.MSELoss(reduction="none")
+
+		self.elapsed = 0 # Seconds
+		self.steps = 0 # Train steps
+		self.elapsedInterval = 10 # Seconds
+		self.elapsedSteps = 0 # Counter
+		self.start = None
+
+	# ----------------------------------------------------------------------------------------------
+	def configure_optimizers(self):
+		warning("Actually choose a sensible optimiser") # TODO
+		optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
+		lr_scheduler = {
+			"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
+			"monitor": "train_loss",
+		}
+
+		return [optimizer], [lr_scheduler]
+
+	# ----------------------------------------------------------------------------------------------
+	# Inference
+	def forward(self, depth, flow):
+		if self.operation in ["depth", "dynamic_depth"]:
+			x = self.l1(depth)
+		elif self.operation == "flow":
+			x = self.l1(flow)
+		elif self.operation in ["both", "dynamic_both"]:
+			x = self.l1(torch.hstack([depth, flow]))
+		x = self.l2(x)
+		x = self.l3(x)
+		x = self.l4(x)
+
+		return x
+
+	# ----------------------------------------------------------------------------------------------
+	def on_train_start(self):
+		if self.logger:
+			tensorboard = self.logger.experiment
+			tensorboard.add_text("Git revision", evreflex.utils.get_git_rev())
+
+	# ----------------------------------------------------------------------------------------------
+	def step(self, batch, batch_idx, dataset):
+		flow, points, danger, img, events = batch # FIX img is just visualised, but it's using GPU memory here
+
+		points = points.permute(0, 3, 2, 1)
+		depth = dataset.points_2_depth(points, self.operation in ["dynamic_depth", "dynamic_both"])
+
+		# Janky bit to make resolution divisible by 4
+		depth = depth[:, :, :, 1:-1]
+		danger = danger[:, :, :, 1:-1]
+		flow = flow[:, :, :, 1:-1]
+		if isinstance(img, Image.Image):
+			img = img[:, :, :, 1:-1]
+		if events.numel():
+			events = events[:, :3, :, 1:-1]
+
+		depth[torch.isnan(depth)] = 0 # There may be NaNs! So we need to set them to something else
+
+		output = self(depth, flow)
+
+		return output, flow, depth, danger, img, events
+
+	# ----------------------------------------------------------------------------------------------
+	def training_step(self, batch, batch_idx):
+		if self.start is not None:
+			self.elapsed += time.time() - self.start
+			self.steps += 1
+			if self.elapsed >= self.elapsedInterval:
+				# self.log("iteration_speed", self.steps / self.elapsed, on_epoch=False, on_step=True)
+				if self.logger:
+					tensorboard = self.logger.experiment
+					tensorboard.add_scalar("sample_speed", self.steps / self.elapsed * self.batch_size, global_step=self.elapsedSteps)
+				self.elapsed = 0
+				self.steps = 0
+				self.elapsedSteps += 1
+		self.start = time.time()
+
+		output, flow, depth, danger, img, events = self.step(batch, batch_idx, self.train_dataloader().dataset)
+		loss = self.criterion(output, danger)
+
+		if batch_idx == 0 and self.logger:
+		# if batch_idx == 0 and self.logger and self.current_epoch % 10 == 0:
+			flow = flow.detach().cpu()
+			output = output.detach().cpu()
+			tensorboard = self.logger.experiment
+			tensorboard.add_images("OP", output.clamp(0, 1), global_step=self.current_epoch)
+			depth[depth == 0] = float("Inf") # So when we invert the visualisation the background is black
+			tensorboard.add_images("Inverse depth", ((1 / depth) / 5).clamp(0, 1), global_step=self.current_epoch) # Black/white polarity
+			tensorboard.add_images("Optical flow", utils.flow_viz(flow), global_step=self.current_epoch)
+			tensorboard.add_images("Ground truth TTI", danger.clamp(0, 1), global_step=self.current_epoch)
+			if isinstance(img, Image.Image):
+				tensorboard.add_images("img", img, global_step=self.current_epoch)
+			tensorboard.add_images("events", events, global_step=self.current_epoch)
+
+		loss = loss.mean() # Do this if reduce in criterion is "none"
+		self.log("train_loss", loss, on_epoch=True, on_step=False)
+		if torch.isnan(loss).any() or torch.isinf(loss).any():
+			error("NaNs in train loss!")
+
+		return loss
+
+	# ----------------------------------------------------------------------------------------------
+	def validation_step(self, batch, batch_idx):
+		output, flow, depth, danger, img, events = self.step(batch, batch_idx, self.val_dataloader().dataset)
+		loss = self.criterion(output, danger)
+		loss = loss.mean() # Do this if reduce in criterion is "none"
+		self.log(f"val_loss", loss, on_epoch=True, on_step=False)
+		if torch.isnan(loss).any() or torch.isinf(loss).any():
+			error("NaNs in val loss!")
+
+		return loss
+
+	# ----------------------------------------------------------------------------------------------
+	def test_step(self, batch, batch_idx):
+		output, flow, depth, danger, img, events = self.step(batch, batch_idx, self.test_dataloader().dataset)
+
+		ffv = 100 # Fudge factor for visualisation
+		thresh = 0.001
+
+		if self.logger:
+			tensorboard = self.logger.experiment
+			output_ = output * ffv
+			output_[output_ < 0] = 0
+			output_[output_ > 1] = 1
+			danger_ = danger * ffv
+			danger_[danger_ < 0] = 0
+			danger_[danger_ > 1] = 1
+			for i, sample in enumerate(output_):
+				tensorboard.add_image("OP", sample, global_step=(batch_idx * events.shape[0]) + i)
+			for i, sample in enumerate(danger_):
+				tensorboard.add_image("GT", sample, global_step=(batch_idx * events.shape[0]) + i)
+			for i, sample in enumerate(events):
+				tensorboard.add_image("events", sample, global_step=(batch_idx * events.shape[0]) + i)
+
+		outputBinary = output.clone()
+		outputBinary[outputBinary < 0] = 0
+		# tp = tfs.ToPILImage()
+		threshDanger = torch.where(danger > thresh, 1, 0).byte()
+		threshOutput = torch.where(outputBinary > thresh, 1, 0).byte()
+
+		if self.logger:
+			tensorboard = self.logger.experiment
+			tensorboard.add_images("OP thresh", threshOutput * 255, global_step=batch_idx)
+			tensorboard.add_images("GT thresh", threshDanger * 255, global_step=batch_idx)
+
+		iou = utils.calc_iou(threshOutput, threshDanger)
+
+		loss = self.criterion(output, danger)
+		lossDepth = self.criterion(output, 1 / (depth[:, :1] + eps))
+		self.log(f"test_loss", loss.mean(), on_epoch=True, on_step=False) # Do this if reduce in criterion is "none"
+		self.log(f"test_loss_depth", lossDepth.mean(), on_epoch=True, on_step=False) # Do this if reduce in criterion is "none"
+		self.log(f"test_thresh_iou", iou.mean(), on_epoch=True, on_step=False)
+		# TODO grid search best depth threshold to use as IoU baseline?
+		# self.log(f"test_thresh_iou_depth", iou.mean(), on_epoch=True, on_step=False)
+
+		# return iou.mean() # Does nothing?
+
+	# ----------------------------------------------------------------------------------------------
+	def get_dataset(self, dataPath: Path, split: str) -> EVReflex:
+		return EVReflex(
+			data_path=dataPath,
+			split=split,
+			transforms=[
+				[],
+				[],
+				[tfs.ToTensor()],
+				[tfs.ToTensor()],
+				[],
+			],
+			subset=self.subset,
+			include_ev=True,
+			dynamic=self.operation in ["depth", "dynamic_depth"],
+		)
+
+	# ----------------------------------------------------------------------------------------------
+	def train_dataloader(self) -> DataLoader:
+		if self.train_path.is_file():
+			with open(self.train_path, "r") as txt:
+				dirs = txt.read().splitlines()
+			dataset = ConcatDataset([self.get_dataset(self.train_path.parent / d) for d in dirs])
+		else:
+			dataset = self.get_dataset(self.train_path, "train")
+
+		return DataLoader(
+			dataset,
+			batch_size=self.batch_size,
+			shuffle=True,
+			num_workers=self.workers,
+			pin_memory=self.device != torch.device("cpu")
+		)
+
+	# ----------------------------------------------------------------------------------------------
+	def val_dataloader(self) -> DataLoader:
+		if self.val_path.is_file():
+			with open(self.val_path, "r") as txt:
+				dirs = txt.read().splitlines()
+			dataset = ConcatDataset([self.get_dataset(self.val_path.parent / d) for d in dirs])
+		else:
+			dataset = self.get_dataset(self.val_path, "val")
+
+		return DataLoader(
+			dataset,
+			batch_size=self.batch_size,
+			shuffle=False,
+			num_workers=self.workers,
+			pin_memory=self.device != torch.device("cpu")
+		)
+
+	# ----------------------------------------------------------------------------------------------
+	def test_dataloader(self) -> DataLoader:
+		if self.test_path.is_file():
+			with open(self.test_path, "r") as txt:
+				dirs = txt.read().splitlines()
+			dataset = ConcatDataset([self.get_dataset(self.test_path.parent / d) for d in dirs])
+		else:
+			dataset = self.get_dataset(self.test_path, "test")
+
+		return DataLoader(
+			dataset,
+			batch_size=self.batch_size,
+			shuffle=False,
+			num_workers=self.workers,
+			pin_memory=self.device != torch.device("cpu")
+		)
+
+
+# ==================================================================================================
+if __name__ == "__main__":
+	import colored_traceback.auto
+	import argparse
+
+	parser = argparse.ArgumentParser()
+	parser.add_argument("--data_dir", type=Path, help="Path to EVReflex dataset")
+	args = parser.parse_args()
+
+	# Blocks:
+	# - Number of layers in the first ConvolutionBlock
+	# - Number of layers in the subsequent ResidualConvolutionBlocks
+	model = EVReflexNet(
+		blocks=[3, 2],
+		# num_classes=40,
+		num_classes=4, # DEBUG
+		base_channels=8,
+	)
+
+	if not args.data_dir:
+		warning("Using random input, specify `--data_dir` to load from a dataset")
+		# points = torch.rand(1, 40, 89960)
+		points = torch.rand(1, 4, 89960)
+		flow = torch.rand(1, 2, 260, 346)
+	else:
+		dataset = EVReflex(
+			data_path=args.data_dir,
+			transforms=[
+				[],
+				[],
+				[tfs.ToTensor()],
+				[tfs.ToTensor()],
+				[],
+			]
+		)
+		dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
+		for sample in dataloader:
+			flow, points, danger, img, events = sample
+			# debug(list(flow.shape), list(points.shape), list(danger.shape))
+			break
+	debug(f"Input shape: {points.shape}")
+
+	points = points.reshape([*points.shape[:2], 260, 346])
+	points = points[:, :, :, 1:-1] # Crop to make divisible by 4
+	output = model(points, flow)
+	debug(f"Output shape: {output.shape}")
diff --git a/evreflex/models/__init__.py b/evreflex/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f676e039557ba7c58ec95dd9e5f74b0bc0bc786
--- /dev/null
+++ b/evreflex/models/__init__.py
@@ -0,0 +1,2 @@
+"""Pytorch Lightning models."""
+from .EVReflex import EVReflexNet
diff --git a/evreflex/models/layers.py b/evreflex/models/layers.py
new file mode 100755
index 0000000000000000000000000000000000000000..f0464d0951e4acf6b99b0df25a71db0602708e5c
--- /dev/null
+++ b/evreflex/models/layers.py
@@ -0,0 +1,416 @@
+#!/usr/bin/env python3
+from torch import nn, Tensor
+
+# ==================================================================================================
+class Layer(nn.Module):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self):
+		super().__init__()
+
+	# ----------------------------------------------------------------------------------------------
+	def forward(self, x: Tensor) -> Tensor:
+		return self.layers(x)
+
+
+# ==================================================================================================
+class ConvolutionLayer(Layer):
+	def __init__(
+		self,
+		in_channels: int,
+		out_channels: int,
+		kernel_size: int,
+		stride: int = 1,
+		padding: int = 0,
+		norm: bool = True,
+		activation: bool = True,
+	):
+		super().__init__()
+
+		self.in_channels = in_channels
+		self.out_channels = out_channels
+		self.kernel_size = kernel_size
+		self.stride = stride
+		self.padding = padding
+		self.norm = norm
+		self.activation = activation
+
+		layers = []
+		layer = nn.Conv2d(
+			in_channels=in_channels,
+			out_channels=out_channels,
+			kernel_size=kernel_size,
+			stride=stride,
+			padding=padding,
+			bias=not norm,
+		)
+		layers += [layer]
+		if norm:
+			layer = nn.BatchNorm2d(num_features=out_channels)
+			layers += [layer]
+		if activation:
+			layer = nn.CELU(inplace=True)
+			layers += [layer]
+
+		self.layers = nn.Sequential(*layers)
+
+
+# ==================================================================================================
+class TransposedConvolutionLayer(Layer):
+	def __init__(
+		self,
+		in_channels: int,
+		out_channels: int,
+		kernel_size: int,
+		stride: int = 1,
+		padding: int = 0,
+		output_padding: int = 0,
+		norm: bool = True,
+		activation: bool = True,
+	):
+		super().__init__()
+
+		self.in_channels = in_channels
+		self.out_channels = out_channels
+		self.kernel_size = kernel_size
+		self.stride = stride
+		self.padding = padding
+		self.output_padding = output_padding
+		self.norm = norm
+		self.activation = activation
+
+		layers = []
+		layer = nn.ConvTranspose2d(
+			in_channels=in_channels,
+			out_channels=out_channels,
+			kernel_size=kernel_size,
+			stride=stride,
+			padding=padding,
+			output_padding=output_padding,
+			bias=not norm,
+		)
+		layers += [layer]
+		if norm:
+			layer = nn.BatchNorm2d(num_features=out_channels)
+			layers += [layer]
+		if activation:
+			layer = nn.CELU(inplace=True)
+			layers += [layer]
+
+		self.layers = nn.Sequential(*layers)
+
+
+# ==================================================================================================
+class Identity(nn.Module):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self):
+		super().__init__()
+
+	# ----------------------------------------------------------------------------------------------
+	def forward(self, x: Tensor) -> Tensor:
+		return x
+
+
+# ==================================================================================================
+class ConvolutionBlock(Layer):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, num_layers: int, in_channels: int, out_channels: int):
+		super(Layer, self).__init__()
+
+		self.num_layers = num_layers
+		self.in_channels = in_channels
+		self.out_channels = out_channels
+
+		layers = [
+			ConvolutionLayer(
+				in_channels=in_channels,
+				out_channels=out_channels,
+				kernel_size=3,
+				stride=2,
+				padding=1,
+			)
+		] + [
+			ConvolutionLayer(
+				in_channels=out_channels,
+				out_channels=out_channels,
+				kernel_size=3,
+				padding=1,
+			)
+			for _ in range(1, num_layers)
+		]
+
+		self.layers = nn.Sequential(*layers)
+
+
+# ==================================================================================================
+class TransposedConvolutionBlock(Layer):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, num_layers: int, in_channels: int, out_channels: int):
+		super(Layer, self).__init__()
+
+		self.num_layers = num_layers
+		self.in_channels = in_channels
+		self.out_channels = out_channels
+
+		layers = [
+			TransposedConvolutionLayer(
+				in_channels=in_channels,
+				out_channels=out_channels,
+				kernel_size=3,
+				stride=2,
+				padding=1,
+				output_padding=1,
+			)
+		] + [
+			TransposedConvolutionLayer(
+				in_channels=out_channels,
+				out_channels=out_channels,
+				kernel_size=3,
+				padding=1,
+			)
+			for _ in range(1, num_layers)
+		]
+
+		self.layers = nn.Sequential(*layers)
+
+
+# ==================================================================================================
+class ConvolutionBranch(Layer):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, in_channels: int, mid_channels: int, out_channels: int, stride: int):
+		super().__init__()
+
+		self.in_channels = in_channels
+		self.mid_channels = mid_channels
+		self.out_channels = out_channels
+		self.stride = stride
+
+		layers = [
+			ConvolutionLayer(
+				in_channels=in_channels, out_channels=mid_channels, kernel_size=1
+			),
+			ConvolutionLayer(
+				in_channels=mid_channels,
+				out_channels=mid_channels,
+				kernel_size=3,
+				stride=stride,
+				padding=1,
+			),
+			ConvolutionLayer(
+				in_channels=mid_channels,
+				out_channels=out_channels,
+				kernel_size=1,
+				activation=False,
+			),
+		]
+
+		self.layers = nn.Sequential(*layers)
+
+
+# ==================================================================================================
+class TransposedConvolutionBranch(Layer):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, in_channels: int, mid_channels: int, out_channels: int, stride: int):
+		super().__init__()
+
+		self.in_channels = in_channels
+		self.mid_channels = mid_channels
+		self.out_channels = out_channels
+		self.stride = stride
+
+		layers = [
+			TransposedConvolutionLayer(
+				in_channels=in_channels, out_channels=mid_channels, kernel_size=1
+			),
+			TransposedConvolutionLayer(
+				in_channels=mid_channels,
+				out_channels=mid_channels,
+				kernel_size=3,
+				stride=stride,
+				padding=1,
+				output_padding=stride - 1,
+			),
+			TransposedConvolutionLayer(
+				in_channels=mid_channels,
+				out_channels=out_channels,
+				kernel_size=1,
+				activation=False,
+			),
+
+		]
+
+		self.layers = nn.Sequential(*layers)
+
+
+# ==================================================================================================
+class ResidualConvolutionBranch(Layer):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, in_channels: int, out_channels: int, stride: int):
+		super().__init__()
+
+		self.in_channels = in_channels
+		self.out_channels = out_channels
+		self.stride = stride
+
+		if stride == 1 and in_channels == out_channels:
+			layers = [Identity()]
+		else:
+			layers = [
+				ConvolutionLayer(
+					in_channels=in_channels,
+					out_channels=out_channels,
+					kernel_size=stride,
+					stride=stride,
+					activation=False,
+				)
+			]
+
+		self.layers = nn.Sequential(*layers)
+
+
+# ==================================================================================================
+class ResidualTransposedConvolutionBranch(Layer):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, in_channels: int, out_channels: int, stride: int):
+		super().__init__()
+
+		self.in_channels = in_channels
+		self.out_channels = out_channels
+		self.stride = stride
+
+		if stride == 1 and in_channels == out_channels:
+			layers = [Identity()]
+		else:
+			layers = [
+				TransposedConvolutionLayer(
+					in_channels=in_channels,
+					out_channels=out_channels,
+					kernel_size=stride,
+					stride=stride,
+					activation=False,
+				)
+			]
+
+		self.layers = nn.Sequential(*layers)
+
+
+# ==================================================================================================
+class ResidualLayer(nn.Module):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self):
+		super().__init__()
+
+	# ----------------------------------------------------------------------------------------------
+	def forward(self, x: Tensor) -> Tensor:
+		# s1 = x.shape
+		x = self.branch_a(x) + self.branch_b(x)
+		x = self.activation(x)
+		# print(f"{type(self).__name__}: {list(s1[1:])} -> {list(x.shape[1:])}")
+		return x
+
+
+# ==================================================================================================
+class ResidualConvolutionLayer(ResidualLayer):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, in_channels: int, mid_channels: int, out_channels: int, stride: int):
+		super().__init__()
+
+		self.in_channels = in_channels
+		self.mid_channels = mid_channels
+		self.out_channels = out_channels
+		self.stride = stride
+
+		self.branch_a = ConvolutionBranch(
+			in_channels=in_channels,
+			mid_channels=mid_channels,
+			out_channels=out_channels,
+			stride=stride,
+		)
+		self.branch_b = ResidualConvolutionBranch(
+			in_channels=in_channels, out_channels=out_channels, stride=stride
+		)
+		self.activation = nn.CELU(inplace=True)
+
+
+# ==================================================================================================
+class ResidualTransposedConvolutionLayer(ResidualLayer):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, in_channels: int, mid_channels: int, out_channels: int, stride: int):
+		super().__init__()
+
+		self.in_channels = in_channels
+		self.mid_channels = mid_channels
+		self.out_channels = out_channels
+		self.stride = stride
+
+		self.branch_a = TransposedConvolutionBranch(
+			in_channels=in_channels,
+			mid_channels=mid_channels,
+			out_channels=out_channels,
+			stride=stride,
+		)
+		self.branch_b = ResidualTransposedConvolutionBranch(
+			in_channels=in_channels, out_channels=out_channels, stride=stride
+		)
+		self.activation = nn.CELU(inplace=True)
+
+
+# ==================================================================================================
+class ResidualConvolutionBlock(Layer):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, num_layers: int, in_channels: int, mid_channels: int, out_channels: int):
+		super(Layer, self).__init__()
+
+		self.num_layers = num_layers
+		self.in_channels = in_channels
+		self.mid_channels = mid_channels
+		self.out_channels = out_channels
+
+		layers = [
+			ResidualConvolutionLayer(
+				in_channels=in_channels,
+				mid_channels=mid_channels,
+				out_channels=out_channels,
+				stride=2,
+			)
+		] + [
+			ResidualConvolutionLayer(
+				in_channels=out_channels,
+				mid_channels=mid_channels,
+				out_channels=out_channels,
+				stride=1,
+			)
+			for _ in range(1, num_layers)
+		]
+
+		self.layers = nn.Sequential(*layers)
+
+
+# ==================================================================================================
+class ResidualTransposedConvolutionBlock(Layer):
+	# ----------------------------------------------------------------------------------------------
+	def __init__(self, num_layers: int, in_channels: int, mid_channels: int, out_channels: int):
+		super(Layer, self).__init__()
+
+		self.num_layers = num_layers
+		self.in_channels = in_channels
+		self.mid_channels = mid_channels
+		self.out_channels = out_channels
+
+		layers = [
+			ResidualTransposedConvolutionLayer(
+				in_channels=in_channels,
+				mid_channels=mid_channels,
+				out_channels=out_channels,
+				stride=2,
+			)
+		] + [
+			ResidualTransposedConvolutionLayer(
+				in_channels=out_channels,
+				mid_channels=mid_channels,
+				out_channels=out_channels,
+				stride=1,
+			)
+			for _ in range(1, num_layers)
+		]
+
+		self.layers = nn.Sequential(*layers)
diff --git a/evreflex/tools/merge_lmdbs.py b/evreflex/tools/merge_lmdbs.py
new file mode 100755
index 0000000000000000000000000000000000000000..35acf12f4588b517a7dfb076863b11134cc9d110
--- /dev/null
+++ b/evreflex/tools/merge_lmdbs.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python
+import argparse
+from pathlib import Path
+import lmdb
+import pickle
+from tqdm import tqdm
+from kellog import info, warning, error, debug
+from natsort import natsorted
+
+dataTypes = ["danger", "event", "flow", "img", "pcl", "seg"]
+skip = [
+	"scene0000_00_vh_clean_2",
+	"scene0080_00_vh_clean_2",
+	"scene0116_00_vh_clean_2",
+	"scene0135_00_vh_clean_2",
+	"scene0150_00_vh_clean_2",
+	"scene0153_00_vh_clean_2",
+	"scene0182_00_vh_clean_2",
+	"scene0195_00_vh_clean_2",
+	"scene0237_00_vh_clean_2",
+	"scene0250_00_vh_clean_2",
+	"scene0415_00_vh_clean_2",
+	"scene0444_00_vh_clean_2",
+	"scene0448_00_vh_clean_2",
+	"scene0457_00_vh_clean_2",
+	"scene0485_00_vh_clean_2",
+	"scene0570_00_vh_clean_2",
+	"scene0583_00_vh_clean_2",
+	"scene0676_00_vh_clean_2",
+	"scene0684_00_vh_clean_2",
+]
+
+# ==================================================================================================
+def main(args):
+	global dataTypes # Not fully sure why I need this here
+	args.output_path.mkdir(parents=True, exist_ok=True)
+
+	scenes = natsorted(list(args.input_path.glob("scene*_00_vh_clean_2")))
+	# scenes = scenes[:5] # DEBUG
+	# scenes = scenes[:2] # DEBUG
+
+	if args.data_type is not None:
+		dataTypes = [args.data_type]
+
+	if not args.force and any([(args.output_path / dataType).exists() for dataType in dataTypes]):
+		warning(f"Output lmdb(s) exists in '{args.output_path}'")
+		ans = input("Overwrite? [Y/n] ").lower()
+		if ans not in ["", "y"]:
+			warning("Not overwriting, aborting")
+			quit(0)
+	info(f"Saving output lmdbs in '{args.output_path}'")
+
+	bar = tqdm(dataTypes, position=0, leave=True)
+	for dataType in bar:
+		outputKeys = set()
+		outputDB = args.output_path / dataType
+		bar1 = tqdm(scenes, position=1, leave=True)
+		with lmdb.open(path=str(outputDB), map_size=2 ** 40) as outputEnv:
+			for scene in bar1:
+				if scene.name in skip:
+					warning(f"Skipping '{scene.name}'")
+					continue
+				bar.set_description(dataType) # Update screen at the same time as the lower bar
+				bar1.set_description(scene.name)
+				inputDB = scene / dataType
+				with lmdb.open(str(inputDB), readonly=True) as inputEnv:
+					with inputEnv.begin() as inputTxn:
+						inKeys = pickle.loads(inputTxn.get(key=pickle.dumps("keys")))
+						for i, key in enumerate(tqdm(inKeys, position=2, leave=True)):
+							if key == "1":
+								warning("Skipping first key because it's not present in danger lmdb...")
+								continue
+							if i == len(inKeys) - 1 and dataType == "danger":
+								warning("Skipping last key in danger lmdb because it's extra...")
+								continue
+							inputKey = pickle.dumps(key)
+							outputKey = str(len(outputKeys))
+							outputKeys.add(outputKey)
+							outputKey = pickle.dumps(outputKey)
+							with outputEnv.begin(write=True) as outputTxn:
+								outputTxn.put(
+									key=outputKey,
+									value=inputTxn.get(inputKey),
+									dupdata=False,
+								)
+			with outputEnv.begin(write=True) as outputTxn:
+				outputTxn.put(
+					key=pickle.dumps("keys"),
+					value=pickle.dumps(outputKeys),
+					dupdata=False,
+				)
+
+
+# ==================================================================================================
+def parse_args():
+	parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+	parser.add_argument("input_path", type=Path, help="Input LMDB(s) directory")
+	parser.add_argument("output_path", type=Path, help="Output LMDB directory")
+	parser.add_argument("-d", "--data_type", choices=dataTypes, type=str, help="If specified, limit to a particular data type", default=None)
+	parser.add_argument("-f", "--force", action="store_true", help="Overwrite output LMDB if any part of it exists")
+
+	return parser.parse_args()
+
+
+# ==================================================================================================
+if __name__ == "__main__":
+	main(parse_args())
diff --git a/evreflex/tools/rosbags_to_danger_lmdb.py b/evreflex/tools/rosbags_to_danger_lmdb.py
new file mode 100755
index 0000000000000000000000000000000000000000..7ec1b691861ffeea0c1cf587d08e9006965aff0a
--- /dev/null
+++ b/evreflex/tools/rosbags_to_danger_lmdb.py
@@ -0,0 +1,476 @@
+#!/usr/bin/env python
+import argparse
+from pathlib import Path
+import re
+from sys import prefix
+from natsort import natsorted
+import lmdb
+import pickle
+from kellog import info, warning, error, debug
+try:
+	from cv_bridge import CvBridge
+except ModuleNotFoundError as e:
+	error(e)
+	error("Probably can't find ROS, did you unset PYTHONPATH?")
+	quit(1)
+from tqdm import tqdm
+from rosbag import Bag
+import yaml
+import rospy
+import numpy as np
+import cv2
+import open3d as o3d
+
+from typing import Tuple
+
+# TODO do this dynamically/with argparse
+# This is EVDodge's values, as used in simulator
+width = 346
+height = 260
+fx = 172.9999924379297
+fy = 172.9999924379297
+cx = 173.0
+cy = 130.0
+labels = [i + 1 for i in range(40)] # Match the values in model conversion script...
+bridge = CvBridge()
+eps = 1e-15
+framerate = 30 # FIX use actual image times instead
+maxVelFilter = 5 # m/s to be ignored
+
+# ==================================================================================================
+def main(args):
+	if args.input_path.is_file():
+		outputDir = args.input_path.parent / "lmdb" / args.input_path.stem
+		outputDir.mkdir(parents=True, exist_ok=True)
+	elif args.input_path.is_dir():
+		outputDir = args.input_path / "lmdb"
+		outputDir.mkdir(parents=True, exist_ok=True)
+	else:
+		raise TypeError(f"Input path '{args.input_path.name}' is not a normal file or directory")
+	# outputImg = outputDir / "img"
+	# outputFlow = outputDir / "flow"
+	# outputSeg = outputDir / "seg"
+	# outputDepth = outputDir / "depth"
+	# outputPCL = outputDir / "pcl"
+	# outputEvent = outputDir / "event"
+	outputDanger = outputDir / "danger"
+	# if not args.force and any([o.exists() for o in [outputImg, outputFlow, outputSeg, outputDepth, outputEvent, outputDanger]]):
+	# if not args.force and any([o.exists() for o in [outputImg, outputFlow, outputSeg, outputPCL, outputEvent, outputDanger]]):
+	if not args.force and any([o.exists() for o in [outputDanger]]):
+		warning(f"Output lmdb(s) exists in '{outputDir}'")
+		ans = input("Overwrite? [Y/n] ").lower()
+		if ans not in ["", "y"]:
+			warning("Not overwriting, aborting")
+			quit(0)
+	info(f"Saving output lmdbs in '{outputDir}'")
+
+	keys_img = set()
+	keys_flow = set()
+	keys_seg = set()
+	keys_pcl = set()
+	keys_event = set()
+	keys_danger = set()
+	# with lmdb.open(path=str(outputImg), map_size=2 ** 40) as lmdb_img, \
+	# 	lmdb.open(path=str(outputFlow), map_size=2 ** 40) as lmdb_flow, \
+	# 	lmdb.open(path=str(outputSeg), map_size=2 ** 40) as lmdb_seg, \
+	# 	lmdb.open(path=str(outputPCL), map_size=2 ** 40) as lmdb_pcl, \
+	# 	lmdb.open(path=str(outputEvent), map_size=2 ** 40) as lmdb_event, \
+	# 	lmdb.open(path=str(outputDanger), map_size=2 ** 40) as lmdb_danger:
+	with lmdb.open(path=str(outputDanger), map_size=2 ** 40) as lmdb_danger:
+		if args.input_path.is_file():
+			bagPath = args.input_path
+			with Bag(str(bagPath.resolve()), "r") as bag:
+				process_bag(
+					bag,
+					outputDir,
+					# (lmdb_img, lmdb_flow, lmdb_seg, lmdb_pcl, lmdb_event, lmdb_danger),
+					lmdb_danger,
+					# (keys_img, keys_flow, keys_seg, keys_pcl, keys_event, keys_danger),
+					keys_danger,
+					args
+				)
+		elif args.input_path.is_dir():
+			# Filter to valid bag names in case there are any extra bags in here (e.g. from debugging)
+			glob = [f for f in args.input_path.rglob("*.bag") if re.search(r"scene[0-9][0-9][0-9][0-9]_[0-9][0-9]_vh_clean_2.bag$", f.name)]
+			barBags = tqdm(total=len(glob), position=1)
+			for bagPath in natsorted(glob):
+				barBags.set_description(bagPath.name)
+				with Bag(str(bagPath.resolve()), "r") as bag:
+					process_bag(
+						bag,
+						outputDir,
+						# (lmdb_img, lmdb_flow, lmdb_seg, lmdb_pcl, lmdb_event, lmdb_danger),
+						# (keys_img, keys_flow, keys_seg, keys_pcl, keys_event, keys_danger),
+						lmdb_danger,
+						keys_danger,
+						args
+					)
+		m = min([len(s) for s in [keys_img, keys_flow, keys_seg, keys_pcl, keys_event, keys_danger]])
+		for env, keys in zip(
+			# [lmdb_img, lmdb_flow, lmdb_seg, lmdb_pcl, lmdb_event, lmdb_danger],
+			# [keys_img, keys_flow, keys_seg, keys_pcl, keys_event, keys_danger]
+			[lmdb_danger],
+			[keys_danger]
+		):
+			# Dirty hack to make the lists the same
+			keys = natsorted(keys)
+			keys = keys[1:m]
+			with env.begin(write=True) as txn:
+				txn.put(
+					key=pickle.dumps("keys"),
+					value=pickle.dumps(keys),
+					dupdata=False,
+				)
+
+
+# ==================================================================================================
+def process_bag(bag, outputDir: Path, lmdbs, keys, args):
+	# lmdb_img, lmdb_flow, lmdb_seg, lmdb_pcl, lmdb_event, lmdb_danger = lmdbs
+	# keys_img, keys_flow, keys_seg, keys_pcl, keys_event, keys_danger = keys
+	lmdb_danger = lmdbs
+	keys_danger = keys
+	imgIter = 0
+	eventImgIter = 0
+	firstImgTime = -1
+	events = []
+	imgTimes = []
+	eventCountImgs = []
+	eventTimeImgs = []
+	eventImgTimes = []
+	width = None # 346
+	height = None # 260
+
+	bagInfo = yaml.load(bag._get_yaml_info(), Loader=yaml.FullLoader)
+	topics = [t["topic"] for t in bagInfo["topics"] if t["topic"] in ["/cam0/image_raw", "/cam0/events", "/cam0/optic_flow", "/cam0/image_alpha", "/cam0/depthmap"]]
+	tStart = bag.get_start_time()
+	tStartRos = rospy.Time(tStart)
+
+	recvd = {}
+	seg = None
+	depth = None
+	prevDepth = None
+	prevFlow = None
+	flow = None
+	total = sum([topic["messages"] for topic in bagInfo["topics"] if topic["topic"] in ["/cam0/image_raw", "/cam0/events", "/cam0/optic_flow", "/cam0/image_alpha", "/cam0/depthmap"]])
+	# FIX total doesn't take into account start time in read_messages()
+	if args.debug:
+		warning("DEBUG MODE ENABLED, only doing first 0.5 seconds")
+		bar = tqdm(bag.read_messages(topics=topics, end_time=rospy.Time(tStart + 0.5)), total=total)
+	else:
+		bar = tqdm(bag.read_messages(topics=topics), total=total)
+	for topic, msg, t in bar:
+		stamp = msg.header.stamp.to_sec()
+		if topic != "/cam0/events":
+			bar.set_description(f"{imgIter} ({msg.header.stamp.to_sec():.2f}s)")
+		if topic != "/cam0/events" and stamp not in recvd:
+			if len(recvd) >= 10:
+				error("Problem with accumulation... timestamps:")
+				for r, v in recvd.items():
+					error(r, v)
+				quit(1)
+			recvd[stamp] = [False, False, False, False]
+		if topic == "/cam0/image_raw":
+			if recvd[stamp][0]:
+				error("Duplicate raw!")
+			else:
+				recvd[stamp][0] = True
+		if topic == "/cam0/image_raw":
+			if not width:
+				width = msg.width
+			if not height:
+				height = msg.height
+			# img = np.asarray(bridge.imgmsg_to_cv2(msg, msg.encoding))
+			# imgPath = save_img(img, imgIter, outputDir, prefix="img")
+			# save_to_lmdb(lmdb_img, imgPath, keys_img, args.keep)
+			time = msg.header.stamp
+			if imgIter > 0:
+				imgTimes.append(time)
+			else:
+				firstImgTime = time
+				eventImgTimes.append(time.to_sec())
+				# filter events we added previously
+				events = filter_events(events, eventImgTimes[-1] - tStart)
+		elif topic == "/cam0/optic_flow":
+			if recvd[stamp][1]:
+				error("Duplicate flow!")
+			else:
+				recvd[stamp][1] = True
+			flowPath, flow = save_flow(msg, imgIter, outputDir)
+			# save_to_lmdb(lmdb_flow, flowPath, keys_flow, args.keep)
+		elif topic == "/cam0/image_alpha":
+			if recvd[stamp][2]:
+				error("Duplicate alpha!")
+			else:
+				recvd[stamp][2] = True
+			seg = np.asarray(bridge.imgmsg_to_cv2(msg, msg.encoding))
+			seg = seg * np.in1d(seg, labels).reshape(seg.shape) # Remove non-class values
+			# segPath = save_img(seg, imgIter, outputDir, prefix="seg", isRGB=False)
+			# save_to_lmdb(lmdb_seg, segPath, keys_seg, args.keep)
+		elif topic == "/cam0/depthmap":
+			if recvd[stamp][3]:
+				error("Duplicate depth!")
+			else:
+				recvd[stamp][3] = True
+			depth = np.asarray(bridge.imgmsg_to_cv2(msg, msg.encoding))
+		elif topic == "/cam0/events" and msg.events:
+			# Add events to list.
+			for event in msg.events:
+				ts = event.ts
+				event = [event.x,
+						event.y,
+						(ts - tStartRos).to_sec(),
+						(float(event.polarity) - 0.5) * 2]
+				# Add event if it was after the first img or we haven't seen the first img
+				if firstImgTime == -1 or ts > firstImgTime:
+					events.append(event)
+			# if len(imgTimes) >= args.max_aug and events[-1][2] > (imgTimes[args.max_aug - 1] - tStartRos).to_sec():
+			# 	eventImgIter, keys_event = save_events(
+			# 		events,
+			# 		outputDir,
+			# 		imgTimes,
+			# 		eventCountImgs,
+			# 		eventTimeImgs,
+			# 		eventImgTimes,
+			# 		width,
+			# 		height,
+			# 		args.max_aug,
+			# 		args.n_skip,
+			# 		eventImgIter,
+			# 		tStartRos,
+			# 		lmdb_event,
+			# 		keys_event,
+			# 		args.keep
+			# 	)
+					# bagPath,
+		if topic != "/cam0/events" and recvd[stamp] == [True, True, True, True]:
+			# pclPath = save_pcl(depth, seg, imgIter, outputDir)
+			# save_to_lmdb(lmdb_pcl, pclPath, keys_pcl, args.keep)
+			dangerPath = generate_danger(depth, prevDepth, imgIter, outputDir, flow, prevFlow)
+			save_to_lmdb(lmdb_danger, dangerPath, keys_danger, args.keep, offset=1)
+			seg = None
+			prevDepth = depth
+			prevFlow = flow
+			depth = None
+			imgIter += 1
+			recvd.pop(stamp)
+	if len(recvd) > 0:
+		error(f"Unfinished timestamps: {len(recvd)} (out of {imgIter})")
+
+
+# ==================================================================================================
+def save_img(img, iter, outDir, prefix: str, isRGB: bool = True) -> Path:
+	if isRGB:
+		img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+	outPath = outDir / f"{prefix}{iter:05d}.png"
+	cv2.imwrite(str(outPath), img)
+
+	return outPath
+
+
+# ==================================================================================================
+def save_events(events, outDir, imgTimes, eventCountImgs, eventTimeImgs,
+		eventImgTimes, width, height, max_aug, n_skip, eventImgIter, tStartRos, lmdb, keys, keep) -> Tuple[int, set]:
+	eventIter = 0
+	cutoffEventIter = 0
+	imgIter = 0
+	currImgTime = (imgTimes[imgIter] - tStartRos).to_sec()
+
+	eventCountImg = np.zeros((height, width, 2), dtype=np.uint16)
+	eventTimeImg = np.zeros((height, width, 2), dtype=np.float32)
+
+	while imgIter < len(imgTimes) and events[-1][2] > currImgTime:
+		x = events[eventIter][0]
+		y = events[eventIter][1]
+		t = events[eventIter][2]
+
+		if t > currImgTime:
+			eventCountImgs.append(eventCountImg)
+			eventCountImg = np.zeros((height, width, 2), dtype=np.uint16)
+			eventTimeImgs.append(eventTimeImg)
+			eventTimeImg = np.zeros((height, width, 2), dtype=np.float32)
+			cutoffEventIter = eventIter
+			eventImgTimes.append(imgTimes[imgIter].to_sec())
+			imgIter += n_skip
+			if (imgIter < len(imgTimes)):
+				currImgTime = (imgTimes[imgIter] - tStartRos).to_sec()
+
+		if events[eventIter][3] > 0:
+			eventCountImg[y, x, 0] += 1
+			eventTimeImg[y, x, 0] = t
+		else:
+			eventCountImg[y, x, 1] += 1
+			eventTimeImg[y, x, 1] = t
+
+		eventIter += 1
+
+	del imgTimes[:imgIter]
+	del events[:cutoffEventIter]
+
+	if len(eventCountImgs) >= max_aug:
+		n_to_save = len(eventCountImgs) - max_aug + 1
+		for i in range(n_to_save):
+			imgTimesOut = np.array(eventImgTimes[i:i + max_aug + 1])
+			imgTimesOut = imgTimesOut.astype(np.float64)
+			eventTimeImgsNP = np.array(eventTimeImgs[i:i + max_aug], dtype=np.float32)
+			eventTimeImgsNP -= imgTimesOut[0] - tStartRos.to_sec()
+			eventTimeImgsNP = np.clip(eventTimeImgsNP, a_min=0, a_max=None)
+
+			now = np.array([np.array(eventCountImgs[i:i + max_aug]), eventTimeImgsNP, imgTimesOut], dtype=object)
+			outPath = outDir / f"event{eventImgIter:05d}.npz"
+			np.savez_compressed(outPath, now)
+			save_to_lmdb(lmdb, outPath, keys, keep)
+			eventImgIter += n_skip
+
+		del eventCountImgs[:n_to_save]
+		del eventTimeImgs[:n_to_save]
+		del eventImgTimes[:n_to_save]
+	return eventImgIter, keys
+
+
+# ==================================================================================================
+def save_flow(msg, iter, outDir):
+	# It's row-major order
+	flow = np.array([
+		np.array(msg.flow_x).reshape([msg.height, msg.width]),
+		np.array(msg.flow_y).reshape([msg.height, msg.width])
+	])
+	# outPath = outDir / f"flow{iter:05d}.npz"
+	outPath = None
+	# np.savez_compressed(str(outPath), flow.astype(np.float32))
+
+	return outPath, flow
+
+
+# ==================================================================================================
+def save_pcl(depth, seg, iter, outDir: Path) -> Path:
+	# Duplicating channels because o3d doesn't seem to be able to do greyscale
+	depth = o3d.geometry.Image(depth)
+	seg = o3d.geometry.Image(cv2.cvtColor(seg, cv2.COLOR_GRAY2RGB))
+	rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(seg, depth, depth_scale=1.0, depth_trunc=50.0, convert_rgb_to_intensity=False)
+	intrinsic = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy)
+	# pcl = o3d.geometry.PointCloud.create_from_depth_image(depth, intrinsic)
+	pcl = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, intrinsic, project_valid_depth_only=False)
+	# File extension dictates the format
+	# NOTE: `compressed` does not seem to have an effect when using 'xyzrgb'
+	# o3d.io.write_point_cloud(str(path / f"pcl{iter:05d}.pcd"), pcl, compressed=True)
+	outPath = outDir / f"pcl{iter:05d}.npz"
+	np.savez_compressed(str(outPath), np.array([pcl.points, pcl.colors], dtype=np.float32))
+
+	return outPath
+
+
+# ==================================================================================================
+def filter_events(events, ts):
+	"""Removes all events with timestamp lower than the specified one
+
+	Args:
+		events (list): the list of events in form of (x, y, t, p)
+		ts (float): the timestamp to split events
+
+	Return:
+		(list): a list of events with timestamp above the threshold
+	"""
+	tss = np.array([e[2] for e in events])
+	idxArray = np.argsort(tss) # I hope it"s not needed
+	i = np.searchsorted(tss[idxArray], ts)
+	return [events[k] for k in idxArray[i:]]
+
+
+# ==================================================================================================
+def generate_danger(depth, prevDepth, iter: int, outDir, flow, prevFlow) -> Path:
+	if prevDepth is None and prevFlow is None:
+		warning("prevDepth and prevFlow are None")
+		return None
+
+	# Flow is pixels/second floating point
+	depth[np.isinf(depth)] = -1
+	prevDepth[np.isinf(prevDepth)] = -1
+
+	# cv2.remap LOOKS UP every pixel with a coordinate, we can't 'apply' the optical flow as a transform directly.
+	# So the optical flow needs to be registered with the target depth for that frame, otherwise we try to look up pixels for which have no optical flow!
+	# So... we need the future depth frame and its optical flow to tell us how that depth maps to this frame
+	# Or more accurately, we need the current depth frame and its optical flow to tell us how that depth frame maps the the PREVIOUS depth frame
+	coords = np.indices(prevDepth.shape).transpose(1, 2, 0).astype(np.float32)
+	coords = coords[:, :, ::-1]
+	coords -= (flow / framerate).transpose(1, 2, 0)
+	prevDepthWarped = cv2.remap(prevDepth, coords, None, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=-1)
+	# Although, the optical flow VALUES I think are not perfectly aligned with the time frame (maybe they tell us about the instantaneous velocity, not between the last two frames)
+	# The issue with this is that a few discontinuities are propogated and duplicated.
+	# First we can definitely re-exclude where we don't currently have depth
+	prevDepthWarped[depth <= 0] = 0
+	danger = prevDepthWarped - depth # Moving closer at m/frame
+	danger = danger * framerate # Moving closer at m/s
+	danger[danger < 0] = 0 # We only care about positive values
+
+	# We don't know these values, allow downstream tasks to ignore these
+	# To filter the most severe remaining positive discontinuity problems, things getting closer very quickly is unrealistic
+	danger[danger > maxVelFilter] = -1 # m/s
+	danger = danger / (depth + eps) # Danger is now 1/s
+	danger[depth <= 0] = -1
+	danger[prevDepthWarped <= 0] = -1
+
+	outPath = outDir / f"danger{iter:05d}.npz"
+	np.savez_compressed(str(outPath), danger.astype(np.float32))
+
+	# DEBUG stuff
+	# cv2.imwrite(str(outPath), danger)
+	# danger = danger * 255
+	# danger[danger > 255] = 255
+	# outPath0 = outDir / f"danger{iter:05d}_.png"
+	# cv2.imwrite(str(outPath0), danger)
+	# depth = depth * 255
+	# depth[depth > 255] = 255
+	# outPath1 = outDir / f"depth{iter:05d}.png"
+	# cv2.imwrite(str(outPath1), depth)
+	# outPath2 = outDir / f"prevDepthWarped{iter:05d}.png"
+	# prevDepthWarped = prevDepthWarped * 255
+	# prevDepthWarped[prevDepthWarped > 255] = 255
+	# cv2.imwrite(str(outPath2), prevDepthWarped)
+
+	return outPath
+
+
+# ==================================================================================================
+# def save_to_lmdb(env, path: Path, iter: int, bagPath: Path, keys: set, keep: bool = False):
+# def save_to_lmdb(env, path: Path, iter: int, keys: set, keep: bool = False):
+def save_to_lmdb(env, path: Path, keys: set, keep: bool = False, offset: int = 0):
+	if path is None:
+		warning("Cannot write empty path")
+		return
+	# key = f"{bagPath.stem}/{iter}"
+	key = str(len(keys) + offset)
+	keys.add(key)
+	with env.begin(write=True) as txn:
+		with open(path, mode="rb") as file:
+			txn.put(
+				key=pickle.dumps(key),
+				value=file.read(),
+				dupdata=False,
+			)
+	if not keep:
+		path.unlink()
+
+
+# ==================================================================================================
+def parse_args():
+	parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+	parser.add_argument("input_path", type=Path,
+						help="Directory to search for ROS bags, or path to a ROS bag")
+	parser.add_argument("--max_aug",
+						help="Maximum number of images to combine for augmentation.",
+						type=int,
+						default=6)
+	parser.add_argument("--n_skip",
+						help="Maximum number of images to combine for augmentation.",
+						type=int,
+						default=1)
+	parser.add_argument("-f", "--force", action="store_true", help="Overwrite output LMDB if any part of it exists")
+	parser.add_argument("-d", "--debug", action="store_true", help="Only do 0.5 seconds as a test")
+	parser.add_argument("-k", "--keep", action="store_true", help="Don't remove intermediate files")
+
+	return parser.parse_args()
+
+
+# ==================================================================================================
+if __name__ == "__main__":
+	main(parse_args())
diff --git a/evreflex/tools/scannet_ply2obj.py b/evreflex/tools/scannet_ply2obj.py
new file mode 100755
index 0000000000000000000000000000000000000000..72759b0e6225fea7d762e6fbd63bd94a34a8c18f
--- /dev/null
+++ b/evreflex/tools/scannet_ply2obj.py
@@ -0,0 +1,328 @@
+import bpy
+from pathlib import Path
+import numpy as np
+from kellog import info, warning, error, debug
+from tqdm import tqdm
+from natsort import natsorted
+import numpy_indexed as npi
+import subprocess
+import tempfile
+import os
+import cv2
+import math
+from io_scene_obj import import_obj
+
+categories = np.array([
+	(0, 0, 0), # IGNORE
+	(174, 199, 232), # wall
+	(152, 223, 138), # floor
+	(31, 119, 180), # cabinet
+	(255, 187, 120), # bed
+	(188, 189, 34), # chair
+	(140, 86, 75), # sofa
+	(255, 152, 150), # table
+	(214, 39, 40), # door
+	(197, 176, 213), # window
+	(148, 103, 189), # bookshelf
+	(196, 156, 148), # picture
+	(23, 190, 207), # counter
+	(178, 76, 76), # blinds
+	(247, 182, 210), # desk
+	(66, 188, 102), # shelves
+	(219, 219, 141), # curtain
+	(140, 57, 197), # dresser
+	(202, 185, 52), # pillow
+	(51, 176, 203), # mirror
+	(200, 54, 131), # floor mat
+	(92, 193, 61), # clothes
+	(78, 71, 183), # ceiling
+	(172, 114, 82), # books
+	(255, 127, 14), # fridge
+	(91, 163, 138), # tv
+	(153, 98, 156), # paper
+	(140, 153, 101), # towel
+	(158, 218, 229), # shower curtain
+	(100, 125, 154), # box
+	(178, 127, 135), # white board
+	(120, 185, 128), # person
+	(146, 111, 194), # night stand
+	(44, 160, 44), # toilet
+	(112, 128, 144), # sink
+	(96, 207, 209), # lamp
+	(227, 119, 194), # bathtub
+	(213, 92, 176), # bag
+	(94, 106, 211), # other struct
+	(82, 84, 163), # other furniture
+	(100, 85, 144), # other prop
+], dtype=np.uint8)
+# Because OpenCV does BGR down there
+categories = np.array([[c[2], c[1], c[0]] for c in categories], dtype=np.uint8)
+
+
+# ==================================================================================================
+def iterate(path: Path):
+	# ("*/") matches files too
+	scenes = [i for i in natsorted(path.glob("scene*")) if i.is_dir()]
+	if not len(scenes):
+		raise RuntimeError("No directories found starting with 'scene'")
+	for scene in tqdm(scenes):
+		if not scene.is_dir():
+			continue
+		try:
+			process(scene)
+		except Exception as e:
+			error(e)
+			res = 8192
+			warning(f"Retrying scene with a higher resolution texture ({res})")
+			try:
+				process(scene, res)
+			except Exception as e:
+				error(e)
+				res = 16384
+				warning(f"Retrying scene with a higher resolution texture ({res})")
+				process(scene, res)
+		# break
+	debug("DONE")
+
+
+# ==================================================================================================
+def process(scene, resolution: int = 4096):
+	debug(scene)
+	clear_scene()
+	rgbPath = scene / f"{scene.name}_vh_clean_2.ply"
+	labelsPath = scene / f"{scene.name}_vh_clean_2.labels.ply"
+	if not rgbPath.exists() or not labelsPath.exists():
+		if not rgbPath.exists():
+			raise RuntimeError(f"Cannot find '{rgbPath}'")
+		if not labelsPath.exists():
+			raise RuntimeError(f"Cannot find '{labelsPath}'")
+
+	map_uvs(labelsPath, rgbPath, size=resolution)
+	os.remove(labelsPath.with_suffix(".obj"))
+	os.remove(f'{labelsPath.with_suffix(".obj")}.mtl')
+	model = import_model(rgbPath.with_suffix(".obj"))
+	centre_object_xy(model[1])
+	bpy.ops.object.origin_set(type="ORIGIN_CENTER_OF_MASS")
+	correct_texture(model[1], labelsPath.with_suffix(".png"))
+	merge_textures(labelsPath.with_suffix(".png"), rgbPath.with_suffix(".png"))
+	os.remove(labelsPath.with_suffix(".png"))
+
+	debug("\tExporting centred mesh")
+	bpy.ops.export_scene.obj(
+		filepath=str(rgbPath.with_suffix(".obj")),
+		check_existing=False,
+		axis_forward="Y",
+		axis_up="Z",
+		use_triangles=True,
+		keep_vertex_order=True,
+		global_scale=1.0,
+		path_mode="COPY"
+	)
+
+
+# ==================================================================================================
+def clear_scene():
+	debug("\tClearing scene")
+	bpy.ops.object.select_all(action="DESELECT")
+	# Blender version >= 2.8x?
+	for obj in bpy.context.scene.objects:
+		if obj.type == "MESH":
+			obj.select_set(True)
+		else:
+			obj.select_set(False)
+	bpy.ops.object.delete()
+	clean_deleted()
+
+
+# ==================================================================================================
+def clean_deleted():
+	# Remove unused blocks without closing and reopening blender
+	for block in bpy.data.meshes:
+		if block.users == 0:
+			bpy.data.meshes.remove(block)
+	for block in bpy.data.materials:
+		if block.users == 0:
+			bpy.data.materials.remove(block)
+	for block in bpy.data.textures:
+		if block.users == 0:
+			bpy.data.textures.remove(block)
+	for block in bpy.data.images:
+		if block.users == 0:
+			bpy.data.images.remove(block)
+
+
+# ==================================================================================================
+def map_uvs(labelsPath: Path, rgbPath: Path, size: int = 4096) -> Path:
+	debug("\tMapping UVs")
+	with tempfile.NamedTemporaryFile(mode="w+") as tmp:
+		tmp.write(generate_script(labelsPath.with_suffix(".png"), size=size))
+		tmp.seek(0)
+		# Generate RGB texture AND export the mesh as a .obj file
+		with subprocess.Popen(["meshlabserver", "-i", labelsPath, "-o", labelsPath.with_suffix(".obj"), "-m", "wt", "-s", tmp.name]) as ps:
+			pass
+	with tempfile.NamedTemporaryFile(mode="w+") as tmp:
+		tmp.write(generate_script(rgbPath.with_suffix(".png"), size=size))
+		tmp.seek(0)
+		# Generate RGB texture
+		with subprocess.Popen(["meshlabserver", "-i", rgbPath, "-o", rgbPath.with_suffix(".obj"), "-m", "wt", "-s", tmp.name]) as ps:
+			pass
+
+
+# ==================================================================================================
+def import_model(path: Path):
+	debug(f"\tImporting '{path}'")
+	# bpy.ops.import_mesh.obj(filepath=str(path))
+	context = bpy.context
+	import_obj.load(context, filepath=str(path))
+
+	return get_model()
+
+
+# ==================================================================================================
+def get_model():
+	scene = bpy.data.scenes["Scene"]
+	objects = scene.objects.items()
+	models = [o for o in objects if o[0].startswith("scene")]
+	assert(len(models) == 1)
+
+	return models[0]
+
+
+# ==================================================================================================
+def vertices_equal(a: bpy.types.Mesh, b: bpy.types.Mesh) -> bool:
+	"""Equal `Mesh.polygons.vertices` using __eq__ operator return False, so this function does it manually."""
+	return all([c == l for ap, bp in zip(a.polygons, b.polygons) for c, l in zip(ap.vertices, bp.vertices)])
+
+
+# ==================================================================================================
+def verify_models(colour_model, label_model):
+	debug("\tChecking for equal number of polygons...")
+	if len(colour_model.polygons) != len(label_model.polygons):
+		raise ValueError(f"Number of polygons do not match: {len(colour_model.polygons)} RGB vs. {len(label_model.polygons)} label")
+
+	debug("\tChecking for equal vertices...")
+	if not vertices_equal(colour_model, label_model):
+		raise ValueError(f"Vertices do not match!")
+
+
+# ==================================================================================================
+def centre_object_xy(model):
+	debug("\tCentering mesh on XY")
+	bpy.context.view_layer.objects.active = model
+	bpy.ops.object.origin_set(type='ORIGIN_CENTER_OF_VOLUME')
+	bpy.context.object.location.x = 0.0
+	bpy.context.object.location.y = 0.0
+	bpy.ops.object.transform_apply()
+
+
+# ==================================================================================================
+def correct_texture(model, labelsPath):
+	debug("\tCorrecting texture interpolation")
+	me = model.data
+	assert(len(me.loops) % 3 == 0)
+	img = cv2.imread(str(labelsPath))
+
+	bg = np.zeros_like(img)
+	# for i in tqdm(range(0, len(me.loops), 3)):
+	for i in range(0, len(me.loops), 3):
+		t1 = me.uv_layers.active.data[i].uv
+		t2 = me.uv_layers.active.data[i + 1].uv
+		t3 = me.uv_layers.active.data[i + 2].uv
+
+		# `round()` seems to do the best job
+		x1 = math.floor(img.shape[1] * t1.x)
+		x2 = math.floor(img.shape[1] * t2.x)
+		x3 = math.floor(img.shape[1] * t3.x)
+		y1 = img.shape[0] - round(img.shape[0] * t1.y)
+		y2 = img.shape[0] - round(img.shape[0] * t2.y)
+		y3 = img.shape[0] - round(img.shape[0] * t3.y)
+		c1 = img[y1, x1]
+		c2 = img[y2, x2]
+		c3 = img[y3, x3]
+
+		if (c1 == c2).all() and (c1 == c2).all() and (c1 == c3).all():
+			c = c1.tolist()
+		if (c1 == c2).all() or (c1 == c3).all():
+			c = c1.tolist()
+		elif (c2 == c3).all():
+			c = c2.tolist()
+		else:
+			c = c1.tolist()
+
+		contours = np.array([[x1, y1], [x2, y2], [x3, y3]]).astype(int)
+		cv2.drawContours(bg, [contours], 0, c, thickness=2)
+		cv2.fillPoly(bg, pts=[contours], color=c)
+	cv2.imwrite(str(labelsPath), bg)
+
+
+# ==================================================================================================
+def merge_textures(labelPath, rgbPath):
+	debug("\tMerging textures")
+	labels = cv2.imread(str(labelPath))
+	rgb = cv2.imread(str(rgbPath))
+
+	labels_ = labels.reshape([labels.shape[0] * labels.shape[1], 3])
+
+	indices = npi.indices(categories, labels_, missing=0)
+	# debug(np.unique(indices))
+	labels__ = indices.reshape([*labels.shape[:2], 1]).astype(np.uint8)
+	img = np.dstack([rgb, labels__])
+	cv2.imwrite(str(rgbPath), img)
+
+
+# ==================================================================================================
+class OT_TestOpenFilebrowser(bpy.types.Operator):
+	bl_idname = "test.open_filebrowser"
+	bl_label = "Choose"
+
+	# Instead of inheriting from ImportHelper
+	directory: bpy.props.StringProperty(
+		name="Path",
+		description="Path to directory"
+	)
+	# ----------------------------------------------------------------------------------------------
+	def invoke(self, context, event):
+		"""Called when registering the class."""
+		context.window_manager.fileselect_add(self)
+
+		return {"RUNNING_MODAL"}
+
+	# ----------------------------------------------------------------------------------------------
+	def execute(self, context):
+		"""Called when the file browser is submitted"""
+		iterate(Path(self.directory))
+
+		bpy.utils.unregister_class(self.__class__)
+
+		return {"FINISHED"}
+
+
+# ==================================================================================================
+if __name__ == "__main__":
+	print("--------------------------------")
+	bpy.context.scene.render.engine = "CYCLES"
+	bpy.utils.register_class(OT_TestOpenFilebrowser)
+	bpy.ops.test.open_filebrowser("INVOKE_DEFAULT")
+
+
+# ==================================================================================================
+def generate_script(path: Path, size: int = 4096) -> str:
+	return f"""\
+<!DOCTYPE FilterScript>
+<FilterScript>
+ <filter name="Parametrization: Trivial Per-Triangle">
+  <Param tooltip="Indicates how many triangles have to be put on each line (every quad contains two triangles)&#xa;Leave 0 for automatic calculation" isxmlparam="0" description="Quads per line" value="0" name="sidedim" type="RichInt"/>
+  <Param tooltip="Gives an indication on how big the texture is" isxmlparam="0" description="Texture Dimension (px)" value="{size}" name="textdim" type="RichInt"/>
+  <Param tooltip="Specifies how many pixels to be left between triangles in parametrization domain" isxmlparam="0" description="Inter-Triangle border (px)" value="4" name="border" type="RichInt"/>
+  <Param tooltip="Choose space optimizing to map smaller faces into smaller triangles in parametrizazion domain" enum_cardinality="2" isxmlparam="0" description="Method" enum_val1="Space-optimizing" value="0" name="method" enum_val0="Basic" type="RichEnum"/>
+ </filter>
+ <filter name="Transfer: Vertex Color to Texture">
+  <Param name="textName" type="RichString" tooltip="The texture file to be created" value="{path.name}" description="Texture file" isxmlparam="0"/>
+  <Param name="textW" type="RichInt" tooltip="The texture width" value="{size}" description="Texture width (px)" isxmlparam="0"/>
+  <Param name="textH" type="RichInt" tooltip="The texture height" value="{size}" description="Texture height (px)" isxmlparam="0"/>
+  <Param name="overwrite" type="RichBool" tooltip="if current mesh has a texture will be overwritten (with provided texture dimension)" value="false" description="Overwrite texture" isxmlparam="0"/>
+  <Param name="assign" type="RichBool" tooltip="assign the newly created texture" value="true" description="Assign texture" isxmlparam="0"/>
+  <Param name="pullpush" type="RichBool" tooltip="if enabled the unmapped texture space is colored using a pull push filling algorithm, if false is set to black" value="true" description="Fill texture" isxmlparam="0"/>
+ </filter>
+</FilterScript>
+"""
diff --git a/evreflex/train_evreflex.py b/evreflex/train_evreflex.py
new file mode 100755
index 0000000000000000000000000000000000000000..ccd17f731647f4c0ab09cff7a4a2fbcbd46529ac
--- /dev/null
+++ b/evreflex/train_evreflex.py
@@ -0,0 +1,99 @@
+#!/usr/bin/env python3
+import colored_traceback.auto
+import argparse
+from pathlib import Path
+from torchvision import transforms
+import pytorch_lightning as pl
+from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
+from pytorch_lightning import loggers
+import setproctitle
+from kellog import info, warning, error, debug
+
+from evreflex import utils
+from evreflex.models import EVReflexNet
+
+# ==================================================================================================
+def main(args):
+	setproctitle.setproctitle(args.proctitle)
+	model = EVReflexNet(
+		operation=args.operation,
+		train_path=args.train_path,
+		val_path=args.val_path,
+		test_path=None,
+		batch_size=args.batch_size,
+		workers=args.workers,
+		lr=args.lr,
+		blocks=[3, 2],
+		subset=args.subset,
+	)
+	if args.version is not None:
+		import torch
+		checkpointPath = args.out_dir / Path(__file__).stem / f"version_{args.version}" / "checkpoints" / "last.ckpt"
+		debug(checkpointPath)
+		debug(checkpointPath.exists())
+	else:
+		checkpointPath = None
+	# `monitor` should correspond to something logged with `self.log`
+	# `monitor` will apply in `validation_step()` if it's there, else `training_step()`
+	checkpoint_callback = ModelCheckpoint(
+		# monitor="train_loss",
+		monitor="val_loss",
+		filename="{epoch:02d} {train_loss:.2f}",
+		save_top_k=3,
+		save_last=True
+	)
+	if args.out_dir is None:
+		warning("'out_dir' was not specified, not logging!")
+		logger = None
+		lr_callback = None
+		callbacks = None
+	else:
+		logger = loggers.TensorBoardLogger(
+			name=Path(__file__).stem,
+			save_dir=args.out_dir,
+			log_graph=True,
+			version=args.version
+		)
+		# Default placeholder "hp_metric" neccessary so that the hyperparameters are written to file
+		logger.log_hyperparams(args)
+		logger.log_hyperparams(utils.get_system_info())
+		lr_callback = LearningRateMonitor(logging_interval="epoch")
+		callbacks = [checkpoint_callback, lr_callback]
+	trainer = pl.Trainer(
+		logger=logger,
+		gpus=0 if args.cpu else -1,
+		# resume_from_checkpoint=Path(),
+		callbacks=callbacks,
+		limit_train_batches=args.train_lim,
+		limit_val_batches=args.val_lim,
+		overfit_batches=args.overfit,
+		resume_from_checkpoint=checkpointPath,
+	)
+
+	trainer.fit(model)
+
+
+# ==================================================================================================
+def parse_args():
+	parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+	parser.add_argument("train_path", type=Path, help="Path to EVReflex dataset, either LMDB directory or `.txt` list of relative paths")
+	parser.add_argument("val_path", type=Path, help="Path to EVReflex dataset, either LMDB directory or `.txt` list of relative paths")
+	parser.add_argument("operation", type=str, choices=["depth", "flow", "both", "dynamic_depth", "dynamic_both"], help="Network flavour")
+	parser.add_argument("--out_dir", type=Path, help="Location to store checkpoints and logs", default=None)
+	parser.add_argument("--proctitle", type=str, help="Process title", default=Path(__file__).name)
+	parser.add_argument("--cpu", action="store_true", help="Run on CPU even if GPUs are available")
+	parser.add_argument("-b", "--batch_size", type=int, help="Batch size", default=8)
+	parser.add_argument("--lr", type=float, help="Initial learning rate", default=0.01)
+	parser.add_argument("-w", "--workers", type=int, help="Number of workers for data loader", default=4)
+	parser.add_argument("--train_lim", type=float, help="Proportion of train epoch length to use", default=1.0)
+	parser.add_argument("--val_lim", type=float, help="Proportion of val epoch length to use", default=1.0)
+	parser.add_argument("--overfit", type=float, help="Proportion of train dataset to use", default=0.0)
+	parser.add_argument("--subset", type=float, help="Proportion of train dataset to use", default=None)
+	parser.add_argument("-v", "--version", type=int, help="Try to continue training from this version", default=None)
+
+	return parser.parse_args()
+
+
+# ==================================================================================================
+if __name__ == "__main__":
+	main(parse_args())
diff --git a/evreflex/utils/__init__.py b/evreflex/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8a6c3515279740f31696a6e61d463ca711e7bac
--- /dev/null
+++ b/evreflex/utils/__init__.py
@@ -0,0 +1,8 @@
+"""Useful functions."""
+from .utils import get_system_info
+from .utils import get_gpu_info
+from .utils import get_git_rev
+from .utils import proctitle
+from .utils import flow_viz
+from .utils import downsample_image_tensor
+from .utils import calc_iou
diff --git a/evreflex/utils/utils.py b/evreflex/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a8112e9c60ce6b245b66eabe83bfb36113ca0b1
--- /dev/null
+++ b/evreflex/utils/utils.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python3
+import torch
+import platform
+from pathlib import Path
+from git import Repo
+import inspect
+import time
+import setproctitle
+import numpy as np
+import cv2
+
+eps = 1e-12
+
+# ==================================================================================================
+def get_system_info():
+	return {
+		"hostname": platform.uname()[1],
+		"gpus": get_gpu_info(),
+	}
+
+
+# ==================================================================================================
+def get_gpu_info():
+	gpus = []
+	for gpu in range(torch.cuda.device_count()):
+		properties = torch.cuda.get_device_properties(f"cuda:{gpu}")
+		gpus.append({
+			"name": properties.name,
+			"memory": round(properties.total_memory / 1e9, 2),
+			"capability": f"{properties.major}.{properties.minor}",
+		})
+
+	return gpus
+
+
+#===================================================================================================
+def get_git_rev(cwd=Path(inspect.stack()[1][1]).parent): # Parent of called script by default
+	repo = Repo(cwd)
+	sha = repo.head.commit.hexsha
+	output = repo.git.rev_parse(sha, short=7)
+	if repo.is_dirty():
+		output += " (dirty)"
+	output += " - " + time.strftime("%a %d/%m/%Y %H:%M", time.gmtime(repo.head.commit.committed_date))
+
+	return output
+
+
+# ==================================================================================================
+def proctitle(title: str = None):
+	import __main__
+	if title:
+		setproctitle.setproctitle(title)
+	else:
+		setproctitle.setproctitle(Path(__main__.__file__).name)
+
+
+# ==================================================================================================
+def flow_viz(tensor: torch.Tensor) -> torch.Tensor:
+	assert tensor.dim() == 4
+	assert tensor.shape[1] == 2
+	flowImg = torch.zeros([tensor.shape[0], 3, tensor.shape[2], tensor.shape[3]], dtype=torch.uint8)
+	for i, f in enumerate(flowImg):
+		flowImg[i] = torch.tensor(_flow_viz_np(tensor[i, 0, ...].numpy(), tensor[i, 1, ...].numpy())).permute(2, 0, 1)
+
+	return flowImg
+
+
+# ==================================================================================================
+def _flow_viz_np(flow_x, flow_y):
+	flows = np.stack((flow_x, flow_y), axis=2)
+	flows[np.isinf(flows)] = 0
+	flows[np.isnan(flows)] = 0
+	mag = np.linalg.norm(flows, axis=2)
+	ang = np.arctan2(flow_y, flow_x)
+	p = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
+	ang += np.pi
+	ang *= 180. / np.pi / 2.
+	ang = ang.astype(np.uint8)
+	hsv = np.zeros([flow_x.shape[0], flow_x.shape[1], 3], dtype=np.uint8)
+	hsv[:, :, 0] = ang
+	hsv[:, :, 1] = 255
+	hsv[:, :, 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
+	flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
+
+	return flow_rgb
+
+
+# ==================================================================================================
+def downsample_image_tensor(tensor):
+	assert(tensor.dim() >= 1)
+	assert(tensor.shape[-1] % 2 == 0 and tensor.shape[-2] % 2 == 0)
+	down = tensor[..., ::2, ::2]
+	up = down.repeat_interleave(2, -2).repeat_interleave(2, -1)
+
+	return up
+
+
+# ==================================================================================================
+def calc_iou(tensor, target, mask=None):
+	intersection = (tensor & target).float() # AND
+	union = (tensor | target).float() # OR
+	# Subtract ignored pixels from both intersection and union
+	if mask is not None:
+		intersection[mask] = 0
+		union[mask] = 0
+	iou = (intersection.sum((-2, -1)) + eps) / (union.sum((-2, -1)) + eps)
+
+	return iou
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c4df353cbec199a178a66b8d9bfd0a5b1c681725
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,13 @@
+numpy
+pytorch
+pytorch-lightning
+lmdb
+opencv-python
+Pillow-SIMD
+open3d
+tqdm
+natsort
+setproctitle
+colored_traceback
+kellog
+gitpython
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a9a0f5301a7838688bc7e9a440f7c22f0731563
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,15 @@
+import setuptools
+
+setuptools.setup(
+	name="evreflex",
+	version="0.0.1",
+	author="Celyn Walters",
+	author_email="celyn.walters@surrey.ac.uk",
+	url="https://gitlab.eps.surrey.ac.uk/cw0071/evreflex/",
+	packages=setuptools.find_packages(),
+	classifiers=[
+		"Programming Language :: Python :: 3",
+		"Operating System :: OS Independent",
+	],
+	python_requires=">=3.6",
+)