From 4a5694867a5d0f37df7a4394ac8f0d97ddb67a16 Mon Sep 17 00:00:00 2001
From: Jon Almazan <jon.almazan@gmail.com>
Date: Fri, 9 Aug 2019 12:44:39 +0100
Subject: [PATCH] re-factor code

---
 dirtorch/extract_features.py |  94 ++++-------------------
 dirtorch/test_dir.py         | 144 +++++++++++++----------------------
 dirtorch/utils/common.py     | 104 ++++++++++++++++++-------
 3 files changed, 147 insertions(+), 195 deletions(-)

diff --git a/dirtorch/extract_features.py b/dirtorch/extract_features.py
index 521ef0a..c57ce5e 100644
--- a/dirtorch/extract_features.py
+++ b/dirtorch/extract_features.py
@@ -1,5 +1,5 @@
 import sys
-import os; os.umask(7)  # group permisions but that's all
+import os
 import os.path as osp
 import pdb
 
@@ -11,48 +11,20 @@ import torch.nn.functional as F
 
 from dirtorch.utils.convenient import mkdir
 from dirtorch.utils import common
+from dirtorch.utils.common import tonumpy, matmul, pool
 from dirtorch.utils.pytorch_loader import get_loader
 
 import dirtorch.test_dir as test
 import dirtorch.nets as nets
 import dirtorch.datasets as datasets
+import dirtorch.datasets.downloader as dl
 
 import pickle as pkl
 import hashlib
 
-def hash(x):
-    m = hashlib.md5()
-    m.update(str(x).encode('utf-8'))
-    return m.hexdigest()
-
-def typename(x):
-    return type(x).__module__
-
-def tonumpy(x):
-    if typename(x) == torch.__name__:
-        return x.cpu().numpy()
-    else:
-        return x
-
-
-def pool(x, pooling='mean', gemp=3):
-    if len(x) == 1: return x[0]
-    x = torch.stack(x, dim=0)
-    if pooling == 'mean':
-        return torch.mean(x, dim=0)
-    elif pooling == 'gem':
-        def sympow(x, p, eps=1e-6):
-            s = torch.sign(x)
-            return (x*s).clamp(min=eps).pow(p) * s
-        x = sympow(x,gemp)
-        x = torch.mean(x, dim=0)
-        return sympow(x, 1/gemp)
-    else:
-        raise ValueError("Bad pooling mode: "+str(pooling))
-
 
 def extract_features(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=None,
-               threads=8, batch_size=16, output=None, dbg=()):
+                     threads=8, batch_size=16, output=None, dbg=()):
     """ Extract features from trained model (network) on a given dataset.
     """
     print("\n>> Extracting features...")
