Skip to content
Snippets Groups Projects
Commit 421621d9 authored by gent's avatar gent
Browse files

stronger regularization for progressive training

parent abec089b
No related branches found
No related tags found
No related merge requests found
......@@ -38,7 +38,7 @@ IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
@gin.configurable
class DynamicResolution:
def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED,
res=(224,224,0),lower_scale=(0.2,0.2,0), upper_scale=(1,1,0)):
res=(224,224,0),lower_scale=(0.08,0.08,0), upper_scale=(1,1,0)):
self.start_ramp = start_ramp
self.end_ramp = end_ramp
self.resolution = res
......@@ -61,7 +61,7 @@ class DynamicResolution:
return rv
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
num_steps = int(np.round((interp[0]-lv) / step))
num_steps = int(((interp[0]-lv) / step))
return num_steps*step + lv
def get_lower_scale(self, epoch):
......@@ -80,7 +80,7 @@ class DynamicResolution:
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
num_steps = int(np.round((interp[0]-lv) / step))
num_steps = int(((interp[0]-lv) / step))
return num_steps*step + lv
def get_upper_scale(self, epoch):
......@@ -98,14 +98,16 @@ class DynamicResolution:
return rv
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
num_steps = int(np.round((interp[0]-lv) / step))
num_steps = int(((interp[0]-lv) / step))
return num_steps*step + lv
def __call__(self, loader, epoch,is_ffcv=False):
img_size = self.get_resolution(epoch)
lower_scale = self.get_lower_scale(epoch)
upper_scale = self.get_upper_scale(epoch)
print(f"resolution = {img_size}, scale = ({lower_scale:.2f},{upper_scale:.2f}) ")
if is_ffcv:
pipeline=loader.pipeline_specs['image']
if pipeline.decoder.output_size[0] != img_size:
......@@ -581,7 +583,7 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscal
[ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,),
RandomHorizontalFlip(),]+
# second_tfl
( [RandomColorJitter(jitter_prob=1, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) +
( [RandomColorJitter(jitter_prob=0.8, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) +
# final_tfl
[
NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32),
......
......@@ -346,6 +346,12 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if dres:
if epoch == dres.end_ramp:
## enhance augmentation, see efficientnetv2
data_loader_train = Loader(args.train_path, pipelines=ThreeAugmentPipeline(color_jitter=0.4),
batch_size=args.batch_size, num_workers=args.num_workers,
order=order, distributed=args.distributed,seed=args.seed)
dres(data_loader_train,epoch,True)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment