Skip to content
Snippets Groups Projects
Commit 97edd31f authored by gent's avatar gent
Browse files

dynamic resolution

parent bed04831
No related branches found
No related tags found
No related merge requests found
...@@ -37,43 +37,85 @@ IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255 ...@@ -37,43 +37,85 @@ IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
@gin.configurable @gin.configurable
class DynamicResolution: class DynamicResolution:
def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED, min_res=160, max_res=224, step=32): def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED,
res=(224,224,0),lower_scale=(0.2,0.2,0), upper_scale=(1,1,0)):
self.start_ramp = start_ramp self.start_ramp = start_ramp
self.end_ramp = end_ramp self.end_ramp = end_ramp
self.min_res = min_res self.resolution = res
self.max_res = max_res self.lower_scale = lower_scale
self.step = step self.upper_scale = upper_scale
assert min_res <= max_res
def get_resolution(self, epoch): def get_resolution(self, epoch):
lv,rv, step = self.resolution
if step==0:
assert lv==rv
return rv
if epoch <= self.start_ramp:
return lv
if epoch >= self.end_ramp:
return rv
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
num_steps = int(np.round((interp[0]-lv) / step))
return num_steps*step + lv
def get_lower_scale(self, epoch):
lv,rv, step = self.lower_scale
if step==0:
assert lv==rv
return rv
if epoch <= self.start_ramp:
return lv
if epoch >= self.end_ramp:
return rv
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
num_steps = int(np.round((interp[0]-lv) / step))
return num_steps*step + lv
def get_upper_scale(self, epoch):
lv,rv, step = self.upper_scale
if step==0:
assert lv==rv
return rv
if epoch <= self.start_ramp: if epoch <= self.start_ramp:
return self.min_res return lv
if epoch >= self.end_ramp: if epoch >= self.end_ramp:
return self.max_res return rv
# otherwise, linearly interpolate to the nearest multiple of 32 interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [self.min_res, self.max_res]) num_steps = int(np.round((interp[0]-lv) / step))
final_res = int(np.round(interp[0] / self.step)) * self.step return num_steps*step + lv
return final_res
def __call__(self, loader, epoch,is_ffcv=False): def __call__(self, loader, epoch,is_ffcv=False):
img_size = self.get_resolution(epoch) img_size = self.get_resolution(epoch)
print("resolution = %s" % str(img_size)) lower_scale = self.get_lower_scale(epoch)
upper_scale = self.get_upper_scale(epoch)
print(f"resolution = {img_size}, scale = ({lower_scale:.2f},{upper_scale:.2f}) ")
if is_ffcv: if is_ffcv:
pipeline=loader.pipeline_specs['image'] pipeline=loader.pipeline_specs['image']
if pipeline.decoder.output_size[0] != img_size: if pipeline.decoder.output_size[0] != img_size:
k = img_size/self.max_res pipeline.decoder.scale = (lower_scale,upper_scale)
pipeline.decoder.scale = (0.2*k,1)
pipeline.decoder.output_size = (img_size,img_size) pipeline.decoder.output_size = (img_size,img_size)
loader.generate_code() loader.generate_code()
else: else:
decoder = loader.dataset.transforms.transform.transforms[0] decoder = loader.dataset.transforms.transform.transforms[0]
k = img_size/self.max_res
decoder.size=(img_size,img_size) decoder.size=(img_size,img_size)
decoder.scale = (0.2*k,1) decoder.scale = (lower_scale,upper_scale)
@gin.configurable @gin.configurable
......
...@@ -29,7 +29,7 @@ from vitookit.datasets.transform import three_augmentation ...@@ -29,7 +29,7 @@ from vitookit.datasets.transform import three_augmentation
from vitookit.utils.helper import * from vitookit.utils.helper import *
from vitookit.utils import misc from vitookit.utils import misc
from vitookit.models.build_model import build_model from vitookit.models.build_model import build_model
from vitookit.datasets.build_dataset import build_dataset from vitookit.datasets.build_dataset import build_dataset, build_transform
import wandb import wandb
...@@ -104,6 +104,7 @@ def get_args_parser(): ...@@ -104,6 +104,7 @@ def get_args_parser():
help='LR decay rate (default: 0.1)') help='LR decay rate (default: 0.1)')
# Augmentation parameters # Augmentation parameters
parser.add_argument('--ThreeAugment', action='store_true', default=False) #3augment parser.add_argument('--ThreeAugment', action='store_true', default=False) #3augment
parser.add_argument('--src',action='store_true', default=False, parser.add_argument('--src',action='store_true', default=False,
help="Use Simple Random Crop (SRC) or Random Resized Crop (RRC). Use SRC when there is less risk of overfitting, such as on ImageNet-21k.") help="Use Simple Random Crop (SRC) or Random Resized Crop (RRC). Use SRC when there is less risk of overfitting, such as on ImageNet-21k.")
...@@ -276,7 +277,8 @@ def main(args): ...@@ -276,7 +277,8 @@ def main(args):
if args.ThreeAugment: if args.ThreeAugment:
transform = three_augmentation(args) transform = three_augmentation(args)
else: else:
transform = None transform = build_transform(is_train=True, args=args)
dataset_train, args.nb_classes = build_dataset(args=args, is_train=True, trnsfrm=transform) dataset_train, args.nb_classes = build_dataset(args=args, is_train=True, trnsfrm=transform)
dataset_val, _ = build_dataset(is_train=False, args=args) dataset_val, _ = build_dataset(is_train=False, args=args)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment