diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py index e6d5b97f2c11f133d1dc5a824cefba13198115c1..37d5d68d9122742e29ae25dcddf409eeb560bde6 100644 --- a/vitookit/evaluation/eval_cls.py +++ b/vitookit/evaluation/eval_cls.py @@ -72,11 +72,6 @@ def get_args_parser(): parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate (default: 0.1)') - parser.add_argument('--model_ema', action='store_true') - parser.add_argument('--no_model_ema', action='store_false', dest='model_ema') - parser.set_defaults(model_ema=False) - parser.add_argument('--model_ema_decay', type=float, default=0.99996, help='') - parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='') # Optimizer parameters parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', @@ -99,23 +94,12 @@ def get_args_parser(): help='learning rate (default: 5e-4)') parser.add_argument('--lr_noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') - parser.add_argument('--lr_noise_pct', type=float, default=0.67, metavar='PERCENT', - help='learning rate noise limit percent (default: 0.67)') - parser.add_argument('--lr_noise_std', type=float, default=1.0, metavar='STDDEV', - help='learning rate noise std-dev (default: 1.0)') - parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', - help='warmup learning rate (default: 1e-6)') - parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR', + + parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') - parser.add_argument('--decay_epochs', type=float, default=30, metavar='N', - help='epoch interval to decay LR') parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') - parser.add_argument('--cooldown_epochs', type=int, default=10, metavar='N', - help='epochs to cooldown LR at min_lr, after cyclic schedule ends') - parser.add_argument('--patience_epochs', type=int, default=10, metavar='N', - help='patience epochs for Plateau LR scheduler (default: 10') parser.add_argument('--decay_rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') @@ -123,18 +107,11 @@ def get_args_parser(): parser.add_argument('--ThreeAugment', action='store_true', default=False) #3augment parser.add_argument('--src',action='store_true', default=False, help="Use Simple Random Crop (SRC) or Random Resized Crop (RRC). Use SRC when there is less risk of overfitting, such as on ImageNet-21k.") - parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', - help='Color jitter factor (default: 0.4)') + parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', + help='Color jitter factor (enabled only when not using Auto/RandAug)') parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', - help='Use AutoAugment policy. "v0" or "original". " + \ - "(default: rand-m9-mstd0.5-inc1)'), + help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') - parser.add_argument('--train_interpolation', type=str, default='bicubic', - help='Training interpolation (random, bilinear, bicubic default: "bicubic")') - - parser.add_argument('--repeated_aug', action='store_true') - parser.add_argument('--no_repeated_aug', action='store_false', dest='repeated_aug') - parser.set_defaults(repeated_aug=True) # * Random Erase params parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', @@ -160,14 +137,6 @@ def get_args_parser(): parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') - # Distillation parameters - parser.add_argument('--teacher_model', default='regnety_160', type=str, metavar='MODEL', - help='Name of teacher model to train (default: "regnety_160"') - parser.add_argument('--teacher_path', type=str, default='') - parser.add_argument('--distillation_type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") - parser.add_argument('--distillation_alpha', default=0.5, type=float, help="") - parser.add_argument('--distillation_tau', default=1.0, type=float, help="") - # * Finetuning params parser.add_argument('--disable_weight_decay_on_bias_norm', action='store_true', default=False) parser.add_argument('--init_scale', default=1.0, type=float) @@ -205,120 +174,10 @@ def get_args_parser(): return parser - -class RASampler(torch.utils.data.Sampler): - """Sampler that restricts data loading to a subset of the dataset for distributed, - with repeated augmentation. - It ensures that different each augmented version of a sample will be visible to a - different process (GPU) - Heavily based on torch.utils.data.DistributedSampler - """ - - def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): - if num_replicas is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") - num_replicas = dist.get_world_size() - if rank is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") - rank = dist.get_rank() - self.dataset = dataset - self.num_replicas = num_replicas - self.rank = rank - self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) - self.total_size = self.num_samples * self.num_replicas - # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) - self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) - self.shuffle = shuffle - - def __iter__(self): - # deterministically shuffle based on epoch - g = torch.Generator() - g.manual_seed(self.epoch) - if self.shuffle: - indices = torch.randperm(len(self.dataset), generator=g).tolist() - else: - indices = list(range(len(self.dataset))) - - # add extra samples to make it evenly divisible - indices = [ele for ele in indices for i in range(3)] - indices += indices[:(self.total_size - len(indices))] - assert len(indices) == self.total_size - - # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] - assert len(indices) == self.num_samples - - return iter(indices[:self.num_selected_samples]) - - def __len__(self): - return self.num_selected_samples - - def set_epoch(self, epoch): - self.epoch = epoch - -class DistillationLoss(torch.nn.Module): - """ - This module wraps a standard criterion and adds an extra knowledge distillation loss by - taking a teacher model prediction and using it as additional supervision. - """ - def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, - distillation_type: str, alpha: float, tau: float): - super().__init__() - self.base_criterion = base_criterion - self.teacher_model = teacher_model - assert distillation_type in ['none', 'soft', 'hard'] - self.distillation_type = distillation_type - self.alpha = alpha - self.tau = tau - - def forward(self, inputs, outputs, labels): - """ - Args: - inputs: The original inputs that are feed to the teacher model - outputs: the outputs of the model to be trained. It is expected to be - either a Tensor, or a Tuple[Tensor, Tensor], with the original output - in the first position and the distillation predictions as the second output - labels: the labels for the base criterion - """ - outputs_kd = None - if not isinstance(outputs, torch.Tensor): - # assume that the model outputs a tuple of [outputs, outputs_kd] - outputs, outputs_kd = outputs - base_loss = self.base_criterion(outputs, labels) - if self.distillation_type == 'none': - return base_loss - - if outputs_kd is None: - raise ValueError("When knowledge distillation is enabled, the model is " - "expected to return a Tuple[Tensor, Tensor] with the output of the " - "class_token and the dist_token") - # don't backprop throught the teacher - with torch.no_grad(): - teacher_outputs = self.teacher_model(inputs) - - if self.distillation_type == 'soft': - T = self.tau - # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 - # with slight modifications - distillation_loss = F.kl_div( - F.log_softmax(outputs_kd / T, dim=1), - F.log_softmax(teacher_outputs / T, dim=1), - reduction='sum', - log_target=True - ) * (T * T) / outputs_kd.numel() - elif self.distillation_type == 'hard': - distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) - - loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha - return loss - -def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler,lr_scheduler, max_norm: float = 0, - model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, + mixup_fn: Optional[Mixup] = None, ): model.train(True) metric_logger = misc.MetricLogger(delimiter=" ") @@ -337,7 +196,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, with torch.cuda.amp.autocast(): outputs = model(samples) - loss = criterion(outputs, targets) + loss = criterion(samples, outputs, targets) loss /= accum_iter loss_scaler(loss, optimizer, clip_grad=max_norm, @@ -401,14 +260,12 @@ def evaluate(data_loader, model, device): return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + def main(args): misc.init_distributed_mode(args) print(args) - if args.distillation_type != 'none' and args.pretrained_weights and not args.eval: - raise NotImplementedError("Finetuning with distillation not yet supported") - device = torch.device(args.device) # fix the seed for reproducibility @@ -429,21 +286,16 @@ def main(args): if True: # args.distributed: num_tasks = misc.get_world_size() global_rank = misc.get_rank() - if args.repeated_aug: - sampler_train = RASampler( - dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True - ) - else: - sampler_train = torch.utils.data.DistributedSampler( - dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True - ) + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( - dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: @@ -488,19 +340,7 @@ def main(args): model.to(device) - model_ema = None - if args.model_ema: - # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper - model_ema = ModelEma( - model, - decay=args.model_ema_decay, - device='cpu' if args.model_ema_force_cpu else '', - resume='') - model_without_ddp = model - if True: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) - model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) @@ -515,6 +355,9 @@ def main(args): print("accumulate grad iterations: %d" % args.accum_iter) print("effective batch size: %d" % eff_batch_size) + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() @@ -548,7 +391,7 @@ def main(args): if args.output_dir and misc.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(test_stats) + "\n") - return + exit(0) print(f"Start training for {args.epochs} epochs from {args.start_epoch}") start_time = time.time() @@ -566,18 +409,18 @@ def main(args): train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler,lr_scheduler, - args.clip_grad, model_ema, mixup_fn, + args.clip_grad, mixup_fn, ) - checkpoint_paths = [output_dir / 'checkpoint.pth'] + checkpoint_paths = ['checkpoint.pth'] - if epoch%10==0 or epoch==args.epochs-1: + if epoch%args.ckpt_freq==0 or epoch==args.epochs-1: test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") if (test_stats["acc1"] >= max_accuracy): # always only save best checkpoint till now - checkpoint_paths += [output_dir / 'checkpoint_best.pth'] + checkpoint_paths += [ 'checkpoint_best.pth'] max_accuracy = max(max_accuracy, test_stats["acc1"]) @@ -593,7 +436,7 @@ def main(args): 'n_parameters': n_parameters} # only save checkpoint on rank 0 - if args.output_dir and misc.is_main_process(): + if output_dir and misc.is_main_process(): if epoch%args.ckpt_freq==0 or epoch==args.epochs-1: for checkpoint_path in checkpoint_paths: misc.save_on_master({ @@ -603,7 +446,7 @@ def main(args): 'epoch': epoch, 'scaler': loss_scaler.state_dict(), 'args': args, - }, checkpoint_path) + }, output_dir / checkpoint_path) if wandb.run: wandb.log(log_stats) with (output_dir / "log.txt").open("a") as f: