Skip to content
Snippets Groups Projects
Commit c8caa2ed authored by gent's avatar gent
Browse files

set lr scale to 1 for eval cls

parent 5bc839af
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
"""
Example:
vitrun --nproc_per_node=3 eval_cls_ffcv.py --train_path $train_path --val_path $val_path --gin VisionTransformer.global_pool='\"avg\"' -w wandb:dlib/EfficientSSL/xsa4wubh --batch_size 360 --output_dir outputs/cls
"""
from PIL import Image # hack to avoid `CXXABI_1.3.9' not found error
import argparse
import datetime
import numpy as np
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import json
import os
import math
import sys
import copy
import scipy.io as scio
from vitookit.datasets.ffcv_transform import *
from vitookit.utils.helper import *
from vitookit.utils import misc
from vitookit.models.build_model import build_model
import wandb
from pathlib import Path
from typing import Iterable, Optional
from torch.nn import functional as F
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler, get_state_dict, ModelEma, accuracy
from timm.data import Mixup
from timm.layers import trunc_normal_
from ffcv import Loader
from ffcv.loader import OrderOption
def get_args_parser():
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
parser.add_argument('--batch_size', default=128, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--ckpt_freq', default=5, type=int)
# Model parameters
parser.add_argument("--compile", action='store_true', default=False, help="compile model with PyTorch 2.0")
parser.add_argument("--prefix", default=None, type=str, help="prefix of the model name")
parser.add_argument('--input_size', default=224, type=int, help='images input size')
parser.add_argument('-w', '--pretrained_weights', default='', type=str, help="""Path to pretrained
weights to evaluate. Set to `download` to automatically load the pretrained DINO from url.
Otherwise the model is randomly initialized""")
parser.add_argument("--checkpoint_key", default=None, type=str, help='Key to use in the checkpoint (example: "teacher")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
help='Attention dropout rate (default: 0.)')
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: 0.1)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--layer_decay', type=float, default=0.75)
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
help='learning rate (default: 5e-4)')
parser.add_argument('--lr_noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--decay_rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Augmentation parameters
parser.add_argument('--ThreeAugment', action='store_true', default=True) #3augment
parser.add_argument('--src',action='store_true', default=False,
help="Use Simple Random Crop (SRC) or Random Resized Crop (RRC). Use SRC when there is less risk of overfitting, such as on ImageNet-21k.")
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
help='Color jitter factor (enabled only when not using Auto/RandAug)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# * Finetuning params
parser.add_argument('--disable_weight_decay_on_bias_norm', action='store_true', default=False)
parser.add_argument('--init_scale', default=1.0, type=float)
# Dataset parameters
parser.add_argument('--train_path', type=str, required=True, help='path to train dataset')
parser.add_argument('--val_path', type=str, required=True, help='path to test dataset')
parser.add_argument('--nb_classes', type=int, default=1000, help='number of classes')
parser.add_argument('--output_dir', default=None, type=str,
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
help='')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
return parser
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,lr_scheduler, max_norm: float = 0,
mixup_fn: Optional[Mixup] = None,
):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = max(len(data_loader)//20,20)
accum_iter = args.accum_iter
for itr,(samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
lr_scheduler.step(epoch+itr/len(data_loader))
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(outputs, targets)
loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=False,
need_update=(itr + 1) % accum_iter == 0)
if (itr + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
# log metrics
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
# this attribute is added by timm on one optimizer (adahessian)
# is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
# if model_ema is not None:
# model_ema.update(model)
if wandb.run:
wandb.log({'train/loss':loss})
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
for images, target in metric_logger.log_every(data_loader, 10, header):
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True, dtype=torch.long)
# compute output
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def main(args):
misc.init_distributed_mode(args)
print(args)
import torch
device = torch.device(args.device)
# fix the seed for reproducibility
misc.fix_random_seeds(args.seed)
cudnn.benchmark = True
order = OrderOption.RANDOM if args.distributed else OrderOption.QUASI_RANDOM
data_loader_train = Loader(args.train_path, pipelines=ThreeAugmentPipeline(),
batch_size=args.batch_size, num_workers=args.num_workers,
order=order, distributed=args.distributed,seed=args.seed)
data_loader_val = Loader(args.val_path, pipelines=ValPipeline(),
batch_size=args.batch_size, num_workers=args.num_workers,
distributed=args.distributed,seed=args.seed)
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
print("Mixup is activated!")
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
print(f"Model built.")
# load weights to evaluate
model = build_model(num_classes=args.nb_classes,drop_path_rate=args.drop_path,)
if args.pretrained_weights:
load_pretrained_weights(model, args.pretrained_weights, checkpoint_key=args.checkpoint_key, prefix=args.prefix)
if args.compile:
model = torch.compile(model)
import torch._dynamo
torch._dynamo.config.suppress_errors = True
trunc_normal_(model.head.weight, std=2e-5)
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
linear_scaled_lr = args.lr * eff_batch_size / 256.0
print("base lr: %.2e" % args.lr )
print("actual lr: %.2e" % linear_scaled_lr)
args.lr = linear_scaled_lr
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
optimizer = create_optimizer(args, model_without_ddp)
# hack to optimize patch embedding
print([ i.shape for i in optimizer.param_groups[1]['params']])
optimizer.param_groups[1]['lr_scale'] = 1.0
loss_scaler = NativeScaler()
lr_scheduler, _ = create_scheduler(args, optimizer)
if mixup_fn is not None:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif args.smoothing > 0.:
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
criterion = torch.nn.CrossEntropyLoss()
print("criterion = %s" % str(criterion))
output_dir = Path(args.output_dir) if args.output_dir else None
if args.resume:
run_variables={"args":dict(),"epoch":0}
restart_from_checkpoint(args.resume,
optimizer=optimizer,
model=model_without_ddp,
scaler=loss_scaler,
run_variables=run_variables)
# args = run_variables['args']
args.start_epoch = run_variables["epoch"] + 1
if args.eval:
test_stats = evaluate(data_loader_val, model_without_ddp, device)
print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%")
if args.output_dir and misc.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(test_stats) + "\n")
exit(0)
print(f"Start training for {args.epochs} epochs from {args.start_epoch}")
start_time = time.time()
max_accuracy = 0.0
if args.output_dir and misc.is_main_process():
try:
wandb.init(job_type='finetune',dir=args.output_dir,resume=True,
config=args.__dict__)
except:
pass
for epoch in range(args.start_epoch, args.epochs):
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,lr_scheduler,
args.clip_grad, mixup_fn,
)
checkpoint_paths = ['checkpoint.pth']
if epoch%args.ckpt_freq==0 or epoch==args.epochs-1:
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on test images: {test_stats['acc1']:.1f}%")
if (test_stats["acc1"] >= max_accuracy):
# always only save best checkpoint till now
checkpoint_paths += [ 'checkpoint_best.pth']
max_accuracy = max(max_accuracy, test_stats["acc1"])
print(f'Max accuracy: {max_accuracy:.2f}%')
log_stats = {**{f'train/{k}': v for k, v in train_stats.items()},
**{f'test/{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
else:
log_stats = {**{f'train/{k}': v for k, v in train_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
# only save checkpoint on rank 0
if output_dir and misc.is_main_process():
if epoch%args.ckpt_freq==0 or epoch==args.epochs-1:
for checkpoint_path in checkpoint_paths:
misc.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}, output_dir / checkpoint_path)
if wandb.run: wandb.log(log_stats)
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
args = aug_parse(parser)
main(args)
...@@ -15,7 +15,7 @@ class VisionTransformer(timm.models.vision_transformer.VisionTransformer): ...@@ -15,7 +15,7 @@ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
super(VisionTransformer, self).__init__(global_pool=global_pool,dynamic_img_size = dynamic_img_size, **kwargs) super(VisionTransformer, self).__init__(global_pool=global_pool,dynamic_img_size = dynamic_img_size, **kwargs)
self.global_pool = global_pool self.global_pool = global_pool
print("global_pool",global_pool,self.norm) print("global_pool",global_pool,self.norm)
def vit_small(**kwargs): def vit_small(**kwargs):
model = VisionTransformer( model = VisionTransformer(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment