diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..5d2af0d67735d6ad75cf65a7a7d8174781795c72
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+
+.idea/
+
+__pycache__/
diff --git a/env.yml b/env.yml
index 3bfc0b376295f5748cc68b706b438f7460b733b9..c75e09f37a14b8bb9929b8261033f1c86e827a60 100644
--- a/env.yml
+++ b/env.yml
@@ -7,6 +7,7 @@ dependencies:
   - numpy
   - opencv=4.5.0=py38_2
   - pandas
+  - matplotlib==3.3.4
   - pillow
   - pip=20.2.4
   - py-opencv=4.5.0
diff --git a/example.py b/example.py
index 84de92b90683bd4395fc9e748038c34e48fa2d30..f0308eb43e51c618914287d8eb76e8528277e7c5 100644
--- a/example.py
+++ b/example.py
@@ -1,6 +1,7 @@
 from model import Model
 import torch.nn.functional as F
 from utils import *
+import matplotlib
 
 
 def inference(original, query):
@@ -12,7 +13,7 @@ def inference(original, query):
         if pred == 0:
             grid[i] *= 0
 
-    heatmap = short_summary_image(dewarped_query.squeeze(0), prediction=grid)
+    heatmap = short_summary_image(dewarped_query.squeeze(0), prediction=grid, size=original.size)
     heatmap.show()
     print(f'The image is {model.cls2name[cls.item()]}.')
 
@@ -21,11 +22,11 @@ if __name__ == '__main__':
     model = Model.load_from_checkpoint('weights/best.ckpt')
     model.eval().cuda()
 
-    original = Image.open('example/original.jpg').convert('RGB')
-    tampered = Image.open('example/tampered.jpg').convert('RGB')
-    tampered_benign = Image.open('example/tampered_benign.jpg').convert('RGB')
+    original = Image.open('example/nonsquare_real.jpg').convert('RGB')
+    tampered = Image.open('example/nonsquare_tampered.jpg').convert('RGB')
+    # tampered_benign = Image.open('example/tampered_benign.jpg').convert('RGB')
 
     inference(original, original)
     inference(original, tampered)
-    inference(original, tampered_benign)
+    # inference(original, tampered_benign)
 
diff --git a/example/nonsquare_real.jpg b/example/nonsquare_real.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..daf645e7f7fae570b939250267d9b2a73d9c4ae6
Binary files /dev/null and b/example/nonsquare_real.jpg differ
diff --git a/example/nonsquare_tampered.jpg b/example/nonsquare_tampered.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..50c42472091cfe4d5a9049a88703e218c331d78f
Binary files /dev/null and b/example/nonsquare_tampered.jpg differ
diff --git a/model.py b/model.py
index 703f1f1a9149ed65106a9dced509e1dab4a845dc..552149447b7e45c5df419d0b799e385ed2881e26 100644
--- a/model.py
+++ b/model.py
@@ -15,7 +15,7 @@ class Model(pl.LightningModule):
         # Hyperparameters
         self.hparams = hparams
 
-        self.cls2name = {0: 'same', 1: 'tampered', 2: 'different'}
+        self.cls2name = {0: 'not tampered', 1: 'tampered', 2: 'different'}
 
         self.raft = RAFT(argparse.Namespace(alternate_corr=False, mixed_precision=False, small=False))
         self.raft = torch.nn.DataParallel(self.raft)
diff --git a/utils.py b/utils.py
index e6f3cda6858355d51bc44f02a0e50814d35fcbaf..9adf926455f85d0594a4c561175dca8ca261fe19 100644
--- a/utils.py
+++ b/utils.py
@@ -97,11 +97,11 @@ def mask_processing(x, use_t=True):
     return x
 
 
-def grid_to_heatmap(grid, cmap='jet', size=1024):
+def grid_to_heatmap(grid, size, cmap='jet'):
     # TODO: pad grid with zeros to remove side stickiness ?
 
     mask = TF.to_pil_image(grid.view(7, 7))
-    mask = mask.resize((size, size), Image.BICUBIC)
+    mask = mask.resize(size, Image.BICUBIC)
     mask = Image.eval(mask, mask_processing)
 
     # Heatmap
@@ -158,22 +158,23 @@ def preprocess_pil(original, query):
     return torch.vstack((original, query)).unsqueeze(0)
 
 
-def short_summary_image(img, target=None, prediction=None, size=1024):
+def short_summary_image(img, size, target=None, prediction=None):
     # Photoshopped image
     if not Image.isImageType(img):
         img = unnormalise(img)
-        img = TF.to_pil_image(img).resize((size, size))
+        img = TF.to_pil_image(img)
+    img = img.resize(size)
 
     # Heatmap of target
     if target is not None:
-        heatmap, mask = grid_to_heatmap(target, cmap='winter')
+        heatmap, mask = grid_to_heatmap(target, cmap='winter', size=img.size)
         img.paste(heatmap, (0, 0), mask)
 
     # Heatmap of prediction
     if prediction is not None:
         prediction -= prediction.min()
         prediction = prediction / prediction.max()
-        heatmap, mask = grid_to_heatmap(prediction, cmap='Wistia')
+        heatmap, mask = grid_to_heatmap(prediction, cmap='Wistia', size=img.size)
         img.paste(heatmap, (0, 0), mask)
 
     return img