@@ -69,11 +41,12 @@ def extract_features(db, net, trfs, pooling='mean', gemp=3, detailed=False, whit
 
     for trfs in trfs_list:
         kw = dict(iscuda=net.iscuda, threads=threads, batch_size=batch_size, same_size='Pad' in trfs or 'Crop' in trfs)
-        bdescs.append( test.extract_image_features(db, trfs, net, desc="DB", **kw) )
+        bdescs.append(test.extract_image_features(db, trfs, net, desc="DB", **kw))
 
         # extract query feats
         if query_db is not None:
-            qdescs.append( bdescs[-1] if db is query_db else test.extract_image_features(query_db, trfs, net, desc="query", **kw) )
+            qdescs.append(bdescs[-1] if db is query_db
+                          else test.extract_image_features(query_db, trfs, net, desc="query", **kw))
 
     # pool from multiple transforms (scales)
     bdescs = tonumpy(F.normalize(pool(bdescs, pooling, gemp), p=2, dim=1))
@@ -95,31 +68,17 @@ def extract_features(db, net, trfs, pooling='mean', gemp=3, detailed=False, whit
     print('Features extracted.')
 
 
-def load_model( path, iscuda, whiten=None ):
+def load_model(path, iscuda):
     checkpoint = common.load_checkpoint(path, iscuda)
     net = nets.create_model(pretrained="", **checkpoint['model_options'])
     net = common.switch_model_to_cuda(net, iscuda, checkpoint)
     net.load_state_dict(checkpoint['state_dict'])
     net.preprocess = checkpoint.get('preprocess', net.preprocess)
-    if whiten is not None and 'pca' in checkpoint:
-        if whiten in checkpoint['pca']:
-            net.pca = checkpoint['pca'][whiten]
+    if 'pca' in checkpoint:
+        net.pca = checkpoint.get('pca')
     return net
 
 
-def learn_whiten( dataset, net, trfs='', pooling='mean', threads=8, batch_size=16):
-    descs = []
-    trfs_list = [trfs] if isinstance(trfs, str) else trfs
-    for trfs in trfs_list:
-        kw = dict(iscuda=net.iscuda, threads=threads, batch_size=batch_size, same_size='Pad' in trfs or 'Crop' in trfs)
-        descs.append( extract_image_features(dataset, trfs, net, desc="PCA", **kw) )
-    # pool from multiple transforms (scales)
-    descs = F.normalize(pool(descs, pooling), p=2, dim=1)
-    # learn pca with whiten
-    pca = common.learn_pca(descs.cpu().numpy(), whiten=True)
-    return pca
-
-
 if __name__ == '__main__':
     import argparse
     parser = argparse.ArgumentParser(description='Evaluate a model')
@@ -130,7 +89,6 @@ if __name__ == '__main__':
     parser.add_argument('--trfs', type=str, required=False, default='', nargs='+', help='test transforms (can be several)')
     parser.add_argument('--pooling', type=str, default="gem", help='pooling scheme if several trf chains')
     parser.add_argument('--gemp', type=int, default=3, help='GeM pooling power')
-    parser.add_argument('--center-bias', type=float, default=0, help='enforce some center bias')
 
     parser.add_argument('--out-json', type=str, default="", help='path to output json')
     parser.add_argument('--detailed', action='store_true', help='return detailed evaluation')
@@ -152,37 +110,17 @@ if __name__ == '__main__':
     dataset = datasets.create(args.dataset)
     print("Dataset:", dataset)
 
-    net = load_model(args.checkpoint, args.iscuda, args.whiten)
-
-    if args.center_bias:
-        assert hasattr(net,'center_bias')
-        net.center_bias = args.center_bias
-        if hasattr(net, 'module') and hasattr(net.module,'center_bias'):
-            net.module.center_bias = args.center_bias
-
-    if args.whiten and not hasattr(net, 'pca'):
-        # Learn PCA if necessary
-        if os.path.exists(args.whiten):
-            with open(args.whiten, 'rb') as f:
-                net.pca = pkl.load(f)
-        else:
-            pca_path = '_'.join([args.checkpoint, args.whiten, args.pooling, hash(args.trfs), 'pca.pkl'])
-            db = datasets.create(args.whiten)
-            print('Dataset for learning the PCA with whitening:', db)
-            pca = learn_whiten(db, net, pooling=args.pooling, trfs=args.trfs, threads=args.threads)
-
-            chk = torch.load(args.checkpoint, map_location=lambda storage, loc: storage)
-            if 'pca' not in chk: chk['pca'] = {}
-            chk['pca'][args.whiten] = pca
-            torch.save(chk, args.checkpoint)
-
-            net.pca = pca
+    net = load_model(args.checkpoint, args.iscuda)
 
     if args.whiten:
+        net.pca = net.pca[args.whiten]
         args.whiten = {'whitenp': args.whitenp, 'whitenv': args.whitenv, 'whitenm': args.whitenm}
+    else:
+        net.pca = None
+        args.whiten = None
 
     # Evaluate
     res = extract_features(dataset, net, args.trfs, pooling=args.pooling, gemp=args.gemp, detailed=args.detailed,
-        threads=args.threads, dbg=args.dbg, whiten=args.whiten, output=args.output)
+                           threads=args.threads, dbg=args.dbg, whiten=args.whiten, output=args.output)
 
 
diff --git a/dirtorch/test_dir.py b/dirtorch/test_dir.py
index 4c986c6..e67663e 100644
--- a/dirtorch/test_dir.py
+++ b/dirtorch/test_dir.py
@@ -1,5 +1,5 @@
 import sys
-import os; os.umask(7)  # group permisions but that's all
+import os
 import os.path as osp
 import pdb
 
@@ -11,6 +11,7 @@ import torch.nn.functional as F
 
 from dirtorch.utils.convenient import mkdir
 from dirtorch.utils import common
+from dirtorch.utils.common import tonumpy, matmul, pool
 from dirtorch.utils.pytorch_loader import get_loader
 import dirtorch.nets as nets
 import dirtorch.datasets as datasets
@@ -20,33 +21,10 @@ import pickle as pkl
 import hashlib
 
 
-def hash(x):
-    m = hashlib.md5()
-    m.update(str(x).encode('utf-8'))
-    return m.hexdigest()
-
-def typename(x):
-    return type(x).__module__
-
-def tonumpy(x):
-    if typename(x) == torch.__name__:
-        return x.cpu().numpy()
-    else:
-        return x
-
-def matmul(A, B):
-    if typename(A) == np.__name__:
-        B = tonumpy(B)
-        scores = np.dot(A, B.T)
-    elif typename(B) == torch.__name__:
-        scores = torch.matmul(A, B.t()).cpu().numpy()
-    else:
-        raise TypeError("matrices must be either numpy or torch type")
-    return scores
-
 def expand_descriptors(descs, db=None, alpha=0, k=0):
     assert k >= 0 and alpha >= 0, 'k and alpha must be non-negative'
-    if k == 0: return descs
+    if k == 0:
+        return descs
     descs = tonumpy(descs)
     n = descs.shape[0]
     db_descs = tonumpy(db if db is not None else descs)
@@ -58,30 +36,31 @@ def expand_descriptors(descs, db=None, alpha=0, k=0):
     idx = np.argpartition(sim, int(-k), axis=1)[:, int(-k):]
     descs_aug = np.zeros_like(descs)
     for i in range(n):
-        new_q = np.vstack([db_descs[j, :] * sim[i,j]**alpha for j in idx[i]])
+        new_q = np.vstack([db_descs[j, :] * sim[i, j]**alpha for j in idx[i]])
         new_q = np.vstack([descs[i], new_q])
         new_q = np.mean(new_q, axis=0)
         descs_aug[i] = new_q / np.linalg.norm(new_q)
 
     return descs_aug
 
-def extract_image_features( dataset, transforms, net, ret_imgs=False, same_size=False, flip=None,
-                            desc="Extract feats...", iscuda=True, threads=8, batch_size=8):
+
+def extract_image_features(dataset, transforms, net, ret_imgs=False, same_size=False, flip=None,
+                           desc="Extract feats...", iscuda=True, threads=8, batch_size=8):
     """ Extract image features for a given dataset.
         Output is 2-dimensional (B, D)
     """
     if not same_size:
         batch_size = 1
-        old_benchmark = torch.backends.cudnn.benchmark # speed-up cudnn
-        torch.backends.cudnn.benchmark = False  # will speed-up a lot for different image sizes
+        old_benchmark = torch.backends.cudnn.benchmark
+        torch.backends.cudnn.benchmark = False
 
-    loader = get_loader( dataset, trf_chain=transforms, preprocess=net.preprocess, iscuda=iscuda,
-                         output=['img'], batch_size=batch_size, threads=threads,
-                         shuffle=False) # VERY IMPORTANT !!!!!!
+    loader = get_loader(dataset, trf_chain=transforms, preprocess=net.preprocess, iscuda=iscuda,
+                        output=['img'], batch_size=batch_size, threads=threads, shuffle=False)
 
-    if hasattr(net,'eval'): net.eval()
+    if hasattr(net, 'eval'):
+        net.eval()
 
-    tocpu = (lambda x: x.cpu()) if ret_imgs=='cpu' else (lambda x:x)
+    tocpu = (lambda x: x.cpu()) if ret_imgs == 'cpu' else (lambda x: x)
 
     img_feats = []
     trf_images = []
@@ -93,40 +72,28 @@ def extract_image_features( dataset, transforms, net, ret_imgs=False, same_size=
                     imgs[i] = imgs[i].flip(2)
             imgs = common.variables(inputs[:1], net.iscuda)[0]
             desc = net(imgs)
-            if ret_imgs: trf_images.append( tocpu(imgs.detach()) )
+            if ret_imgs:
+                trf_images.append(tocpu(imgs.detach()))
             del imgs
             del inputs
-            if len(desc.shape) == 1: desc.unsqueeze_(0)
-            img_feats.append( desc.detach() )
+            if len(desc.shape) == 1:
+                desc.unsqueeze_(0)
+            img_feats.append(desc.detach())
 
     img_feats = torch.cat(img_feats, dim=0)
-    if len(img_feats.shape) == 1: img_feats.unsqueeze_(0)
+    if len(img_feats.shape) == 1:
+        img_feats.unsqueeze_(0)
 
     if not same_size:
         torch.backends.cudnn.benchmark = old_benchmark
 
     if ret_imgs:
-        if same_size: trf_images = torch.cat(trf_images, dim=0)
+        if same_size:
+            trf_images = torch.cat(trf_images, dim=0)
         return trf_images, img_feats
     return img_feats
 
 
-def pool(x, pooling='mean', gemp=3):
-    if len(x) == 1: return x[0]
-    x = torch.stack(x, dim=0)
-    if pooling == 'mean':
-        return torch.mean(x, dim=0)
-    elif pooling == 'gem':
-        def sympow(x, p, eps=1e-6):
-            s = torch.sign(x)
-            return (x*s).clamp(min=eps).pow(p) * s
-        x = sympow(x,gemp)
-        x = torch.mean(x, dim=0)
-        return sympow(x, 1/gemp)
-    else:
-        raise ValueError("Bad pooling mode: "+str(pooling))
-
-
 def eval_model(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=None,
                aqe=None, adba=None, threads=8, batch_size=16, save_feats=None,
                load_feats=None, load_distractors=None, dbg=()):
@@ -145,10 +112,10 @@ def eval_model(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=Non
 
         for trfs in trfs_list:
             kw = dict(iscuda=net.iscuda, threads=threads, batch_size=batch_size, same_size='Pad' in trfs or 'Crop' in trfs)
-            bdescs.append( extract_image_features(db, trfs, net, desc="DB", **kw) )
+            bdescs.append(extract_image_features(db, trfs, net, desc="DB", **kw))
 
             # extract query feats
-            qdescs.append( bdescs[-1] if db is query_db else extract_image_features(query_db, trfs, net, desc="query", **kw) )
+            qdescs.append(bdescs[-1] if db is query_db else extract_image_features(query_db, trfs, net, desc="query", **kw))
 
         # pool from multiple transforms (scales)
         bdescs = F.normalize(pool(bdescs, pooling, gemp), p=2, dim=1)
@@ -165,9 +132,10 @@ def eval_model(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=Non
         exit()
 
     if load_distractors:
-        ddescs = [ np.load(os.path.join(load_distractors, '%d.bdescs.npy' % i)) for i in tqdm.tqdm(range(0,1000), 'Distractors') ]
+        ddescs = [np.load(os.path.join(load_distractors, '%d.bdescs.npy' % i))
+                  for i in tqdm.tqdm(range(0, 1000), 'Distractors')]
         bdescs = np.concatenate([tonumpy(bdescs)] + ddescs)
-        qdescs = tonumpy(qdescs) # so matmul below can work
+        qdescs = tonumpy(qdescs)
 
     if whiten is not None:
         bdescs = common.whiten_features(tonumpy(bdescs), net.pca, **whiten)
@@ -186,55 +154,48 @@ def eval_model(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=Non
     res = {}
 
     try:
-        aps = [db.eval_query_AP(q, s) for q,s in enumerate(tqdm.tqdm(scores,desc='AP'))]
+        aps = [db.eval_query_AP(q, s) for q, s in enumerate(tqdm.tqdm(scores, desc='AP'))]
         if not isinstance(aps[0], dict):
             aps = [float(e) for e in aps]
-            if detailed: res['APs'] = aps
-            res['mAP'] = float(np.mean([e for e in aps if e>=0])) # Queries with no relevants have an AP of -1
+            if detailed:
+                res['APs'] = aps
+            # Queries with no relevants have an AP of -1
+            res['mAP'] = float(np.mean([e for e in aps if e >= 0]))
         else:
             modes = aps[0].keys()
             for mode in modes:
                 apst = [float(e[mode]) for e in aps]
-                if detailed: res['APs'+'-'+mode] = apst
-                res['mAP'+'-'+mode] = float(np.mean([e for e in apst if e>=0])) # Queries with no relevants have an AP of -1
+                if detailed:
+                    res['APs'+'-'+mode] = apst
+                # Queries with no relevants have an AP of -1
+                res['mAP'+'-'+mode] = float(np.mean([e for e in apst if e >= 0]))
 
     except NotImplementedError:
         print(" AP not implemented!")
 
     try:
-        tops = [db.eval_query_top(q,s) for q,s in enumerate(tqdm.tqdm(scores,desc='top1'))]
-        if detailed: res['tops'] = tops
+        tops = [db.eval_query_top(q, s) for q, s in enumerate(tqdm.tqdm(scores, desc='top1'))]
+        if detailed:
+            res['tops'] = tops
         for k in tops[0]:
-            res['top%d'%k] = float(np.mean([top[k] for top in tops]))
+            res['top%d' % k] = float(np.mean([top[k] for top in tops]))
     except NotImplementedError:
         pass
 
     return res
 
 
-def load_model( path, iscuda ):
+def load_model(path, iscuda):
     checkpoint = common.load_checkpoint(path, iscuda)
     net = nets.create_model(pretrained="", **checkpoint['model_options'])
     net = common.switch_model_to_cuda(net, iscuda, checkpoint)
     net.load_state_dict(checkpoint['state_dict'])
     net.preprocess = checkpoint.get('preprocess', net.preprocess)
-    if 'pca' in checkpoint: net.pca = checkpoint.get('pca')
+    if 'pca' in checkpoint:
+        net.pca = checkpoint.get('pca')
     return net
 
 
-def learn_whiten( dataset, net, trfs='', pooling='mean', threads=8, batch_size=16):
-    descs = []
-    trfs_list = [trfs] if isinstance(trfs, str) else trfs
-    for trfs in trfs_list:
-        kw = dict(iscuda=net.iscuda, threads=threads, batch_size=batch_size, same_size='Pad' in trfs or 'Crop' in trfs)
-        descs.append( extract_image_features(dataset, trfs, net, desc="PCA", **kw) )
-    # pool from multiple transforms (scales)
-    descs = F.normalize(pool(descs, pooling), p=2, dim=1)
-    # learn pca with whiten
-    pca = common.learn_pca(descs.cpu().numpy(), whiten=True)
-    return pca
-
-
 if __name__ == '__main__':
     import argparse
     parser = argparse.ArgumentParser(description='Evaluate a model')
@@ -267,8 +228,10 @@ if __name__ == '__main__':
 
     args = parser.parse_args()
     args.iscuda = common.torch_set_gpu(args.gpu)
-    if args.aqe is not None: args.aqe = {'k': args.aqe[0], 'alpha': args.aqe[1]}
-    if args.adba is not None: args.adba = {'k': args.adba[0], 'alpha': args.adba[1]}
+    if args.aqe is not None:
+        args.aqe = {'k': args.aqe[0], 'alpha': args.aqe[1]}
+    if args.adba is not None:
+        args.adba = {'k': args.adba[0], 'alpha': args.adba[1]}
 
     dl.download_dataset(args.dataset)
 
@@ -286,9 +249,10 @@ if __name__ == '__main__':
 
     # Evaluate
     res = eval_model(dataset, net, args.trfs, pooling=args.pooling, gemp=args.gemp, detailed=args.detailed,
-        threads=args.threads, dbg=args.dbg, whiten=args.whiten,  aqe=args.aqe, adba=args.adba,
-        save_feats=args.save_feats, load_feats=args.load_feats, load_distractors=args.load_distractors)
-    print(' * '+'\n * '.join(['%s = %g'%p for p in res.items()]))
+                     threads=args.threads, dbg=args.dbg, whiten=args.whiten, aqe=args.aqe, adba=args.adba,
+                     save_feats=args.save_feats, load_feats=args.load_feats,
+                     load_distractors=args.load_distractors)
+    print(' * ' + '\n * '.join(['%s = %g' % p for p in res.items()]))
 
     if args.out_json:
         # write to file
@@ -298,7 +262,7 @@ if __name__ == '__main__':
             data = {}
         data[args.dataset] = res
         mkdir(args.out_json)
-        open(args.out_json,'w').write(json.dumps(data, indent=1))
+        open(args.out_json, 'w').write(json.dumps(data, indent=1))
         print("saved to "+args.out_json)
 
 
diff --git a/dirtorch/utils/common.py b/dirtorch/utils/common.py
index 3d300f6..cf03917 100644
--- a/dirtorch/utils/common.py
+++ b/dirtorch/utils/common.py
@@ -6,6 +6,9 @@ from collections import OrderedDict
 import numpy as np
 import sklearn.decomposition
 
+import torch
+import torch.nn.functional as F
+
 try:
     import torch
     import torch.nn as nn
@@ -13,13 +16,52 @@ except ImportError:
     pass
 
 
+def typename(x):
+    return type(x).__module__
+
+
+def tonumpy(x):
+    if typename(x) == torch.__name__:
+        return x.cpu().numpy()
+    else:
+        return x
+
+
+def matmul(A, B):
+    if typename(A) == np.__name__:
+        B = tonumpy(B)
+        scores = np.dot(A, B.T)
+    elif typename(B) == torch.__name__:
+        scores = torch.matmul(A, B.t()).cpu().numpy()
+    else:
+        raise TypeError("matrices must be either numpy or torch type")
+    return scores
+
+
+def pool(x, pooling='mean', gemp=3):
+    if len(x) == 1:
+        return x[0]
+    x = torch.stack(x, dim=0)
+    if pooling == 'mean':
+        return torch.mean(x, dim=0)
+    elif pooling == 'gem':
+        def sympow(x, p, eps=1e-6):
+            s = torch.sign(x)
+            return (x*s).clamp(min=eps).pow(p) * s
+        x = sympow(x, gemp)
+        x = torch.mean(x, dim=0)
+        return sympow(x, 1/gemp)
+    else:
+        raise ValueError("Bad pooling mode: "+str(pooling))
+
+
 def torch_set_gpu(gpus, seed=None, randomize=True):
     if type(gpus) is int:
             gpus = [gpus]
 
     assert gpus, 'error: empty gpu list, use --gpu N N ...'
 
-    cuda = all(gpu>=0 for gpu in gpus)
+    cuda = all(gpu >= 0 for gpu in gpus)
 
     if cuda:
         if any(gpu >= 1000 for gpu in gpus):
@@ -28,12 +70,12 @@ def torch_set_gpu(gpus, seed=None, randomize=True):
         else:
             os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in gpus])
         assert cuda and torch.cuda.is_available(), "%s has GPUs %s unavailable" % (
-            os.environ['HOSTNAME'],os.environ['CUDA_VISIBLE_DEVICES'])
-        torch.backends.cudnn.benchmark = True # speed-up cudnn
-        torch.backends.cudnn.fastest = True # even more speed-up?
-        print( 'Launching on GPUs ' + os.environ['CUDA_VISIBLE_DEVICES'] )
+            os.environ['HOSTNAME'], os.environ['CUDA_VISIBLE_DEVICES'])
+        torch.backends.cudnn.benchmark = True
+        torch.backends.cudnn.fastest = True
+        print('Launching on GPUs ' + os.environ['CUDA_VISIBLE_DEVICES'])
     else:
-        print( 'Launching on >> CPU <<' )
+        print('Launching on >> CPU <<')
 
     torch_set_seed(seed, cuda, randomize=randomize)
     return cuda
@@ -58,19 +100,21 @@ def torch_set_seed(seed, cuda, randomize=True):
 def save_checkpoint(state, is_best, filename):
     try:
         dirs = os.path.split(filename)[0]
-        if not os.path.isdir(dirs): os.makedirs(dirs)
+        if not os.path.isdir(dirs):
+            os.makedirs(dirs)
         torch.save(state, filename)
         if is_best:
             filenamebest = filename+'.best'
             shutil.copyfile(filename, filenamebest)
             filename = filenamebest
-        print( "saving to "+filename )
+        print("saving to "+filename)
     except:
-        print( "Error: Could not save checkpoint at %s, skipping" % filename )
+        print("Error: Could not save checkpoint at %s, skipping" % filename)
 
 
 def load_checkpoint(filename, iscuda=False):
-    if not filename: return None
+    if not filename:
+        return None
     assert os.path.isfile(filename), "=> no checkpoint found at '%s'" % filename
     checkpoint = torch.load(filename, map_location=lambda storage, loc: storage)
     print("=> loading checkpoint '%s'" % filename, end='')
@@ -80,7 +124,7 @@ def load_checkpoint(filename, iscuda=False):
     print()
 
     new_dict = OrderedDict()
-    for k,v in list(checkpoint['state_dict'].items()):
+    for k, v in list(checkpoint['state_dict'].items()):
         if k.startswith('module.'):
             k = k[7:]
         new_dict[k] = v
@@ -93,10 +137,10 @@ def load_checkpoint(filename, iscuda=False):
                     if iscuda and torch.is_tensor(v):
                         state[k] = v.cuda()
         except RuntimeError as e:
-            print("RuntimeError:",e,"(machine %s, GPU %s)"%(
-                os.environ['HOSTNAME'],os.environ['CUDA_VISIBLE_DEVICES']),
+            print("RuntimeError:", e, "(machine %s, GPU %s)" % (
+                os.environ['HOSTNAME'], os.environ['CUDA_VISIBLE_DEVICES']),
                 file=sys.stderr)
-            sys.exit(1) # error
+            sys.exit(1)
 
     return checkpoint
 
@@ -104,14 +148,15 @@ def load_checkpoint(filename, iscuda=False):
 def switch_model_to_cuda(model, iscuda=True, checkpoint=None):
     if iscuda:
         if checkpoint:
-            checkpoint['state_dict'] = {'module.'+k:v for k,v in checkpoint['state_dict'].items()}
+            checkpoint['state_dict'] = {'module.' + k: v for k, v in checkpoint['state_dict'].items()}
         try:
             model = torch.nn.DataParallel(model)
 
             # copy attributes automatically
             for var in dir(model.module):
-                if var.startswith('_'): continue
-                val = getattr(model.module,var)
+                if var.startswith('_'):
+                    continue
+                val = getattr(model.module, var)
                 if isinstance(val, (bool, int, float, str, dict)) or \
                    (callable(val) and var.startswith('get_')):
                     setattr(model, var, val)
@@ -119,10 +164,10 @@ def switch_model_to_cuda(model, iscuda=True, checkpoint=None):
             model.cuda()
             model.isasync = True
         except RuntimeError as e:
-            print("RuntimeError:",e,"(machine %s, GPU %s)"%(
-                os.environ['HOSTNAME'],os.environ['CUDA_VISIBLE_DEVICES']),
+            print("RuntimeError:", e, "(machine %s, GPU %s)" % (
+                os.environ['HOSTNAME'], os.environ['CUDA_VISIBLE_DEVICES']),
                 file=sys.stderr)
-            sys.exit(1) # error
+            sys.exit(1)
 
     model.iscuda = iscuda
     return model
@@ -139,7 +184,8 @@ def model_size(model):
 
 def freeze_batch_norm(model, freeze=True, only_running=False):
     model.freeze_bn = bool(freeze)
-    if not freeze: return
+    if not freeze:
+        return
 
     for m in model.modules():
         if isinstance(m, nn.BatchNorm2d):
@@ -157,14 +203,16 @@ def variables(inputs, iscuda, not_on_gpu=[]):
     '''
     inputs_var = []
 
-    for i,x in enumerate(inputs):
-        if i not in not_on_gpu and not isinstance(x, (tuple,list)):
-            if iscuda: x = x.cuda(non_blocking=True)
+    for i, x in enumerate(inputs):
+        if i not in not_on_gpu and not isinstance(x, (tuple, list)):
+            if iscuda:
+                x = x.cuda(non_blocking=True)
             x = torch.autograd.Variable(x)
         inputs_var.append(x)
 
     return inputs_var
 
+
 def learn_pca(X, n_components=None, whiten=False, use_sklearn=True):
     ''' Learn Principal Component Analysis
 
@@ -180,16 +228,16 @@ def learn_pca(X, n_components=None, whiten=False, use_sklearn=True):
         pca = sklearn.decomposition.PCA(n_components=n_components, svd_solver='full', whiten=whiten)
         pca.fit(X)
     else:
-        fudge=1E-8
+        fudge = 1E-8
         means = np.mean(X, axis=0)
         X = X - means
 
         # get the covariance matrix
-        Xcov = np.dot(X.T,X)
+        Xcov = np.dot(X.T, X)
 
         # eigenvalue decomposition of the covariance matrix
         d, V = np.linalg.eigh(Xcov)
-        d[d<0] = fudge
+        d[d < 0] = fudge
 
         # a fudge factor can be used so that eigenvectors associated with
         # small eigenvalues do not get overamplified.
@@ -205,6 +253,7 @@ def learn_pca(X, n_components=None, whiten=False, use_sklearn=True):
 
     return pca
 
+
 def transform(pca, X, whitenp=0.5, whitenv=None, whitenm=1.0, use_sklearn=True):
     if use_sklearn:
         # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/decomposition/base.py#L99
@@ -218,6 +267,7 @@ def transform(pca, X, whitenp=0.5, whitenv=None, whitenm=1.0, use_sklearn=True):
         X_transformed = np.dot(X, pca['W'])
     return X_transformed
 
+
 def whiten_features(X, pca, l2norm=True, whitenp=0.5, whitenv=None, whitenm=1.0, use_sklearn=True):
     res = transform(pca, X, whitenp=whitenp, whitenv=whitenv, whitenm=whitenm, use_sklearn=use_sklearn)
     if l2norm:
-- 
GitLab