diff --git a/dirtorch/datasets/generic.py b/dirtorch/datasets/generic.py
index c0915ade63e6df95af6696f22c4c0537fc101115..bc0f14d2f4f5f8bd7e48495f8ca6292aa04bb30a 100644
--- a/dirtorch/datasets/generic.py
+++ b/dirtorch/datasets/generic.py
@@ -69,12 +69,14 @@ class ImageListLabels(LabelledDataset):
 
     def get_label(self, i, toint=False):
         label = self.labels[i]
-        if toint: label = self.cls_idx[ label ]
+        if toint:
+            label = self.cls_idx[label]
         return label
 
     def get_query_db(self):
         return self
 
+
 class ImageListLabelsQ(ImageListLabels):
     ''' Two list of images with labels: one for the dataset and one for the queries.
 
@@ -125,7 +127,7 @@ class ImageListRelevants(Dataset):
 
         Input: path to the pickle file
     """
-    def __init__(self, gt_file, root=None, img_dir = 'jpg', ext='.jpg'):
+    def __init__(self, gt_file, root=None, img_dir='jpg', ext='.jpg'):
         self.root = root
         self.img_dir = img_dir
 
@@ -146,17 +148,25 @@ class ImageListRelevants(Dataset):
         self.nquery = len(self.qimgs)
 
     def get_relevants(self, qimg_idx, mode='classic'):
-        if mode=='classic': rel = self.relevants[qimg_idx]
-        elif mode=='easy': rel = self.easy[qimg_idx]
-        elif mode=='medium': rel = self.easy[qimg_idx] + self.hard[qimg_idx]
-        elif mode=='hard': rel = self.hard[qimg_idx]
+        if mode == 'classic':
+            rel = self.relevants[qimg_idx]
+        elif mode == 'easy':
+            rel = self.easy[qimg_idx]
+        elif mode == 'medium':
+            rel = self.easy[qimg_idx] + self.hard[qimg_idx]
+        elif mode == 'hard':
+            rel = self.hard[qimg_idx]
         return rel
 
     def get_junk(self, qimg_idx, mode='classic'):
-        if mode=='classic': junk = self.junk[qimg_idx]
-        elif mode=='easy': junk = self.junk[qimg_idx] + self.hard[qimg_idx]
-        elif mode=='medium': junk = self.junk[qimg_idx]
-        elif mode=='hard': junk = self.junk[qimg_idx] + self.easy[qimg_idx]
+        if mode == 'classic':
+            junk = self.junk[qimg_idx]
+        elif mode == 'easy':
+            junk = self.junk[qimg_idx] + self.hard[qimg_idx]
+        elif mode == 'medium':
+            junk = self.junk[qimg_idx]
+        elif mode == 'hard':
+            junk = self.junk[qimg_idx] + self.easy[qimg_idx]
         return junk
 
     def get_query_filename(self, qimg_idx, root=None):
@@ -175,38 +185,42 @@ class ImageListRelevants(Dataset):
         return ImageListROIs(self.root, self.img_dir, self.qimgs, self.qroi)
 
     def get_query_groundtruth(self, query_idx, what='AP', mode='classic'):
-        res = -np.ones(self.nimg, dtype=np.int8) # all negatives
-        res[self.get_relevants(query_idx, mode)] = 1 # positives
-        res[self.get_junk(query_idx, mode)] = 0 # junk
+        # negatives
+        res = -np.ones(self.nimg, dtype=np.int8)
+        # positive
+        res[self.get_relevants(query_idx, mode)] = 1
+        # junk
+        res[self.get_junk(query_idx, mode)] = 0
         return res
 
     def eval_query_AP(self, query_idx, scores):
         """ Evaluates AP for a given query.
         """
-        from ..utils.evaluation import compute_AP
+        from ..utils.evaluation import compute_average_precision
         if self.relevants:
-            gt = self.get_query_groundtruth(query_idx, 'AP') # labels in {-1, 0, 1}
-            if gt.shape != scores.shape:
-                # TODO: Get this number in a less hacky way. This was the number of non-corrupted distractors
-                gt = np.concatenate([gt, np.full((976089,), fill_value=-1)])
+            gt = self.get_query_groundtruth(query_idx, 'AP')  # labels in {-1, 0, 1}
             assert gt.shape == scores.shape, "scores should have shape %s" % str(gt.shape)
             assert -1 <= gt.min() and gt.max() <= 1, "bad ground-truth labels"
             keep = (gt != 0)  # remove null labels
-            return compute_AP(gt[keep]>0, scores[keep])
+
+            gt, scores = gt[keep], scores[keep]
+            gt_sorted = gt[np.argsort(scores)[::-1]]
+            positive_rank = np.where(gt_sorted == 1)[0]
+            ap = compute_average_precision(positive_rank)
         else:
             d = {}
             for mode in ('easy', 'medium', 'hard'):
-                gt = self.get_query_groundtruth(query_idx, 'AP', mode) # labels in {-1, 0, 1}
-                if gt.shape != scores.shape:
-                    # TODO: Get this number in a less hacky way. This was the number of non-corrupted distractors
-                    gt = np.concatenate([gt, np.full((976089,), fill_value=-1)])
+                gt = self.get_query_groundtruth(query_idx, 'AP', mode)  # labels in {-1, 0, 1}
                 assert gt.shape == scores.shape, "scores should have shape %s" % str(gt.shape)
                 assert -1 <= gt.min() and gt.max() <= 1, "bad ground-truth labels"
                 keep = (gt != 0)  # remove null labels
-                if sum(gt[keep]>0) == 0: #exclude queries with no relevants from the evaluation
+                if sum(gt[keep] > 0) == 0:  # exclude queries with no relevants from the evaluation
                     d[mode] = -1
                 else:
-                    d[mode] = compute_AP(gt[keep]>0, scores[keep])
+                    gt2, scores2 = gt[keep], scores[keep]
+                    gt_sorted = gt2[np.argsort(scores2)[::-1]]
+                    positive_rank = np.where(gt_sorted == 1)[0]
+                    d[mode] = compute_average_precision(positive_rank)
             return d
 
 
@@ -235,6 +249,7 @@ class ImageListROIs(Dataset):
             img = img.resize(resize, Image.ANTIALIAS if np.prod(resize) < np.prod(img.size) else Image.BICUBIC)
         return img
 
+
 def not_none(label):
     return label is not None
 
@@ -256,10 +271,12 @@ class ImageClusters(LabelledDataset):
 
         for img, cls in data.items():
             assert type(img) is str
-            if not filter(cls): continue
-            if type(cls) not in (str,int,type(None)): continue
-            self.imgs.append( img )
-            self.labels.append( cls )
+            if not filter(cls):
+                continue
+            if type(cls) not in (str, int, type(None)):
+                continue
+            self.imgs.append(img)
+            self.labels.append(cls)
 
         self.find_classes()
         self.nimg = len(self.imgs)
@@ -270,7 +287,8 @@ class ImageClusters(LabelledDataset):
 
     def get_label(self, i, toint=False):
         label = self.labels[i]
-        if toint: label = self.cls_idx[ label ]
+        if toint:
+            label = self.cls_idx[label]
         return label
 
 
diff --git a/dirtorch/test_dir.py b/dirtorch/test_dir.py
index e67663eae5f9bb1e7b90e5bfa14baa8f952ff486..31df69033c9b8be01e8e91bc0214f7654e109a01 100644
--- a/dirtorch/test_dir.py
+++ b/dirtorch/test_dir.py
@@ -96,7 +96,7 @@ def extract_image_features(dataset, transforms, net, ret_imgs=False, same_size=F
 
 def eval_model(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=None,
                aqe=None, adba=None, threads=8, batch_size=16, save_feats=None,
-               load_feats=None, load_distractors=None, dbg=()):
+               load_feats=None, dbg=()):
     """ Evaluate a trained model (network) on a given dataset.
     The dataset is supposed to contain the evaluation code.
     """
@@ -122,20 +122,16 @@ def eval_model(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=Non
         qdescs = F.normalize(pool(qdescs, pooling, gemp), p=2, dim=1)
     else:
         bdescs = np.load(os.path.join(load_feats, 'feats.bdescs.npy'))
-        qdescs = np.load(os.path.join(load_feats, 'feats.qdescs.npy'))
+        if query_db is not db:
+            qdescs = np.load(os.path.join(load_feats, 'feats.qdescs.npy'))
+        else:
+            qdescs = bdescs
 
     if save_feats:
-        mkdir(save_feats, isfile=True)
-        np.save(save_feats+'.bdescs', bdescs.cpu().numpy())
+        mkdir(save_feats)
+        np.save(os.path.join(save_feats, 'feats.bdescs.npy'), bdescs.cpu().numpy())
         if query_db is not db:
-            np.save(save_feats+'.qdescs', qdescs.cpu().numpy())
-        exit()
-
-    if load_distractors:
-        ddescs = [np.load(os.path.join(load_distractors, '%d.bdescs.npy' % i))
-                  for i in tqdm.tqdm(range(0, 1000), 'Distractors')]
-        bdescs = np.concatenate([tonumpy(bdescs)] + ddescs)
-        qdescs = tonumpy(qdescs)
+            np.save(os.path.join(save_feats, 'feats.qdescs.npy'), qdescs.cpu().numpy())
 
     if whiten is not None:
         bdescs = common.whiten_features(tonumpy(bdescs), net.pca, **whiten)
@@ -169,7 +165,6 @@ def eval_model(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=Non
                     res['APs'+'-'+mode] = apst
                 # Queries with no relevants have an AP of -1
                 res['mAP'+'-'+mode] = float(np.mean([e for e in apst if e >= 0]))
-
     except NotImplementedError:
         print(" AP not implemented!")
 
@@ -210,7 +205,6 @@ if __name__ == '__main__':
     parser.add_argument('--out-json', type=str, default="", help='path to output json')
     parser.add_argument('--detailed', action='store_true', help='return detailed evaluation')
     parser.add_argument('--save-feats', type=str, default="", help='path to output features')
-    parser.add_argument('--load-distractors', type=str, default="", help='path to load distractors from')
     parser.add_argument('--load-feats', type=str, default="", help='path to load features from')
 
     parser.add_argument('--threads', type=int, default=8, help='number of thread workers')
@@ -250,8 +244,7 @@ if __name__ == '__main__':
     # Evaluate
     res = eval_model(dataset, net, args.trfs, pooling=args.pooling, gemp=args.gemp, detailed=args.detailed,
                      threads=args.threads, dbg=args.dbg, whiten=args.whiten, aqe=args.aqe, adba=args.adba,
-                     save_feats=args.save_feats, load_feats=args.load_feats,
-                     load_distractors=args.load_distractors)
+                     save_feats=args.save_feats, load_feats=args.load_feats)
     print(' * ' + '\n * '.join(['%s = %g' % p for p in res.items()]))
 
     if args.out_json:
@@ -264,6 +257,3 @@ if __name__ == '__main__':
         mkdir(args.out_json)
         open(args.out_json, 'w').write(json.dumps(data, indent=1))
         print("saved to "+args.out_json)
-
-
-
diff --git a/dirtorch/utils/evaluation.py b/dirtorch/utils/evaluation.py
index c4db204c6f55d38b4e8f6a038ff037d736c31f6e..040d79d3fd23fef32464090421a08ffa6902b7c2 100644
--- a/dirtorch/utils/evaluation.py
+++ b/dirtorch/utils/evaluation.py
@@ -7,11 +7,11 @@ import torch
 
 def accuracy_topk(output, target, topk=(1,)):
     """Computes the precision@k for the specified values of k
-    
+
     output: torch.FloatTensoror np.array(float)
             shape = B * L [* H * W]
             L: number of possible labels
-    
+
     target: torch.IntTensor or np.array(int)
             shape = B     [* H * W]
             ground-truth labels
@@ -20,34 +20,74 @@ def accuracy_topk(output, target, topk=(1,)):
         pred = (-output).argsort(axis=1)
         target = np.expand_dims(target, axis=1)
         correct = (pred == target)
-        
+
         res = []
         for k in topk:
-            correct_k = correct[:,:k].sum()
+            correct_k = correct[:, :k].sum()
             res.append(correct_k / target.size)
-    
+
     if isinstance(output, torch.Tensor):
         _, pred = output.topk(max(topk), 1, True, True)
         correct = pred.eq(target.unsqueeze(1))
 
         res = []
         for k in topk:
-            correct_k = correct[:,:k].float().view(-1).sum(0)
+            correct_k = correct[:, :k].float().view(-1).sum(0)
             res.append(correct_k.mul_(1.0 / target.numel()))
 
     return res
 
 
-
 def compute_AP(label, score):
     from sklearn.metrics import average_precision_score
     return average_precision_score(label, score)
 
 
+def compute_average_precision(positive_ranks):
+    """
+    Extracted from: https://github.com/tensorflow/models/blob/master/research/delf/delf/python/detect_to_retrieve/dataset.py
+
+    Computes average precision according to dataset convention.
+    It assumes that `positive_ranks` contains the ranks for all expected positive
+    index images to be retrieved. If `positive_ranks` is empty, returns
+    `average_precision` = 0.
+    Note that average precision computation here does NOT use the finite sum
+    method (see
+    https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision)
+    which is common in information retrieval literature. Instead, the method
+    implemented here integrates over the precision-recall curve by averaging two
+    adjacent precision points, then multiplying by the recall step. This is the
+    convention for the Revisited Oxford/Paris datasets.
+    Args:
+        positive_ranks: Sorted 1D NumPy integer array, zero-indexed.
+    Returns:
+        average_precision: Float.
+    """
+    average_precision = 0.0
+
+    num_expected_positives = len(positive_ranks)
+    if not num_expected_positives:
+        return average_precision
+
+    recall_step = 1.0 / num_expected_positives
+    for i, rank in enumerate(positive_ranks):
+        if not rank:
+            left_precision = 1.0
+        else:
+            left_precision = i / rank
+
+        right_precision = (i + 1) / (rank + 1)
+        average_precision += (left_precision + right_precision) * recall_step / 2
+
+    return average_precision
+
+
 def compute_average_precision_quantized(labels, idx, step=0.01):
     recall_checkpoints = np.arange(0, 1, step)
+
     def mymax(x, default):
         return np.max(x) if len(x) else default
+
     Nrel = np.sum(labels)
     if Nrel == 0:
         return 0
@@ -58,8 +98,7 @@ def compute_average_precision_quantized(labels, idx, step=0.01):
     return np.mean(precs)
 
 
-
 def pixelwise_iou(output, target):
-    """ For each image, for each label, compute the IoU between 
+    """ For each image, for each label, compute the IoU between
     """
     assert False