diff --git a/dirtorch/utils/common.py b/dirtorch/utils/common.py index cf03917bd4ae5130dfaa2e52ba348578d015049a..a4303495b05821dcb55acac348debdf5e5573c8a 100644 --- a/dirtorch/utils/common.py +++ b/dirtorch/utils/common.py @@ -189,12 +189,15 @@ def freeze_batch_norm(model, freeze=True, only_running=False): for m in model.modules(): if isinstance(m, nn.BatchNorm2d): - m.eval() # Eval mode freezes the running mean and std + # Eval mode freezes the running mean and std + m.eval() for param in m.named_parameters(): if only_running: - param[1].requires_grad = True # Weight and bias can be updated + # Weight and bias can be updated + param[1].requires_grad = True else: - param[1].requires_grad = False # Freeze the weight and bias + # Freeze the weight and bias + param[1].requires_grad = False def variables(inputs, iscuda, not_on_gpu=[]): @@ -213,47 +216,6 @@ def variables(inputs, iscuda, not_on_gpu=[]): return inputs_var -def learn_pca(X, n_components=None, whiten=False, use_sklearn=True): - ''' Learn Principal Component Analysis - - input: - X: input matrix with size samples x features - n_components: number of components to keep - whiten: applies feature whitening - - output: - PCA: weights and means of the PCA learned - ''' - if use_sklearn: - pca = sklearn.decomposition.PCA(n_components=n_components, svd_solver='full', whiten=whiten) - pca.fit(X) - else: - fudge = 1E-8 - means = np.mean(X, axis=0) - X = X - means - - # get the covariance matrix - Xcov = np.dot(X.T, X) - - # eigenvalue decomposition of the covariance matrix - d, V = np.linalg.eigh(Xcov) - d[d < 0] = fudge - - # a fudge factor can be used so that eigenvectors associated with - # small eigenvalues do not get overamplified. - D = np.diag(1. / np.sqrt(d+fudge)) - - # whitening matrix - W = np.dot(np.dot(V, D), V.T) - - # multiply by the whitening matrix - X_white = np.dot(X, W) - - pca = {'W': W, 'means': means} - - return pca - - def transform(pca, X, whitenp=0.5, whitenv=None, whitenm=1.0, use_sklearn=True): if use_sklearn: # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/decomposition/base.py#L99