diff --git a/dirtorch/loss.py b/dirtorch/loss.py
index 1b84555405c0a1c00d8477d1d829267aec1b74d3..756268e1d6602dec62ff94276328a3be84cfbc68 100644
--- a/dirtorch/loss.py
+++ b/dirtorch/loss.py
@@ -1,6 +1,7 @@
 import numpy as np
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 
 
 class APLoss (nn.Module):
@@ -24,37 +25,38 @@ class APLoss (nn.Module):
         self.max = max
         gap = max - min
         assert gap > 0
-        # init quantizer = non-learnable (fixed) convolution
+        # Initialize quantizer as non-trainable convolution
         self.quantizer = q = nn.Conv1d(1, 2*nq, kernel_size=1, bias=True)
         q.weight = nn.Parameter(q.weight.detach(), requires_grad=False)
         q.bias = nn.Parameter(q.bias.detach(), requires_grad=False)
         a = (nq-1) / gap
-        # first half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1)
+        # First half equal to lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1)
         q.weight[:nq] = -a
-        q.bias[:nq] = torch.from_numpy(a*min + np.arange(nq, 0, -1)) # b = 1 + a*(min+x)
-        # first half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1)
+        q.bias[:nq] = torch.from_numpy(a*min + np.arange(nq, 0, -1))  # b = 1 + a*(min+x)
+        # First half equal to lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1)
         q.weight[nq:] = a
-        q.bias[nq:] = torch.from_numpy(np.arange(2-nq, 2, 1) - a*min) # b = 1 - a*(min+x)
-        # first and last one are special: just horizontal straight line
+        q.bias[nq:] = torch.from_numpy(np.arange(2-nq, 2, 1) - a*min)  # b = 1 - a*(min+x)
+        # First and last one as a horizontal straight line
         q.weight[0] = q.weight[-1] = 0
         q.bias[0] = q.bias[-1] = 1
 
     def forward(self, x, label, qw=None, ret='1-mAP'):
-        assert x.shape == label.shape # N x M
+        assert x.shape == label.shape  # N x M
         N, M = x.shape
-        # quantize all predictions
+        # Quantize all predictions
         q = self.quantizer(x.unsqueeze(1))
-        q = torch.min(q[:,:self.nq], q[:,self.nq:]).clamp(min=0) # N x Q x M
+        q = torch.min(q[:, :self.nq], q[:, self.nq:]).clamp(min=0)  # N x Q x M
 
-        nbs = q.sum(dim=-1) # number of samples  N x Q = c
-        rec = (q * label.view(N,1,M).float()).sum(dim=-1) # number of correct samples = c+ N x Q
-        prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision
-        rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1]
+        nbs = q.sum(dim=-1)  # number of samples  N x Q = c
+        rec = (q * label.view(N, 1, M).float()).sum(dim=-1)  # number of correct samples = c+ N x Q
+        prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1))  # precision
+        rec /= rec.sum(dim=-1).unsqueeze(1)  # norm in [0,1]
 
-        ap = (prec * rec).sum(dim=-1) # per-image AP
+        ap = (prec * rec).sum(dim=-1)  # per-image AP
 
         if ret == '1-mAP':
-            if qw is not None: ap *= qw # query weights
+            if qw is not None:
+                ap *= qw  # query weights
             return 1 - ap.mean()
         elif ret == 'AP':
             assert qw is None
@@ -63,7 +65,8 @@ class APLoss (nn.Module):
             raise ValueError("Bad return type for APLoss(): %s" % str(ret))
 
     def measures(self, x, gt, loss=None):
-        if loss is None: loss = self.forward(x, gt)
+        if loss is None:
+            loss = self.forward(x, gt)
         return {'loss_ap': float(loss)}
 
 
@@ -89,37 +92,37 @@ class TAPLoss (APLoss):
            M: size of the descs;
            Q: number of bins (nq);
         '''
-        assert x.shape == label.shape # N x M
+        assert x.shape == label.shape  # N x M
         N, M = x.shape
         label = label.float()
         Np = label.sum(dim=-1, keepdim=True)
 
-        # quantize all predictions
+        # Quantize all predictions
         q = self.quantizer(x.unsqueeze(1))
-        q = torch.min(q[:,:self.nq], q[:,self.nq:]).clamp(min=0) # N x Q x M
+        q = torch.min(q[:, :self.nq], q[:, self.nq:]).clamp(min=0)  # N x Q x M
 
-        c = q.sum(dim=-1) # number of samples  N x Q = nbs on APLoss
-        cp = (q * label.view(N,1,M)).sum(dim=-1) # N x Q number of correct samples = rec on APLoss
-        C  = c.cumsum(dim=-1)
+        c = q.sum(dim=-1)  # number of samples  N x Q = nbs on APLoss
+        cp = (q * label.view(N, 1, M)).sum(dim=-1)  # N x Q number of correct samples = rec on APLoss
+        C = c.cumsum(dim=-1)
         Cp = cp.cumsum(dim=-1)
 
-        zeros = torch.zeros(N,1).to(x.device)
-        C_1d  = torch.cat((zeros, C[:,:-1]), dim=-1)
-        Cp_1d = torch.cat((zeros, Cp[:,:-1]), dim=-1)
+        zeros = torch.zeros(N, 1).to(x.device)
+        C_1d = torch.cat((zeros, C[:, :-1]), dim=-1)
+        Cp_1d = torch.cat((zeros, Cp[:, :-1]), dim=-1)
 
         if self.simplified:
             aps = cp * (Cp_1d+Cp+1) / (C_1d+C+1) / Np
         else:
-            #pdb.set_trace()
             eps = 1e-8
-            ratio = (cp-1).clamp(min=0)/( (c-1).clamp(min=0) +eps)
-            aps = cp*( c*ratio + (Cp_1d+1-ratio*(C_1d+1))*torch.log( (C+1)/(C_1d+1) ) )/(c+eps)/Np
+            ratio = (cp - 1).clamp(min=0) / ((c-1).clamp(min=0) + eps)
+            aps = cp * (c * ratio + (Cp_1d + 1 - ratio * (C_1d + 1)) * torch.log((C + 1) / (C_1d + 1))) / (c + eps) / Np
         aps = aps.sum(dim=-1)
 
-        assert aps.numel() == N, pdb.set_trace()
+        assert aps.numel() == N
 
         if ret == '1-mAP':
-            if qw is not None: aps *= qw # query weights
+            if qw is not None:
+                aps *= qw  # query weights
             return 1 - aps.mean()
         elif ret == 'AP':
             assert qw is None
@@ -128,7 +131,8 @@ class TAPLoss (APLoss):
             raise ValueError("Bad return type for APLoss(): %s" % str(ret))
 
     def measures(self, x, gt, loss=None):
-        if loss is None: loss = self.forward(x, gt)
+        if loss is None:
+            loss = self.forward(x, gt)
         return {'loss_tap'+('s' if self.simplified else ''): float(loss)}
 
 
@@ -141,7 +145,7 @@ class TripletMarginLoss(nn.TripletMarginLoss):
         return max(0, dp - dn + self.margin)
 
 
-class TripletLogExpLoss(Module):
+class TripletLogExpLoss(nn.Module):
     r"""Creates a criterion that measures the triplet loss given an input
     tensors x1, x2, x3.
     This is used for measuring a relative similarity between samples. A triplet
@@ -200,7 +204,7 @@ class TripletLogExpLoss(Module):
         return loss
 
     def eval_func(self, dp, dn):
-        return np.log(1 + np.exp(d_p - d_n))
+        return np.log(1 + np.exp(dp - dn))
 
 
 def sim_to_dist(scores):