Skip to content
Snippets Groups Projects
Commit 107b54da authored by Wu, Jiantao (PG/R - Comp Sci & Elec Eng)'s avatar Wu, Jiantao (PG/R - Comp Sci & Elec Eng)
Browse files

change dres

parent 5405b0aa
No related branches found
No related tags found
No related merge requests found
from ffcv import Loader
import gin
from .ffcv_transform import ThreeAugmentPipeline
@gin.configurable
class DynamicResolution:
def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED,
scheme=1):
schemes ={
1: [
dict(res=160,lower_scale=0.08, upper_scale=1,color_jitter=False),
dict(res=192,lower_scale=0.08, upper_scale=1,color_jitter=False),
dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True),
],
2:[
dict(res=160,lower_scale=0.96, upper_scale=1,color_jitter=False),
dict(res=192,lower_scale=0.38, upper_scale=1,color_jitter=False),
dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True),
],
3:[
dict(res=160,lower_scale=0.08, upper_scale=0.4624,color_jitter=False),
dict(res=192,lower_scale=0.08, upper_scale=0.7056,color_jitter=False),
dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True),
],
4:[
dict(res=160,lower_scale=0.20, upper_scale=0.634,color_jitter=False),
dict(res=192,lower_scale=0.137, upper_scale=0.81,color_jitter=False),
dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True),
],
5: [
dict(res=160,lower_scale=0.08, upper_scale=1,color_jitter=True),
dict(res=192,lower_scale=0.08, upper_scale=1,color_jitter=True),
dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True),
],
}
self.scheme = schemes[scheme]
self.start_ramp = start_ramp
self.end_ramp = end_ramp
def get_config(self, epoch):
if epoch <= self.start_ramp:
return self.scheme[0]
elif epoch>=self.end_ramp:
return self.scheme[-1]
else:
i = (epoch-self.start_ramp) * (len(self.scheme)-1) // (self.end_ramp-self.start_ramp)
return self.scheme[i]
def __call__(self, loader: Loader, epoch,is_ffcv=False):
config = self.get_config(epoch)
print(", ".join([f"{k}={v}" for k,v in config.items()]))
img_size = config['res']
lower_scale = config['lower_scale']
upper_scale = config['upper_scale']
color_jitter = config['color_jitter']
if is_ffcv:
pipelines = ThreeAugmentPipeline(img_size,scale=(lower_scale,upper_scale),color_jitter=color_jitter)
loader.compile_pipeline(pipelines)
else:
# todo: change dres
pipelines = loader.dataset.transforms
decoder = loader.dataset.transforms.transform.transforms[0]
decoder.size=(img_size,img_size)
decoder.scale = (lower_scale,upper_scale)
......@@ -36,84 +36,6 @@ IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
@gin.configurable
class DynamicResolution:
def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED,
scheme=1):
schemes ={
1: [
dict(res=160,lower_scale=0.08, upper_scale=1),
dict(res=192,lower_scale=0.08, upper_scale=1),
dict(res=224,lower_scale=0.08, upper_scale=1),
],
2:[
dict(res=160,lower_scale=0.96, upper_scale=1),
dict(res=192,lower_scale=0.38, upper_scale=1),
dict(res=224,lower_scale=0.08, upper_scale=1),
],
3:[
dict(res=160,lower_scale=0.08, upper_scale=0.4624),
dict(res=192,lower_scale=0.08, upper_scale=0.7056),
dict(res=224,lower_scale=0.08, upper_scale=1),
],
4:[
dict(res=160,lower_scale=0.20, upper_scale=0.634),
dict(res=192,lower_scale=0.137, upper_scale=0.81),
dict(res=224,lower_scale=0.08, upper_scale=1),
],
5: [
dict(res=160,lower_scale=0.20, upper_scale=1),
dict(res=192,lower_scale=0.137, upper_scale=1),
dict(res=224,lower_scale=0.08, upper_scale=1),
],
}
self.scheme = schemes[scheme]
self.start_ramp = start_ramp
self.end_ramp = end_ramp
def get_resolution(self, epoch):
if epoch<=(self.start_ramp+self.end_ramp)/2:
return self.scheme[0]['res']
elif epoch<=self.end_ramp:
return self.scheme[1]['res']
else:
return self.scheme[2]['res']
def get_lower_scale(self, epoch):
if epoch<=(self.start_ramp+self.end_ramp)/2:
return self.scheme[0]['lower_scale']
elif epoch<=self.end_ramp:
return self.scheme[1]['lower_scale']
else:
return self.scheme[2]['lower_scale']
def get_upper_scale(self, epoch):
if epoch<=(self.start_ramp+self.end_ramp)/2:
return self.scheme[0]['upper_scale']
elif epoch<=self.end_ramp:
return self.scheme[1]['upper_scale']
else:
return self.scheme[2]['upper_scale']
def __call__(self, loader, epoch,is_ffcv=False):
img_size = self.get_resolution(epoch)
lower_scale = self.get_lower_scale(epoch)
upper_scale = self.get_upper_scale(epoch)
print(f"resolution = {img_size}, scale = ({lower_scale:.2f},{upper_scale:.2f}) ")
if is_ffcv:
pipeline=loader.pipeline_specs['image']
if pipeline.decoder.output_size[0] != img_size:
pipeline.decoder.scale = (lower_scale,upper_scale)
pipeline.decoder.output_size = (img_size,img_size)
loader.generate_code()
else:
decoder = loader.dataset.transforms.transform.transforms[0]
decoder.size=(img_size,img_size)
decoder.scale = (lower_scale,upper_scale)
@gin.configurable
def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0)):
image_pipeline = [
......
......
......@@ -117,6 +117,6 @@ class SA1BDataset:
m = maskUtils.decode(rle)
return m
if __name__ == '__main__':
data=SA1BDataset('/vol/research/datasets/still/SA-1B',['sa_258805','sa_2601447','sa_6313499','sa_2592075','sa_258406','sa_2601844','sa_6353015','sa_10958302','sa_6344773']) # if you want a subset of the dataset
img, mask, class_ids = data[0]
\ No newline at end of file
......@@ -5,7 +5,7 @@ import numpy as np
from torchvision import transforms
import torch
import gin
from torch import nn
class PermutePatch(object):
"""
......@@ -103,6 +103,9 @@ class GrayScale(object):
else:
return img
class ThreeAugmentation(nn.Module):
def __init__(self,):
super().__init__()
def three_augmentation(args = None):
img_size = args.input_size
......
......
......@@ -43,7 +43,7 @@ from timm.layers import trunc_normal_
from ffcv import Loader
from ffcv.loader import OrderOption
from datasets.dres import DynamicResolution
def get_args_parser():
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
......@@ -342,12 +342,6 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if dres:
if epoch == dres.end_ramp:
print("enhance augmentation!", epoch, dres.end_ramp)
## enhance augmentation, see efficientnetv2
data_loader_train = Loader(args.train_path, pipelines=ThreeAugmentPipeline(color_jitter=0.3),
batch_size=args.batch_size, num_workers=args.num_workers,
order=order, distributed=args.distributed,seed=args.seed)
dres(data_loader_train,epoch,True)
train_stats = train_one_epoch(
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment