From 263e089ecc1181de993ec650ddc9e564537bea9b Mon Sep 17 00:00:00 2001
From: Jon Almazan <jon.almazan@gmail.com>
Date: Thu, 8 Aug 2019 11:54:16 +0100
Subject: [PATCH] fix net.pca bug

---
 dirtorch/test_dir.py | 53 +++++++++-----------------------------------
 1 file changed, 10 insertions(+), 43 deletions(-)

diff --git a/dirtorch/test_dir.py b/dirtorch/test_dir.py
index e01ac3f..4c986c6 100644
--- a/dirtorch/test_dir.py
+++ b/dirtorch/test_dir.py
@@ -19,6 +19,7 @@ import dirtorch.datasets.downloader as dl
 import pickle as pkl
 import hashlib
 
+
 def hash(x):
     m = hashlib.md5()
     m.update(str(x).encode('utf-8'))
@@ -84,12 +85,12 @@ def extract_image_features( dataset, transforms, net, ret_imgs=False, same_size=
 
     img_feats = []
     trf_images = []
-    with torch.no_grad(): # important to put it outside!
+    with torch.no_grad():
         for inputs in tqdm.tqdm(loader, desc, total=1+(len(dataset)-1)//batch_size):
             imgs = inputs[0]
             for i in range(len(imgs)):
                 if flip and flip.pop(0):
-                    imgs[i] = imgs[i].flip(2) # flip this image horizontally!
+                    imgs[i] = imgs[i].flip(2)
             imgs = common.variables(inputs[:1], net.iscuda)[0]
             desc = net(imgs)
             if ret_imgs: trf_images.append( tocpu(imgs.detach()) )
@@ -99,7 +100,7 @@ def extract_image_features( dataset, transforms, net, ret_imgs=False, same_size=
             img_feats.append( desc.detach() )
 
     img_feats = torch.cat(img_feats, dim=0)
-    if len(img_feats.shape) == 1: img_feats.unsqueeze_(0) # atleast_2d
+    if len(img_feats.shape) == 1: img_feats.unsqueeze_(0)
 
     if not same_size:
         torch.backends.cudnn.benchmark = old_benchmark
@@ -197,25 +198,6 @@ def eval_model(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=Non
                 if detailed: res['APs'+'-'+mode] = apst
                 res['mAP'+'-'+mode] = float(np.mean([e for e in apst if e>=0])) # Queries with no relevants have an AP of -1
 
-        if 'ap' in dbg:
-            pdb.set_trace()
-            pyplot(globals())
-            for query in np.argsort(aps):
-                subplot_grid(20, 1)
-                pl.imshow(query_db.get_image(query))
-                qlabel = query_db.get_label(query)
-                pl.xlabel('#%d %s' % (query, qlabel))
-                pl_noticks()
-                ranked = np.argsort(scores[query])[::-1]
-                gt = db.get_query_groundtruth(query)[ranked]
-
-                for i,idx in enumerate(ranked):
-                    if i+2 > 20: break
-                    subplot_grid(20, i+2)
-                    pl.imshow(db.get_image(idx))
-                    pl.xlabel('#%d %s %g' % (idx, 'OK' if label==qlabel else 'BAD', scores[query,idx]))
-                    pl_noticks()
-            pdb.set_trace()
     except NotImplementedError:
         print(" AP not implemented!")
 
@@ -236,7 +218,7 @@ def load_model( path, iscuda ):
     net = common.switch_model_to_cuda(net, iscuda, checkpoint)
     net.load_state_dict(checkpoint['state_dict'])
     net.preprocess = checkpoint.get('preprocess', net.preprocess)
-    if 'pca' in checkpoint: net.pca = checkpoint.get('pca', net.pca)
+    if 'pca' in checkpoint: net.pca = checkpoint.get('pca')
     return net
 
 
@@ -263,7 +245,6 @@ if __name__ == '__main__':
     parser.add_argument('--trfs', type=str, required=False, default='', nargs='+', help='test transforms (can be several)')
     parser.add_argument('--pooling', type=str, default="gem", help='pooling scheme if several trf chains')
     parser.add_argument('--gemp', type=int, default=3, help='GeM pooling power')
-    parser.add_argument('--center-bias', type=float, default=0, help='enforce some center bias')
 
     parser.add_argument('--out-json', type=str, default="", help='path to output json')
     parser.add_argument('--detailed', action='store_true', help='return detailed evaluation')
@@ -296,26 +277,12 @@ if __name__ == '__main__':
 
     net = load_model(args.checkpoint, args.iscuda)
 
-    if args.center_bias:
-        assert hasattr(net,'center_bias')
-        net.center_bias = args.center_bias
-        if hasattr(net, 'module') and hasattr(net.module,'center_bias'):
-            net.module.center_bias = args.center_bias
-
-    if args.whiten and not hasattr(net, 'pca'):
-        # Learn PCA if necessary
-        if os.path.exists(args.whiten):
-            with open(args.whiten, 'rb') as f:
-                net.pca = pkl.load(f)
-        else:
-            pca_path = '_'.join([args.checkpoint, args.whiten, args.pooling, hash(args.trfs), 'pca.pkl'])
-            db = datasets.create(args.whiten)
-            print('Dataset for learning the PCA with whitening:', db)
-            net.pca = learn_whiten(db, net, pooling=args.pooling, trfs=args.trfs, threads=args.threads)
-            with open(pca_path, 'wb') as f:
-                pkl.dump(net.pca, f)
-
+    if args.whiten:
+        net.pca = net.pca[args.whiten]
         args.whiten = {'whitenp': args.whitenp, 'whitenv': args.whitenv, 'whitenm': args.whitenm}
+    else:
+        net.pca = None
+        args.whiten = None
 
     # Evaluate
     res = eval_model(dataset, net, args.trfs, pooling=args.pooling, gemp=args.gemp, detailed=args.detailed,
-- 
GitLab