Skip to content
Snippets Groups Projects
Commit 5822516e authored by Wu, Jiantao (PG/R - Comp Sci & Elec Eng)'s avatar Wu, Jiantao (PG/R - Comp Sci & Elec Eng)
Browse files

cls3: rm useless args and set color_jitter to None

parent d0ceaa86
No related branches found
No related tags found
No related merge requests found
......@@ -72,11 +72,6 @@ def get_args_parser():
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: 0.1)')
parser.add_argument('--model_ema', action='store_true')
parser.add_argument('--no_model_ema', action='store_false', dest='model_ema')
parser.set_defaults(model_ema=False)
parser.add_argument('--model_ema_decay', type=float, default=0.99996, help='')
parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
......@@ -99,23 +94,12 @@ def get_args_parser():
help='learning rate (default: 5e-4)')
parser.add_argument('--lr_noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr_noise_pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr_noise_std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--decay_epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown_epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience_epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay_rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
......@@ -123,18 +107,11 @@ def get_args_parser():
parser.add_argument('--ThreeAugment', action='store_true', default=False) #3augment
parser.add_argument('--src',action='store_true', default=False,
help="Use Simple Random Crop (SRC) or Random Resized Crop (RRC). Use SRC when there is less risk of overfitting, such as on ImageNet-21k.")
parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
help='Color jitter factor (enabled only when not using Auto/RandAug)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + \
"(default: rand-m9-mstd0.5-inc1)'),
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
parser.add_argument('--train_interpolation', type=str, default='bicubic',
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
parser.add_argument('--repeated_aug', action='store_true')
parser.add_argument('--no_repeated_aug', action='store_false', dest='repeated_aug')
parser.set_defaults(repeated_aug=True)
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
......@@ -160,14 +137,6 @@ def get_args_parser():
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# Distillation parameters
parser.add_argument('--teacher_model', default='regnety_160', type=str, metavar='MODEL',
help='Name of teacher model to train (default: "regnety_160"')
parser.add_argument('--teacher_path', type=str, default='')
parser.add_argument('--distillation_type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
parser.add_argument('--distillation_alpha', default=0.5, type=float, help="")
parser.add_argument('--distillation_tau', default=1.0, type=float, help="")
# * Finetuning params
parser.add_argument('--disable_weight_decay_on_bias_norm', action='store_true', default=False)
parser.add_argument('--init_scale', default=1.0, type=float)
......@@ -205,120 +174,10 @@ def get_args_parser():
return parser
class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed,
with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU)
Heavily based on torch.utils.data.DistributedSampler
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices = [ele for ele in indices for i in range(3)]
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices[:self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch
class DistillationLoss(torch.nn.Module):
"""
This module wraps a standard criterion and adds an extra knowledge distillation loss by
taking a teacher model prediction and using it as additional supervision.
"""
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
distillation_type: str, alpha: float, tau: float):
super().__init__()
self.base_criterion = base_criterion
self.teacher_model = teacher_model
assert distillation_type in ['none', 'soft', 'hard']
self.distillation_type = distillation_type
self.alpha = alpha
self.tau = tau
def forward(self, inputs, outputs, labels):
"""
Args:
inputs: The original inputs that are feed to the teacher model
outputs: the outputs of the model to be trained. It is expected to be
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
in the first position and the distillation predictions as the second output
labels: the labels for the base criterion
"""
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
# assume that the model outputs a tuple of [outputs, outputs_kd]
outputs, outputs_kd = outputs
base_loss = self.base_criterion(outputs, labels)
if self.distillation_type == 'none':
return base_loss
if outputs_kd is None:
raise ValueError("When knowledge distillation is enabled, the model is "
"expected to return a Tuple[Tensor, Tensor] with the output of the "
"class_token and the dist_token")
# don't backprop throught the teacher
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
if self.distillation_type == 'soft':
T = self.tau
# taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
# with slight modifications
distillation_loss = F.kl_div(
F.log_softmax(outputs_kd / T, dim=1),
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum',
log_target=True
) * (T * T) / outputs_kd.numel()
elif self.distillation_type == 'hard':
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
return loss
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,lr_scheduler, max_norm: float = 0,
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
mixup_fn: Optional[Mixup] = None,
):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
......@@ -337,7 +196,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(outputs, targets)
loss = criterion(samples, outputs, targets)
loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm,
......@@ -401,14 +260,12 @@ def evaluate(data_loader, model, device):
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def main(args):
misc.init_distributed_mode(args)
print(args)
if args.distillation_type != 'none' and args.pretrained_weights and not args.eval:
raise NotImplementedError("Finetuning with distillation not yet supported")
device = torch.device(args.device)
# fix the seed for reproducibility
......@@ -429,21 +286,16 @@ def main(args):
if True: # args.distributed:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
if args.repeated_aug:
sampler_train = RASampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
else:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True)
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
......@@ -488,19 +340,7 @@ def main(args):
model.to(device)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume='')
model_without_ddp = model
if True:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
......@@ -515,6 +355,9 @@ def main(args):
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
optimizer = create_optimizer(args, model_without_ddp)
loss_scaler = NativeScaler()
......@@ -548,7 +391,7 @@ def main(args):
if args.output_dir and misc.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(test_stats) + "\n")
return
exit(0)
print(f"Start training for {args.epochs} epochs from {args.start_epoch}")
start_time = time.time()
......@@ -566,18 +409,18 @@ def main(args):
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,lr_scheduler,
args.clip_grad, model_ema, mixup_fn,
args.clip_grad, mixup_fn,
)
checkpoint_paths = [output_dir / 'checkpoint.pth']
checkpoint_paths = ['checkpoint.pth']
if epoch%10==0 or epoch==args.epochs-1:
if epoch%args.ckpt_freq==0 or epoch==args.epochs-1:
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
if (test_stats["acc1"] >= max_accuracy):
# always only save best checkpoint till now
checkpoint_paths += [output_dir / 'checkpoint_best.pth']
checkpoint_paths += [ 'checkpoint_best.pth']
max_accuracy = max(max_accuracy, test_stats["acc1"])
......@@ -593,7 +436,7 @@ def main(args):
'n_parameters': n_parameters}
# only save checkpoint on rank 0
if args.output_dir and misc.is_main_process():
if output_dir and misc.is_main_process():
if epoch%args.ckpt_freq==0 or epoch==args.epochs-1:
for checkpoint_path in checkpoint_paths:
misc.save_on_master({
......@@ -603,7 +446,7 @@ def main(args):
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}, checkpoint_path)
}, output_dir / checkpoint_path)
if wandb.run: wandb.log(log_stats)
with (output_dir / "log.txt").open("a") as f:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment