Skip to content
Snippets Groups Projects
Commit bed04831 authored by gent's avatar gent
Browse files

Add DynamicResolution class and enable dynamic resolution option

parent 813e9fb5
No related branches found
No related tags found
No related merge requests found
...@@ -35,9 +35,49 @@ from torch import nn ...@@ -35,9 +35,49 @@ from torch import nn
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255 IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255 IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
@gin.configurable
class DynamicResolution:
def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED, min_res=160, max_res=224, step=32):
self.start_ramp = start_ramp
self.end_ramp = end_ramp
self.min_res = min_res
self.max_res = max_res
self.step = step
assert min_res <= max_res
def get_resolution(self, epoch):
if epoch <= self.start_ramp:
return self.min_res
if epoch >= self.end_ramp:
return self.max_res
# otherwise, linearly interpolate to the nearest multiple of 32
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [self.min_res, self.max_res])
final_res = int(np.round(interp[0] / self.step)) * self.step
return final_res
def __call__(self, loader, epoch,is_ffcv=False):
img_size = self.get_resolution(epoch)
print("resolution = %s" % str(img_size))
if is_ffcv:
pipeline=loader.pipeline_specs['image']
if pipeline.decoder.output_size[0] != img_size:
k = img_size/self.max_res
pipeline.decoder.scale = (0.2*k,1)
pipeline.decoder.output_size = (img_size,img_size)
loader.generate_code()
else:
decoder = loader.dataset.transforms.transform.transforms[0]
k = img_size/self.max_res
decoder.size=(img_size,img_size)
decoder.scale = (0.2*k,1)
@gin.configurable @gin.configurable
def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0),blur=False): def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0)):
image_pipeline = [ image_pipeline = [
RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,ratio=ratio), RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,ratio=ratio),
RandomHorizontalFlip(), RandomHorizontalFlip(),
...@@ -46,8 +86,6 @@ def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0),blur=Fal ...@@ -46,8 +86,6 @@ def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0),blur=Fal
ToDevice(torch.device('cuda')), ToDevice(torch.device('cuda')),
ToTorchImage(), ToTorchImage(),
] ]
if blur:
image_pipeline.append(transforms.GaussianBlur(3))
label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda'))] label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda'))]
# Pipeline for each data field # Pipeline for each data field
pipelines = { pipelines = {
......
...@@ -51,6 +51,7 @@ def get_args_parser(): ...@@ -51,6 +51,7 @@ def get_args_parser():
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
parser.add_argument('--epochs', default=100, type=int) parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--ckpt_freq', default=5, type=int) parser.add_argument('--ckpt_freq', default=5, type=int)
parser.add_argument("--dynamic_resolution", default=False, action="store_true", help="Use dynamic resolution.")
# Model parameters # Model parameters
parser.add_argument("--compile", action='store_true', default=False, help="compile model with PyTorch 2.0") parser.add_argument("--compile", action='store_true', default=False, help="compile model with PyTorch 2.0")
...@@ -354,6 +355,14 @@ def main(args): ...@@ -354,6 +355,14 @@ def main(args):
exit(0) exit(0)
print(f"Start training for {args.epochs} epochs from {args.start_epoch}") print(f"Start training for {args.epochs} epochs from {args.start_epoch}")
if args.dynamic_resolution:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
dres = DynamicResolution()
else:
dres = None
start_time = time.time() start_time = time.time()
max_accuracy = 0.0 max_accuracy = 0.0
if args.output_dir and misc.is_main_process(): if args.output_dir and misc.is_main_process():
...@@ -364,7 +373,8 @@ def main(args): ...@@ -364,7 +373,8 @@ def main(args):
pass pass
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if dres:
dres(data_loader_train,epoch,True)
train_stats = train_one_epoch( train_stats = train_one_epoch(
model, criterion, data_loader_train, model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,lr_scheduler, optimizer, device, epoch, loss_scaler,lr_scheduler,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment