Skip to content
Snippets Groups Projects
Commit 421621d9 authored by gent's avatar gent
Browse files

stronger regularization for progressive training

parent abec089b
No related branches found
No related tags found
No related merge requests found
...@@ -38,7 +38,7 @@ IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255 ...@@ -38,7 +38,7 @@ IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
@gin.configurable @gin.configurable
class DynamicResolution: class DynamicResolution:
def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED, def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED,
res=(224,224,0),lower_scale=(0.2,0.2,0), upper_scale=(1,1,0)): res=(224,224,0),lower_scale=(0.08,0.08,0), upper_scale=(1,1,0)):
self.start_ramp = start_ramp self.start_ramp = start_ramp
self.end_ramp = end_ramp self.end_ramp = end_ramp
self.resolution = res self.resolution = res
...@@ -61,7 +61,7 @@ class DynamicResolution: ...@@ -61,7 +61,7 @@ class DynamicResolution:
return rv return rv
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv]) interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
num_steps = int(np.round((interp[0]-lv) / step)) num_steps = int(((interp[0]-lv) / step))
return num_steps*step + lv return num_steps*step + lv
def get_lower_scale(self, epoch): def get_lower_scale(self, epoch):
...@@ -80,7 +80,7 @@ class DynamicResolution: ...@@ -80,7 +80,7 @@ class DynamicResolution:
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv]) interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
num_steps = int(np.round((interp[0]-lv) / step)) num_steps = int(((interp[0]-lv) / step))
return num_steps*step + lv return num_steps*step + lv
def get_upper_scale(self, epoch): def get_upper_scale(self, epoch):
...@@ -98,14 +98,16 @@ class DynamicResolution: ...@@ -98,14 +98,16 @@ class DynamicResolution:
return rv return rv
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv]) interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
num_steps = int(np.round((interp[0]-lv) / step)) num_steps = int(((interp[0]-lv) / step))
return num_steps*step + lv return num_steps*step + lv
def __call__(self, loader, epoch,is_ffcv=False): def __call__(self, loader, epoch,is_ffcv=False):
img_size = self.get_resolution(epoch) img_size = self.get_resolution(epoch)
lower_scale = self.get_lower_scale(epoch) lower_scale = self.get_lower_scale(epoch)
upper_scale = self.get_upper_scale(epoch) upper_scale = self.get_upper_scale(epoch)
print(f"resolution = {img_size}, scale = ({lower_scale:.2f},{upper_scale:.2f}) ") print(f"resolution = {img_size}, scale = ({lower_scale:.2f},{upper_scale:.2f}) ")
if is_ffcv: if is_ffcv:
pipeline=loader.pipeline_specs['image'] pipeline=loader.pipeline_specs['image']
if pipeline.decoder.output_size[0] != img_size: if pipeline.decoder.output_size[0] != img_size:
...@@ -581,7 +583,7 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscal ...@@ -581,7 +583,7 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscal
[ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,), [ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,),
RandomHorizontalFlip(),]+ RandomHorizontalFlip(),]+
# second_tfl # second_tfl
( [RandomColorJitter(jitter_prob=1, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) + ( [RandomColorJitter(jitter_prob=0.8, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) +
# final_tfl # final_tfl
[ [
NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32),
......
...@@ -346,6 +346,12 @@ def main(args): ...@@ -346,6 +346,12 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if dres: if dres:
if epoch == dres.end_ramp:
## enhance augmentation, see efficientnetv2
data_loader_train = Loader(args.train_path, pipelines=ThreeAugmentPipeline(color_jitter=0.4),
batch_size=args.batch_size, num_workers=args.num_workers,
order=order, distributed=args.distributed,seed=args.seed)
dres(data_loader_train,epoch,True) dres(data_loader_train,epoch,True)
train_stats = train_one_epoch( train_stats = train_one_epoch(
model, criterion, data_loader_train, model, criterion, data_loader_train,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment