diff --git a/bin/viz_ffcv.py b/bin/viz_ffcv.py index 314ead8d69907f3de7e3ed604cca8c31c587cd96..d979f31177dd66dddeb6353e0200fe242c6e1739 100644 --- a/bin/viz_ffcv.py +++ b/bin/viz_ffcv.py @@ -6,6 +6,9 @@ from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,Ran from ffcv.fields.decoders import CenterCropRGBImageDecoder, IntDecoder from torchvision.utils import save_image import numpy as np +import time +from tqdm import tqdm + IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255 IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255 @@ -18,7 +21,7 @@ parser.add_argument('--img_size', type=int, default=224) args = parser.parse_args() image_pipeline = [ - CenterCropRGBImageDecoder((args.img_size, args.img_size), 1), + CenterCropRGBImageDecoder((args.img_size, args.img_size), 224/256), ToTensor(), ToTorchImage(), ] @@ -29,8 +32,14 @@ pipelines = { 'label': label_pipeline, } loader = Loader(args.data_path, pipelines=pipelines, - batch_size=args.num_samples, num_workers=10, + batch_size=args.num_samples, num_workers=10, ) + +start = time.time() +for batch in tqdm(loader): + pass +print('Time taken', time.time()-start, 's to load', len(loader.indices), 'samples') + nrow=int(np.sqrt(args.num_samples)) for batch in loader: x, y = batch diff --git a/vitookit/datasets/transform.py b/vitookit/datasets/transform.py index 2532c1d8531e9fb2aa87b8998b7119a307c7ad54..33750245324bf46bd19c875bf7db3ec5e7f9800a 100644 --- a/vitookit/datasets/transform.py +++ b/vitookit/datasets/transform.py @@ -31,7 +31,7 @@ def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0)): return pipelines @gin.configurable -def ValPipeline(img_size=224,ratio= 4.0/3.0): +def ValPipeline(img_size=224,ratio= 224/256): image_pipeline = [ CenterCropRGBImageDecoder((img_size, img_size), ratio), NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16), diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py index a3775db7029aca7c683b3f43972fe3ccddac1da6..e6d5b97f2c11f133d1dc5a824cefba13198115c1 100644 --- a/vitookit/evaluation/eval_cls.py +++ b/vitookit/evaluation/eval_cls.py @@ -54,7 +54,7 @@ def get_args_parser(): parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') parser.add_argument('--epochs', default=100, type=int) - parser.add_argument('--ckpt_freq', default=10, type=int) + parser.add_argument('--ckpt_freq', default=5, type=int) # Model parameters parser.add_argument("--compile", action='store_true', default=False, help="compile model with PyTorch 2.0") @@ -337,12 +337,12 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, with torch.cuda.amp.autocast(): outputs = model(samples) - loss = criterion(samples, outputs, targets) + loss = criterion(outputs, targets) loss /= accum_iter loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False, - update_grad=(itr + 1) % accum_iter == 0) + need_update=(itr + 1) % accum_iter == 0) if (itr + 1) % accum_iter == 0: optimizer.zero_grad() @@ -529,31 +529,6 @@ def main(args): criterion = torch.nn.CrossEntropyLoss() print("criterion = %s" % str(criterion)) - - teacher_model = None - # if args.distillation_type != 'none': - # assert args.teacher_path, 'need to specify teacher-path when using distillation' - # print(f"Creating teacher model: {args.teacher_model}") - # teacher_model = create_model( - # args.teacher_model, - # pretrained=False, - # num_classes=args.nb_classes, - # head_type=args.head_type, - # ) - # if args.teacher_path.startswith('https'): - # checkpoint = torch.hub.load_state_dict_from_url( - # args.teacher_path, map_location='cpu', check_hash=True) - # else: - # checkpoint = torch.load(args.teacher_path, map_location='cpu') - # teacher_model.load_state_dict(checkpoint['model']) - # teacher_model.to(device) - # teacher_model.eval() - - # wrap the criterion in our custom DistillationLoss, which - # just dispatches to the original criterion if args.distillation_type is 'none' - # criterion = DistillationLoss( - # criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau - # ) output_dir = Path(args.output_dir) if args.output_dir else None @@ -563,9 +538,8 @@ def main(args): optimizer=optimizer, model=model_without_ddp, scaler=loss_scaler, - model_ema=model_ema, run_variables=run_variables) - args = run_variables['args'] + # args = run_variables['args'] args.start_epoch = run_variables["epoch"] + 1 if args.eval: @@ -627,7 +601,6 @@ def main(args): 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, - 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) diff --git a/vitookit/evaluation/eval_distill.py b/vitookit/evaluation/eval_distill.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f163caf4567d9bc0981ee3ea7b5710eb2fc6b5 --- /dev/null +++ b/vitookit/evaluation/eval_distill.py @@ -0,0 +1,651 @@ +#!/usr/bin/env python +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Mostly copy-paste from DEiT library: +https://github.com/facebookresearch/deit/blob/main/main.py +""" +from PIL import Image # hack to avoid `CXXABI_1.3.9' not found error + +import argparse +import datetime +import numpy as np +import time +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import json +import os +import math +import sys +import copy +import scipy.io as scio +from vitookit.datasets.transform import three_augmentation +from vitookit.utils.helper import * +from vitookit.utils import misc +from vitookit.models.build_model import build_model +from vitookit.datasets import build_dataset +import wandb + + +from pathlib import Path +from typing import Iterable, Optional +from torch.nn import functional as F + +from timm.models import create_model +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from timm.scheduler import create_scheduler +from timm.optim import create_optimizer +from timm.utils import NativeScaler, get_state_dict, ModelEma, accuracy +from timm.data import Mixup, create_transform +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import trunc_normal_ + + +def get_args_parser(): + parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) + parser.add_argument('--batch_size', default=128, type=int, + help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') + parser.add_argument('--accum_iter', default=1, type=int, + help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') + parser.add_argument('--epochs', default=100, type=int) + parser.add_argument('--ckpt_freq', default=5, type=int) + + # Model parameters + parser.add_argument("--compile", action='store_true', default=False, help="compile model with PyTorch 2.0") + parser.add_argument("--prefix", default=None, type=str, help="prefix of the model name") + + parser.add_argument('--input_size', default=224, type=int, help='images input size') + parser.add_argument('-w', '--pretrained_weights', default='', type=str, help="""Path to pretrained + weights to evaluate. Set to `download` to automatically load the pretrained DINO from url. + Otherwise the model is randomly initialized""") + parser.add_argument("--checkpoint_key", default=None, type=str, help='Key to use in the checkpoint (example: "teacher")') + parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', + help='Dropout rate (default: 0.)') + parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT', + help='Attention dropout rate (default: 0.)') + parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', + help='Drop path rate (default: 0.1)') + + parser.add_argument('--model_ema', action='store_true') + parser.add_argument('--no_model_ema', action='store_false', dest='model_ema') + parser.set_defaults(model_ema=False) + parser.add_argument('--model_ema_decay', type=float, default=0.99996, help='') + parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='') + + # Optimizer parameters + parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "adamw"') + parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') + parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: None, use opt default)') + parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') + parser.add_argument('--weight_decay', type=float, default=0.05, + help='weight decay (default: 0.05)') + parser.add_argument('--layer_decay', type=float, default=0.75) + # Learning rate schedule parameters + parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "cosine"') + parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', + help='learning rate (default: 5e-4)') + parser.add_argument('--lr_noise', type=float, nargs='+', default=None, metavar='pct, pct', + help='learning rate noise on/off epoch percentages') + parser.add_argument('--lr_noise_pct', type=float, default=0.67, metavar='PERCENT', + help='learning rate noise limit percent (default: 0.67)') + parser.add_argument('--lr_noise_std', type=float, default=1.0, metavar='STDDEV', + help='learning rate noise std-dev (default: 1.0)') + parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', + help='warmup learning rate (default: 1e-6)') + parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') + + parser.add_argument('--decay_epochs', type=float, default=30, metavar='N', + help='epoch interval to decay LR') + parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', + help='epochs to warmup LR, if scheduler supports') + parser.add_argument('--cooldown_epochs', type=int, default=10, metavar='N', + help='epochs to cooldown LR at min_lr, after cyclic schedule ends') + parser.add_argument('--patience_epochs', type=int, default=10, metavar='N', + help='patience epochs for Plateau LR scheduler (default: 10') + parser.add_argument('--decay_rate', '--dr', type=float, default=0.1, metavar='RATE', + help='LR decay rate (default: 0.1)') + + # Augmentation parameters + parser.add_argument('--ThreeAugment', action='store_true', default=False) #3augment + parser.add_argument('--src',action='store_true', default=False, + help="Use Simple Random Crop (SRC) or Random Resized Crop (RRC). Use SRC when there is less risk of overfitting, such as on ImageNet-21k.") + parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', + help='Color jitter factor (default: 0.4)') + parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', + help='Use AutoAugment policy. "v0" or "original". " + \ + "(default: rand-m9-mstd0.5-inc1)'), + parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') + parser.add_argument('--train_interpolation', type=str, default='bicubic', + help='Training interpolation (random, bilinear, bicubic default: "bicubic")') + + parser.add_argument('--repeated_aug', action='store_true') + parser.add_argument('--no_repeated_aug', action='store_false', dest='repeated_aug') + parser.set_defaults(repeated_aug=True) + + # * Random Erase params + parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', + help='Random erase prob (default: 0.25)') + parser.add_argument('--remode', type=str, default='pixel', + help='Random erase mode (default: "pixel")') + parser.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') + parser.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') + + # * Mixup params + parser.add_argument('--mixup', type=float, default=0.8, + help='mixup alpha, mixup enabled if > 0. (default: 0.8)') + parser.add_argument('--cutmix', type=float, default=1.0, + help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') + parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') + parser.add_argument('--mixup_prob', type=float, default=1.0, + help='Probability of performing mixup or cutmix when either/both is enabled') + parser.add_argument('--mixup_switch_prob', type=float, default=0.5, + help='Probability of switching to cutmix when both mixup and cutmix enabled') + parser.add_argument('--mixup_mode', type=str, default='batch', + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') + + # Distillation parameters + parser.add_argument('--teacher_model', default='regnety_160', type=str, metavar='MODEL', + help='Name of teacher model to train (default: "regnety_160"') + parser.add_argument('--teacher_path', type=str, default='') + parser.add_argument('--distillation_type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") + parser.add_argument('--distillation_alpha', default=0.5, type=float, help="") + parser.add_argument('--distillation_tau', default=1.0, type=float, help="") + + # * Finetuning params + parser.add_argument('--disable_weight_decay_on_bias_norm', action='store_true', default=False) + parser.add_argument('--init_scale', default=1.0, type=float) + + # Dataset parameters + parser.add_argument('--data_location', default='/datasets01/imagenet_full_size/061417/', type=str, + help='dataset path') + parser.add_argument('--data_set', default='IN1K', + type=str, help='ImageNet dataset path') + parser.add_argument('--inat_category', default='name', + choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], + type=str, help='semantic granularity') + + parser.add_argument('--output_dir', default=None, type=str, + help='path where to save, empty for no saving') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true', help='Perform evaluation only') + parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation') + parser.add_argument('--num_workers', default=10, type=int) + parser.add_argument('--pin_mem', action='store_true', + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem', + help='') + parser.set_defaults(pin_mem=True) + + # distributed training parameters + parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + + return parser + + + +class RASampler(torch.utils.data.Sampler): + """Sampler that restricts data loading to a subset of the dataset for distributed, + with repeated augmentation. + It ensures that different each augmented version of a sample will be visible to a + different process (GPU) + Heavily based on torch.utils.data.DistributedSampler + """ + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) + self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) + self.shuffle = shuffle + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + if self.shuffle: + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices = [ele for ele in indices for i in range(3)] + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices[:self.num_selected_samples]) + + def __len__(self): + return self.num_selected_samples + + def set_epoch(self, epoch): + self.epoch = epoch + +class DistillationLoss(torch.nn.Module): + """ + This module wraps a standard criterion and adds an extra knowledge distillation loss by + taking a teacher model prediction and using it as additional supervision. + """ + def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, + distillation_type: str, alpha: float, tau: float): + super().__init__() + self.base_criterion = base_criterion + self.teacher_model = teacher_model + assert distillation_type in ['none', 'soft', 'hard'] + self.distillation_type = distillation_type + self.alpha = alpha + self.tau = tau + + def forward(self, inputs, outputs, labels): + """ + Args: + inputs: The original inputs that are feed to the teacher model + outputs: the outputs of the model to be trained. It is expected to be + either a Tensor, or a Tuple[Tensor, Tensor], with the original output + in the first position and the distillation predictions as the second output + labels: the labels for the base criterion + """ + outputs_kd = None + if not isinstance(outputs, torch.Tensor): + # assume that the model outputs a tuple of [outputs, outputs_kd] + outputs, outputs_kd = outputs + base_loss = self.base_criterion(outputs, labels) + if self.distillation_type == 'none': + return base_loss + + if outputs_kd is None: + raise ValueError("When knowledge distillation is enabled, the model is " + "expected to return a Tuple[Tensor, Tensor] with the output of the " + "class_token and the dist_token") + # don't backprop throught the teacher + with torch.no_grad(): + teacher_outputs = self.teacher_model(inputs) + + if self.distillation_type == 'soft': + T = self.tau + # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 + # with slight modifications + distillation_loss = F.kl_div( + F.log_softmax(outputs_kd / T, dim=1), + F.log_softmax(teacher_outputs / T, dim=1), + reduction='sum', + log_target=True + ) * (T * T) / outputs_kd.numel() + elif self.distillation_type == 'hard': + distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) + + loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha + return loss + +def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler,lr_scheduler, max_norm: float = 0, + model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, + ): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + print_freq = max(len(data_loader)//20,20) + + accum_iter = args.accum_iter + for itr,(samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + samples = samples.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + lr_scheduler.step(epoch+itr/len(data_loader)) + if mixup_fn is not None: + samples, targets = mixup_fn(samples, targets) + + with torch.cuda.amp.autocast(): + outputs = model(samples) + loss = criterion(samples, outputs, targets) + + loss /= accum_iter + loss_scaler(loss, optimizer, clip_grad=max_norm, + parameters=model.parameters(), create_graph=False, + need_update=(itr + 1) % accum_iter == 0) + if (itr + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + # log metrics + loss_value = loss.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + # this attribute is added by timm on one optimizer (adahessian) + # is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order + + # if model_ema is not None: + # model_ema.update(model) + + if wandb.run: + wandb.log({'train/loss':loss}) + metric_logger.update(loss=loss_value) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(data_loader, model, device): + criterion = torch.nn.CrossEntropyLoss() + + metric_logger = misc.MetricLogger(delimiter=" ") + header = 'Test:' + + # switch to evaluation mode + model.eval() + + for images, target in metric_logger.log_every(data_loader, 10, header): + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True, dtype=torch.long) + + # compute output + with torch.cuda.amp.autocast(): + output = model(images) + loss = criterion(output, target) + + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + batch_size = images.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' + .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + +def main(args): + misc.init_distributed_mode(args) + + print(args) + + if args.distillation_type != 'none' and args.pretrained_weights and not args.eval: + raise NotImplementedError("Finetuning with distillation not yet supported") + + device = torch.device(args.device) + + # fix the seed for reproducibility + misc.fix_random_seeds(args.seed) + + cudnn.benchmark = True + + if args.ThreeAugment: + transform = three_augmentation(args) + else: + transform = None + dataset_train, args.nb_classes = build_dataset(args=args, is_train=True, trnsfrm=transform) + dataset_val, _ = build_dataset(is_train=False, args=args) + + + print("Load dataset:", dataset_train) + + if True: # args.distributed: + num_tasks = misc.get_world_size() + global_rank = misc.get_rank() + if args.repeated_aug: + sampler_train = RASampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + else: + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' + 'This will slightly alter validation results as extra duplicate entries are added to achieve ' + 'equal num of samples per-process.') + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + batch_size=int(1.5 * args.batch_size), + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False + ) + + mixup_fn = None + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + print("Mixup is activated!") + mixup_fn = Mixup( + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, + label_smoothing=args.smoothing, num_classes=args.nb_classes) + + + print(f"Model built.") + # load weights to evaluate + + model = build_model(num_classes=args.nb_classes,drop_path_rate=args.drop_path,) + if args.pretrained_weights: + load_pretrained_weights(model, args.pretrained_weights, checkpoint_key=args.checkpoint_key, prefix=args.prefix) + if args.compile: + return torch.compile(model) + trunc_normal_(model.head.weight, std=2e-5) + + model.to(device) + + model_ema = None + if args.model_ema: + # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper + model_ema = ModelEma( + model, + decay=args.model_ema_decay, + device='cpu' if args.model_ema_force_cpu else '', + resume='') + + model_without_ddp = model + if True: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('number of params:', n_parameters) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + + linear_scaled_lr = args.lr * eff_batch_size / 256.0 + + print("base lr: %.2e" % args.lr ) + print("actual lr: %.2e" % linear_scaled_lr) + args.lr = linear_scaled_lr + + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + optimizer = create_optimizer(args, model_without_ddp) + loss_scaler = NativeScaler() + + lr_scheduler, _ = create_scheduler(args, optimizer) + + criterion = LabelSmoothingCrossEntropy() + + if mixup_active: + # smoothing is handled with mixup label transform + criterion = SoftTargetCrossEntropy() + elif args.smoothing: + criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) + else: + criterion = torch.nn.CrossEntropyLoss() + + if args.bce_loss: + criterion = torch.nn.BCEWithLogitsLoss() + print("criterion = %s" % str(criterion)) + + teacher_model = None + if args.distillation_type != 'none': + assert args.teacher_path, 'need to specify teacher-path when using distillation' + print(f"Creating teacher model: {args.teacher_model}") + teacher_model = create_model( + args.teacher_model, + pretrained=False, + num_classes=args.nb_classes, + head_type=args.head_type, + ) + if args.teacher_path.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.teacher_path, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.teacher_path, map_location='cpu') + teacher_model.load_state_dict(checkpoint['model']) + teacher_model.to(device) + teacher_model.eval() + + # wrap the criterion in our custom DistillationLoss, which + # just dispatches to the original criterion if args.distillation_type is 'none' + criterion = DistillationLoss( + criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau + ) + + output_dir = Path(args.output_dir) if args.output_dir else None + if args.resume: + run_variables={"args":dict(),"epoch":0} + restart_from_checkpoint(args.resume, + optimizer=optimizer, + model=model_without_ddp, + scaler=loss_scaler, + model_ema=model_ema, + run_variables=run_variables) + # args = run_variables['args'] + args.start_epoch = run_variables["epoch"] + 1 + + if args.eval: + test_stats = evaluate(data_loader_val, model_without_ddp, device) + print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") + if args.output_dir and misc.is_main_process(): + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(test_stats) + "\n") + return + + print(f"Start training for {args.epochs} epochs from {args.start_epoch}") + start_time = time.time() + max_accuracy = 0.0 + if args.output_dir and misc.is_main_process(): + try: + wandb.init(job_type='distill',dir=args.output_dir,resume=True, + config=args.__dict__) + except: + pass + + for epoch in range(args.start_epoch, args.epochs): + data_loader_train.sampler.set_epoch(epoch) + + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler,lr_scheduler, + args.clip_grad, model_ema, mixup_fn, + ) + + checkpoint_paths = [output_dir / 'checkpoint.pth'] + + if epoch%10==0 or epoch==args.epochs-1: + test_stats = evaluate(data_loader_val, model, device) + print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") + + if (test_stats["acc1"] >= max_accuracy): + # always only save best checkpoint till now + checkpoint_paths += [output_dir / 'checkpoint_best.pth'] + + + max_accuracy = max(max_accuracy, test_stats["acc1"]) + print(f'Max accuracy: {max_accuracy:.2f}%') + + log_stats = {**{f'train/{k}': v for k, v in train_stats.items()}, + **{f'test/{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + else: + log_stats = {**{f'train/{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + + # only save checkpoint on rank 0 + if args.output_dir and misc.is_main_process(): + if epoch%args.ckpt_freq==0 or epoch==args.epochs-1: + for checkpoint_path in checkpoint_paths: + misc.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'model_ema': get_state_dict(model_ema), + 'scaler': loss_scaler.state_dict(), + 'args': args, + }, checkpoint_path) + + if wandb.run: wandb.log(log_stats) + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) + args = aug_parse(parser) + main(args) diff --git a/vitookit/evaluation/eval_linear.py b/vitookit/evaluation/eval_linear.py index 598f8dd8b263464526143928216870bddb096ef8..8a2ec0eea121f5ce162e83c417b0aa80b563d6e1 100644 --- a/vitookit/evaluation/eval_linear.py +++ b/vitookit/evaluation/eval_linear.py @@ -41,7 +41,7 @@ def get_args_parser(): parser.add_argument('--batch_size', default=512, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') parser.add_argument('--epochs', default=90, type=int) - parser.add_argument('--ckpt_freq', default=10, type=int) + parser.add_argument('--ckpt_freq', default=5, type=int) parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') @@ -251,7 +251,7 @@ def main(args): model=model_without_ddp, scaler=loss_scaler, run_variables=run_variables) - args = run_variables['args'] + # args = run_variables['args'] args.start_epoch = run_variables["epoch"] + 1 print("Resuming from epoch %d" % args.start_epoch) diff --git a/vitookit/evaluation/eval_linear_ffcv.py b/vitookit/evaluation/eval_linear_ffcv.py index 3e50d031f57b601e7a08ade85ec4b49c888e41b9..5368cc36fa1a4b30cbdeb968844c936fbe2c036b 100644 --- a/vitookit/evaluation/eval_linear_ffcv.py +++ b/vitookit/evaluation/eval_linear_ffcv.py @@ -43,7 +43,7 @@ def get_args_parser(): parser.add_argument('--batch_size', default=512, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') parser.add_argument('--epochs', default=90, type=int) - parser.add_argument('--ckpt_freq', default=10, type=int) + parser.add_argument('--ckpt_freq', default=5, type=int) parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') @@ -136,15 +136,8 @@ def main(args): np.random.seed(seed) cudnn.benchmark = True - order = OrderOption.RANDOM if args.distributed else OrderOption.QUASI_RANDOM - - - data_loader_train = Loader(args.train_path, pipelines=SimplePipeline(), - batch_size=args.batch_size, num_workers=args.num_workers, - order=order, distributed=args.distributed,seed=args.seed) - - data_loader_val = Loader(args.val_path, pipelines=ValPipeline(ratio=1), + data_loader_val = Loader(args.val_path, pipelines=ValPipeline(), batch_size=args.batch_size, num_workers=args.num_workers, distributed=args.distributed,seed=args.seed) @@ -206,6 +199,7 @@ def main(args): print("criterion = %s" % str(criterion)) + if args.resume: run_variables={"args":dict(),"epoch":0} restart_from_checkpoint(args.resume, @@ -213,9 +207,10 @@ def main(args): model=model_without_ddp, scaler=loss_scaler, run_variables=run_variables) - args = run_variables['args'] + # args = run_variables['args'] args.start_epoch = run_variables["epoch"] + 1 - + print("resume from epoch %d" % args.start_epoch) + if args.eval: test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%") @@ -226,7 +221,13 @@ def main(args): max_accuracy = 0.0 output_dir = Path(args.output_dir) if args.output_dir else None - + + + order = OrderOption.RANDOM if args.distributed else OrderOption.QUASI_RANDOM + data_loader_train = Loader(args.train_path, pipelines=SimplePipeline(), + batch_size=args.batch_size, num_workers=args.num_workers, + order=order, distributed=args.distributed,seed=args.seed) + for epoch in range(args.start_epoch, args.epochs): train_stats = train_one_epoch( @@ -238,19 +239,19 @@ def main(args): 'epoch': epoch, 'n_parameters': n_parameters} - ckpt_path = output_dir / 'checkpoint.pth' + ckpt_path = 'checkpoint.pth' if epoch%args.ckpt_freq==0 or epoch == args.epochs-1: test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%") print(f'Max accuracy: {max_accuracy:.2f}%') log_stats.update({f'test/{k}': v for k, v in test_stats.items()}) if max_accuracy < test_stats["acc1"]: - ckpt_path = output_dir / 'best_checkpoint.pth' + ckpt_path = 'best_checkpoint.pth' max_accuracy = max(max_accuracy, test_stats["acc1"]) else: test_stats={} - if args.output_dir and misc.is_main_process(): + if output_dir and misc.is_main_process(): if epoch % args.ckpt_freq == 0 or epoch == args.epochs-1: misc.save_on_master({ 'model': model_without_ddp.state_dict(), @@ -258,7 +259,7 @@ def main(args): 'epoch': epoch, 'scaler': loss_scaler.state_dict(), 'args': args, - }, ckpt_path) + }, output_dir / ckpt_path) if wandb.run: wandb.log(log_stats)