diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..89d749e18474bf6bb8deeb490582271d6d415bc9 --- /dev/null +++ b/vitookit/datasets/ffcv_transform.py @@ -0,0 +1,396 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +Copy from https://github.com/facebookresearch/FFCV-SSL/blob/6458e33f0753e7a35bc639517a763350a0fc2177/ffcv/transforms/colorjitter.py +""" + + +import numpy as np +from collections.abc import Sequence +from typing import Callable, Optional, Tuple +from dataclasses import replace +from ffcv.pipeline.allocation_query import AllocationQuery +from ffcv.pipeline.operation import Operation +from ffcv.pipeline.state import State +from ffcv.pipeline.compiler import Compiler +import numba as nb +import numbers +import math +import random +from numba import njit, jit +import gin +from cv2 import GaussianBlur +from scipy.ndimage import gaussian_filter + +from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip, View +from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder, SimpleRGBImageDecoder, CenterCropRGBImageDecoder + +import torch +import torchvision.transforms as tfms +from torch import nn + +IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255 +IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255 + +@njit(parallel=False, fastmath=True, inline="always") +def apply_cj( + im, + apply_bri, + bri_ratio, + apply_cont, + cont_ratio, + apply_sat, + sat_ratio, + apply_hue, + hue_factor, +): + + gray = ( + np.float32(0.2989) * im[..., 0] + + np.float32(0.5870) * im[..., 1] + + np.float32(0.1140) * im[..., 2] + ) + one = np.float32(1) + # Brightness + if apply_bri: + im = im * bri_ratio + + # Contrast + if apply_cont: + im = cont_ratio * im + (one - cont_ratio) * np.float32(gray.mean()) + + # Saturation + if apply_sat: + im[..., 0] = sat_ratio * im[..., 0] + (one - sat_ratio) * gray + im[..., 1] = sat_ratio * im[..., 1] + (one - sat_ratio) * gray + im[..., 2] = sat_ratio * im[..., 2] + (one - sat_ratio) * gray + + # Hue + if apply_hue: + hue_factor_radians = hue_factor * 2.0 * np.pi + cosA = np.cos(hue_factor_radians) + sinA = np.sin(hue_factor_radians) + v1, v2, v3 = 1.0 / 3.0, np.sqrt(1.0 / 3.0), (1.0 - cosA) + hue_matrix = [ + [ + cosA + v3 / 3.0, + v1 * v3 - v2 * sinA, + v1 * v3 + v2 * sinA, + ], + [ + v1 * v3 + v2 * sinA, + cosA + v1 * v3, + v1 * v3 - v2 * sinA, + ], + [ + v1 * v3 - v2 * sinA, + v1 * v3 + v2 * sinA, + cosA + v1 * v3, + ], + ] + hue_matrix = np.array(hue_matrix, dtype=np.float64).T + for row in nb.prange(im.shape[0]): + im[row] = im[row] @ hue_matrix + return np.clip(im, 0, 255).astype(np.uint8) + + +class RandomColorJitter(Operation): + """Add ColorJitter with probability jitter_prob. + Operates on raw arrays (not tensors), ranging from 0 to 255. + + see https://github.com/pytorch/vision/blob/28557e0cfe9113a5285330542264f03e4ba74535/torchvision/transforms/functional_tensor.py#L165 + and https://sanje2v.wordpress.com/2021/01/11/accelerating-data-transforms/ + Parameters + ---------- + jitter_prob : float, The probability with which to apply ColorJitter. + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ + + def __init__( + self, + jitter_prob=0.5, + brightness=0.8, + contrast=0.4, + saturation=0.4, + hue=0.2, + seed=None, + ): + super().__init__() + self.jitter_prob = jitter_prob + + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5)) + self.seed = seed + assert self.jitter_prob >= 0 and self.jitter_prob <= 1 + + def _check_input( + self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True + ): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError( + f"If {name} is a single number, it must be non negative." + ) + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"{name} values should be between {bound}") + else: + raise TypeError( + f"{name} should be a single number or a list/tuple with length 2." + ) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + setattr(self, f"apply_{name}", False) + else: + setattr(self, f"apply_{name}", True) + return tuple(value) + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + + jitter_prob = self.jitter_prob + + apply_bri = self.apply_brightness + bri = self.brightness + + apply_cont = self.apply_contrast + cont = self.contrast + + apply_sat = self.apply_saturation + sat = self.saturation + + apply_hue = self.apply_hue + hue = self.hue + + seed = self.seed + if seed is None: + + def color_jitter(images, _): + for i in my_range(images.shape[0]): + if np.random.rand() > jitter_prob: + continue + + images[i] = apply_cj( + images[i].astype("float64"), + apply_bri, + np.random.uniform(bri[0], bri[1]), + apply_cont, + np.random.uniform(cont[0], cont[1]), + apply_sat, + np.random.uniform(sat[0], sat[1]), + apply_hue, + np.random.uniform(hue[0], hue[1]), + ) + return images + + color_jitter.is_parallel = True + return color_jitter + + def color_jitter(images, _, counter): + + random.seed(seed + counter) + N = images.shape[0] + values = np.zeros(N) + bris = np.zeros(N) + conts = np.zeros(N) + sats = np.zeros(N) + hues = np.zeros(N) + for i in range(N): + values[i] = np.float32(random.uniform(0, 1)) + bris[i] = np.float32(random.uniform(bri[0], bri[1])) + conts[i] = np.float32(random.uniform(cont[0], cont[1])) + sats[i] = np.float32(random.uniform(sat[0], sat[1])) + hues[i] = np.float32(random.uniform(hue[0], hue[1])) + for i in my_range(N): + if values[i] > jitter_prob: + continue + images[i] = apply_cj( + images[i].astype("float64"), + apply_bri, + bris[i], + apply_cont, + conts[i], + apply_sat, + sats[i], + apply_hue, + hues[i], + ) + return images + + color_jitter.is_parallel = True + color_jitter.with_counter = True + return color_jitter + + def declare_state_and_memory( + self, previous_state: State + ) -> Tuple[State, Optional[AllocationQuery]]: + return (replace(previous_state, jit_mode=True), None) + +class Grayscale(Operation): + """Add Gaussian Blur with probability blur_prob. + Operates on raw arrays (not tensors). + + Parameters + ---------- + blur_prob : float + The probability with which to flip each image in the batch + horizontally. + """ + + def __init__(self): + super().__init__() + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + + def grayscale(images, _): + for i in my_range(images.shape[0]): + images[i] = ( + 0.2989 * images[i, ..., 0:1] + + 0.5870 * images[i, ..., 1:2] + + 0.1140 * images[i, ..., 2:3] + ) + return images + grayscale.is_parallel = True + return grayscale + + def declare_state_and_memory( + self, previous_state: State + ) -> Tuple[State, Optional[AllocationQuery]]: + return (previous_state, None) + + + +class Solarization(Operation): + """Solarize the image randomly with a given probability by inverting all pixel + values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + Parameters + ---------- + solarization_prob (float): probability of the image being solarized. Default value is 0.5 + threshold (float): all pixels equal or above this value are inverted. + """ + + def __init__( + self, threshold: float = 128, + ): + super().__init__() + self.threshold = threshold + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + threshold = self.threshold + def solarize(images, _): + for i in my_range(images.shape[0]): + mask = images[i] >= threshold + images[i] = np.where(mask, 255 - images[i], images[i]) + return images + + solarize.is_parallel = True + return solarize + + def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: + # No updates to state or extra memory necessary! + return previous_state, None + + +class ThreeAugmentation(Operation): + def __init__( + self, threshold=128, radius_min=0.1, radius_max=2. + ): + super().__init__() + self.threshold = threshold + self.radius_min = radius_min + self.radius_max = radius_max + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + threshold = self.threshold + radius_min = self.radius_min + radius_max = self.radius_max + + def randchoice(images, _): + for i in my_range(images.shape[0]): + idx = random.randint(0, 2) + if idx == 0: + # solarize + mask = images[i] >= threshold + images[i] = np.where(mask, 255 - images[i], images[i]) + elif idx == 1: + # grayscale + images[i] = ( + 0.2989 * images[i, ..., 0:1] + + 0.5870 * images[i, ..., 1:2] + + 0.1140 * images[i, ..., 2:3] + ) + else: + # TODO: GaussianBlur + radius = np.random.uniform(radius_min, radius_max) + # images[i] = gaussian_filter(images[i], radius) + return images + # randchoice.is_parallel = True + + return randchoice + + def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: + # No updates to state or extra memory necessary! + return previous_state, None + +@gin.configurable +def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None): + """ + ThreeAugmentPipeline + """ + image_pipeline = ( + # first_tfl + [ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,), + RandomHorizontalFlip(),]+ + # second_tfl + [ ThreeAugmentation(),] + + ( [RandomColorJitter(jitter_prob=1, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) + + # final_tfl + [ + NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16), + ToTensor(), + # ToDevice(torch.device('cuda')), + ToTorchImage(), + ]) + label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda')),View(-1)] + # Pipeline for each data field + pipelines = { + 'image': image_pipeline, + 'label': label_pipeline + } + return pipelines + +if __name__ == "__main__": + from ffcv import Loader + from ffcv.pipeline import Pipeline + path = "/users/jw02425/projects/EfficientSSL/data/IN1K_val_1000.ffcv" + loader = Loader(path,4, + pipelines=ThreeAugmentPipeline(),) + + for batch in loader: + print(batch) + break \ No newline at end of file diff --git a/vitookit/datasets/transform.py b/vitookit/datasets/transform.py index 33750245324bf46bd19c875bf7db3ec5e7f9800a..4fdfb0f384d9e218e1fd047bc1a7da72f2697084 100644 --- a/vitookit/datasets/transform.py +++ b/vitookit/datasets/transform.py @@ -6,7 +6,7 @@ from torchvision import transforms import torch import gin -from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip +from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip,View from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder, SimpleRGBImageDecoder, CenterCropRGBImageDecoder IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255 @@ -39,7 +39,7 @@ def ValPipeline(img_size=224,ratio= 224/256): ToDevice(torch.device('cuda')), ToTorchImage(), ] - label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda'))] + label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda')),View(-1)] # Pipeline for each data field pipelines = { 'image': image_pipeline, @@ -143,35 +143,6 @@ class GrayScale(object): else: return img -from torchvision.transforms import functional as F -class RandomResizedCrop(transforms.RandomResizedCrop): - """ - RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. - This may lead to results different with torchvision's version. - Following BYOL's TF code: - https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 - """ - @staticmethod - def get_params(img, scale, ratio): - width, height = F._get_image_size(img) - area = height * width - - target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() - log_ratio = torch.log(torch.tensor(ratio)) - aspect_ratio = torch.exp( - torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) - ).item() - - w = int(round(math.sqrt(target_area * aspect_ratio))) - h = int(round(math.sqrt(target_area / aspect_ratio))) - - w = min(w, width) - h = min(h, height) - - i = torch.randint(0, height - h + 1, size=(1,)).item() - j = torch.randint(0, width - w + 1, size=(1,)).item() - - return i, j, h, w def three_augmentation(args = None): img_size = args.input_size @@ -188,7 +159,7 @@ def three_augmentation(args = None): ] else: primary_tfl = [ - RandomResizedCrop( + transforms.RandomResizedCrop( img_size, scale=scale, interpolation=transforms.InterpolationMode.BICUBIC), transforms.RandomHorizontalFlip() ] @@ -206,4 +177,5 @@ def three_augmentation(args = None): mean=torch.tensor(mean), std=torch.tensor(std)) ] - return transforms.Compose(primary_tfl+secondary_tfl+final_tfl) \ No newline at end of file + return transforms.Compose(primary_tfl+secondary_tfl+final_tfl) + diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py index b5cf20eab66a45738d975132fe3c93465f56e5c6..d06c688632b723d5ea4deadf2613218a5006904d 100644 --- a/vitookit/evaluation/eval_cls.py +++ b/vitookit/evaluation/eval_cls.py @@ -196,7 +196,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, with torch.cuda.amp.autocast(): outputs = model(samples) - loss = criterion(samples, outputs, targets) + loss = criterion( outputs, targets) loss /= accum_iter loss_scaler(loss, optimizer, clip_grad=max_norm, diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py new file mode 100644 index 0000000000000000000000000000000000000000..9a477d16fb5ec0ac310b2b0396d0e87f8ac607f5 --- /dev/null +++ b/vitookit/evaluation/eval_cls_ffcv.py @@ -0,0 +1,427 @@ +#!/usr/bin/env python +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Mostly copy-paste from DEiT library: +https://github.com/facebookresearch/deit/blob/main/main.py +""" +from PIL import Image # hack to avoid `CXXABI_1.3.9' not found error + +import argparse +import datetime +import numpy as np +import time +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import json +import os +import math +import sys +import copy +import scipy.io as scio +from vitookit.datasets.ffcv_transform import ThreeAugmentPipeline +from vitookit.datasets.transform import ValPipeline, three_augmentation +from vitookit.utils.helper import * +from vitookit.utils import misc +from vitookit.models.build_model import build_model +from vitookit.datasets import build_dataset +import wandb + +from pathlib import Path +from typing import Iterable, Optional +from torch.nn import functional as F + +from timm.models import create_model +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from timm.scheduler import create_scheduler +from timm.optim import create_optimizer +from timm.utils import NativeScaler, get_state_dict, ModelEma, accuracy +from timm.data import Mixup, create_transform +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import trunc_normal_ + +from ffcv import Loader +from ffcv.loader import OrderOption + +def get_args_parser(): + parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) + parser.add_argument('--batch_size', default=128, type=int, + help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') + parser.add_argument('--accum_iter', default=1, type=int, + help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') + parser.add_argument('--epochs', default=100, type=int) + parser.add_argument('--ckpt_freq', default=5, type=int) + + # Model parameters + parser.add_argument("--compile", action='store_true', default=False, help="compile model with PyTorch 2.0") + parser.add_argument("--prefix", default=None, type=str, help="prefix of the model name") + + parser.add_argument('--input_size', default=224, type=int, help='images input size') + parser.add_argument('-w', '--pretrained_weights', default='', type=str, help="""Path to pretrained + weights to evaluate. Set to `download` to automatically load the pretrained DINO from url. + Otherwise the model is randomly initialized""") + parser.add_argument("--checkpoint_key", default=None, type=str, help='Key to use in the checkpoint (example: "teacher")') + parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', + help='Dropout rate (default: 0.)') + parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT', + help='Attention dropout rate (default: 0.)') + parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', + help='Drop path rate (default: 0.1)') + + + # Optimizer parameters + parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "adamw"') + parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') + parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: None, use opt default)') + parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') + parser.add_argument('--weight_decay', type=float, default=0.05, + help='weight decay (default: 0.05)') + parser.add_argument('--layer_decay', type=float, default=0.75) + # Learning rate schedule parameters + parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "cosine"') + parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', + help='learning rate (default: 5e-4)') + parser.add_argument('--lr_noise', type=float, nargs='+', default=None, metavar='pct, pct', + help='learning rate noise on/off epoch percentages') + + parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') + + parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', + help='epochs to warmup LR, if scheduler supports') + parser.add_argument('--decay_rate', '--dr', type=float, default=0.1, metavar='RATE', + help='LR decay rate (default: 0.1)') + + # Augmentation parameters + parser.add_argument('--ThreeAugment', action='store_true', default=False) #3augment + parser.add_argument('--src',action='store_true', default=False, + help="Use Simple Random Crop (SRC) or Random Resized Crop (RRC). Use SRC when there is less risk of overfitting, such as on ImageNet-21k.") + parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', + help='Color jitter factor (enabled only when not using Auto/RandAug)') + parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', + help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), + parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') + + # * Random Erase params + parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', + help='Random erase prob (default: 0.25)') + parser.add_argument('--remode', type=str, default='pixel', + help='Random erase mode (default: "pixel")') + parser.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') + parser.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') + + # * Mixup params + parser.add_argument('--mixup', type=float, default=0.8, + help='mixup alpha, mixup enabled if > 0. (default: 0.8)') + parser.add_argument('--cutmix', type=float, default=1.0, + help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') + parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') + parser.add_argument('--mixup_prob', type=float, default=1.0, + help='Probability of performing mixup or cutmix when either/both is enabled') + parser.add_argument('--mixup_switch_prob', type=float, default=0.5, + help='Probability of switching to cutmix when both mixup and cutmix enabled') + parser.add_argument('--mixup_mode', type=str, default='batch', + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') + + # * Finetuning params + parser.add_argument('--disable_weight_decay_on_bias_norm', action='store_true', default=False) + parser.add_argument('--init_scale', default=1.0, type=float) + + # Dataset parameters + parser.add_argument('--train_path', type=str, required=True, help='path to train dataset') + parser.add_argument('--val_path', type=str, required=True, help='path to test dataset') + parser.add_argument('--nb_classes', type=int, default=1000, help='number of classes') + + parser.add_argument('--output_dir', default=None, type=str, + help='path where to save, empty for no saving') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true', help='Perform evaluation only') + parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation') + parser.add_argument('--num_workers', default=10, type=int) + parser.add_argument('--pin_mem', action='store_true', + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem', + help='') + parser.set_defaults(pin_mem=True) + + # distributed training parameters + parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + + return parser + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler,lr_scheduler, max_norm: float = 0, + mixup_fn: Optional[Mixup] = None, + ): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + print_freq = max(len(data_loader)//20,20) + + accum_iter = args.accum_iter + for itr,(samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + samples = samples.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + lr_scheduler.step(epoch+itr/len(data_loader)) + if mixup_fn is not None: + samples, targets = mixup_fn(samples, targets) + + with torch.cuda.amp.autocast(): + outputs = model(samples) + loss = criterion(outputs, targets) + + loss /= accum_iter + loss_scaler(loss, optimizer, clip_grad=max_norm, + parameters=model.parameters(), create_graph=False, + need_update=(itr + 1) % accum_iter == 0) + if (itr + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + # log metrics + loss_value = loss.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + # this attribute is added by timm on one optimizer (adahessian) + # is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order + + # if model_ema is not None: + # model_ema.update(model) + + if wandb.run: + wandb.log({'train/loss':loss}) + metric_logger.update(loss=loss_value) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(data_loader, model, device): + criterion = torch.nn.CrossEntropyLoss() + + metric_logger = misc.MetricLogger(delimiter=" ") + header = 'Test:' + + # switch to evaluation mode + model.eval() + + for images, target in metric_logger.log_every(data_loader, 10, header): + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True, dtype=torch.long) + + # compute output + with torch.cuda.amp.autocast(): + output = model(images) + loss = criterion(output, target) + + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + batch_size = images.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' + .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +def main(args): + misc.init_distributed_mode(args) + + print(args) + + device = torch.device(args.device) + + # fix the seed for reproducibility + misc.fix_random_seeds(args.seed) + + cudnn.benchmark = True + + + order = OrderOption.RANDOM if args.distributed else OrderOption.QUASI_RANDOM + data_loader_train = Loader(args.train_path, pipelines=ThreeAugmentPipeline(), + batch_size=args.batch_size, num_workers=args.num_workers, + order=order, distributed=args.distributed,seed=args.seed) + + + data_loader_val = Loader(args.val_path, pipelines=ValPipeline(), + batch_size=args.batch_size, num_workers=args.num_workers, + distributed=args.distributed,seed=args.seed) + + mixup_fn = None + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + print("Mixup is activated!") + mixup_fn = Mixup( + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, + label_smoothing=args.smoothing, num_classes=args.nb_classes) + + + print(f"Model built.") + # load weights to evaluate + + model = build_model(num_classes=args.nb_classes,drop_path_rate=args.drop_path,) + if args.pretrained_weights: + load_pretrained_weights(model, args.pretrained_weights, checkpoint_key=args.checkpoint_key, prefix=args.prefix) + if args.compile: + model = torch.compile(model) + trunc_normal_(model.head.weight, std=2e-5) + + model.to(device) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('number of params:', n_parameters) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + + linear_scaled_lr = args.lr * eff_batch_size / 256.0 + + print("base lr: %.2e" % args.lr ) + print("actual lr: %.2e" % linear_scaled_lr) + args.lr = linear_scaled_lr + + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + optimizer = create_optimizer(args, model_without_ddp) + loss_scaler = NativeScaler() + + lr_scheduler, _ = create_scheduler(args, optimizer) + + if mixup_fn is not None: + # smoothing is handled with mixup label transform + criterion = SoftTargetCrossEntropy() + elif args.smoothing > 0.: + criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) + else: + criterion = torch.nn.CrossEntropyLoss() + + print("criterion = %s" % str(criterion)) + + + output_dir = Path(args.output_dir) if args.output_dir else None + if args.resume: + run_variables={"args":dict(),"epoch":0} + restart_from_checkpoint(args.resume, + optimizer=optimizer, + model=model_without_ddp, + scaler=loss_scaler, + run_variables=run_variables) + # args = run_variables['args'] + args.start_epoch = run_variables["epoch"] + 1 + + if args.eval: + test_stats = evaluate(data_loader_val, model_without_ddp, device) + print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%") + if args.output_dir and misc.is_main_process(): + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(test_stats) + "\n") + exit(0) + + print(f"Start training for {args.epochs} epochs from {args.start_epoch}") + start_time = time.time() + max_accuracy = 0.0 + if args.output_dir and misc.is_main_process(): + try: + wandb.init(job_type='finetune',dir=args.output_dir,resume=True, + config=args.__dict__) + except: + pass + + for epoch in range(args.start_epoch, args.epochs): + + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler,lr_scheduler, + args.clip_grad, mixup_fn, + ) + + checkpoint_paths = ['checkpoint.pth'] + + if epoch%args.ckpt_freq==0 or epoch==args.epochs-1: + test_stats = evaluate(data_loader_val, model, device) + print(f"Accuracy of the network on test images: {test_stats['acc1']:.1f}%") + + if (test_stats["acc1"] >= max_accuracy): + # always only save best checkpoint till now + checkpoint_paths += [ 'checkpoint_best.pth'] + + + max_accuracy = max(max_accuracy, test_stats["acc1"]) + print(f'Max accuracy: {max_accuracy:.2f}%') + + log_stats = {**{f'train/{k}': v for k, v in train_stats.items()}, + **{f'test/{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + else: + log_stats = {**{f'train/{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + + # only save checkpoint on rank 0 + if output_dir and misc.is_main_process(): + if epoch%args.ckpt_freq==0 or epoch==args.epochs-1: + for checkpoint_path in checkpoint_paths: + misc.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'scaler': loss_scaler.state_dict(), + 'args': args, + }, output_dir / checkpoint_path) + + if wandb.run: wandb.log(log_stats) + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) + args = aug_parse(parser) + main(args)