Skip to content
Snippets Groups Projects
Commit 5405b0aa authored by Wu, Jiantao (PG/R - Comp Sci & Elec Eng)'s avatar Wu, Jiantao (PG/R - Comp Sci & Elec Eng)
Browse files

optimize transform

parent 7cb28a45
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,7 @@ import gin
from cv2 import GaussianBlur
from scipy.ndimage import gaussian_filter
from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip, View
from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip, View, Convert
from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder, SimpleRGBImageDecoder, CenterCropRGBImageDecoder
import torch
......@@ -114,101 +114,16 @@ class DynamicResolution:
decoder.size=(img_size,img_size)
decoder.scale = (lower_scale,upper_scale)
@gin.configurable
class DynamicResolution1:
def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED,
res=(224,224,0),lower_scale=(0.08,0.08,0), upper_scale=(1,1,0)):
self.start_ramp = start_ramp
self.end_ramp = end_ramp
self.resolution = res
self.lower_scale = lower_scale
self.upper_scale = upper_scale
def get_resolution(self, epoch):
lv,rv, step = self.resolution
if step==0:
assert lv==rv
return rv
if epoch <= self.start_ramp:
return lv
if epoch >= self.end_ramp:
return rv
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
num_steps = int(((interp[0]-lv) / step))
return num_steps*step + lv
def get_lower_scale(self, epoch):
lv,rv, step = self.lower_scale
if step==0:
assert lv==rv
return rv
if epoch <= self.start_ramp:
return lv
if epoch >= self.end_ramp:
return rv
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
num_steps = int(((interp[0]-lv) / step))
return num_steps*step + lv
def get_upper_scale(self, epoch):
lv,rv, step = self.upper_scale
if step==0:
assert lv==rv
return rv
if epoch <= self.start_ramp:
return lv
if epoch >= self.end_ramp:
return rv
interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
num_steps = int(((interp[0]-lv) / step))
return num_steps*step + lv
def __call__(self, loader, epoch,is_ffcv=False):
img_size = self.get_resolution(epoch)
lower_scale = self.get_lower_scale(epoch)
upper_scale = self.get_upper_scale(epoch)
print(f"resolution = {img_size}, scale = ({lower_scale:.2f},{upper_scale:.2f}) ")
if is_ffcv:
pipeline=loader.pipeline_specs['image']
if pipeline.decoder.output_size[0] != img_size:
pipeline.decoder.scale = (lower_scale,upper_scale)
pipeline.decoder.output_size = (img_size,img_size)
loader.generate_code()
else:
decoder = loader.dataset.transforms.transform.transforms[0]
decoder.size=(img_size,img_size)
decoder.scale = (lower_scale,upper_scale)
@gin.configurable
def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0)):
image_pipeline = [
RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,ratio=ratio),
RandomHorizontalFlip(),
NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16),
RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,ratio=ratio,),
RandomHorizontalFlip(),
ToTensor(),
ToDevice(torch.device('cuda')),
ToDevice(torch.device('cuda')),
ToTorchImage(),
Convert(torch.float16),
tfms.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255], inplace=True),
]
label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda'))]
# Pipeline for each data field
......@@ -222,10 +137,11 @@ def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0)):
def ValPipeline(img_size=224,ratio= 224/256):
image_pipeline = [
CenterCropRGBImageDecoder((img_size, img_size), ratio),
NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16),
ToTensor(),
ToDevice(torch.device('cuda')),
ToTorchImage(),
Convert(torch.float16),
tfms.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255], inplace=True),
]
label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda')),View(-1)]
# Pipeline for each data field
......@@ -575,54 +491,6 @@ def convolution(image: np.ndarray, kernel: list | tuple, output: np.ndarray) ->
return output
class FFCVThreeAugmentation(Operation):
def __init__(
self, threshold=128, radius_min=0.1, radius_max=2.
):
super().__init__()
self.threshold = threshold
self.radius_min = radius_min
self.radius_max = radius_max
def generate_code(self) -> Callable:
my_range = Compiler.get_iterator()
threshold = self.threshold
radius_min = self.radius_min
radius_max = self.radius_max
def randchoice(images, dst):
for i in my_range(images.shape[0]):
idx = random.randint(0, 2)
if idx == 0:
# solarize
mask = images[i] >= threshold
dst[i] = np.where(mask, 255 - images[i], images[i])
elif idx == 1:
# grayscale
dst[i] = (
0.2989 * images[i, ..., 0:1]
+ 0.5870 * images[i, ..., 1:2]
+ 0.1140 * images[i, ..., 2:3]
)
else:
sigma = np.random.uniform(radius_min, radius_max)
kernel_width = int(3 * sigma)
if kernel_width % 2 == 0:
kernel_width = kernel_width + 1 # make sure kernel width only sth 3,5,7 etc
kernel = generate_gaussian_filter(sigma,filter_shape=(kernel_width, kernel_width))
convolution(images[i], kernel, dst[i])
return dst
randchoice.parallel = True
return randchoice
def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
# No updates to state or extra memory necessary!
mem_alloc = AllocationQuery(previous_state.shape,dtype=previous_state.dtype)
return replace(previous_state,jit_mode=True), mem_alloc
from torchvision.transforms import InterpolationMode
@gin.configurable
......@@ -664,7 +532,7 @@ class ThreeAugmentation(nn.Module):
return f"{super().__repr__()}(p={self.p})"
@gin.configurable
def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscale=1):
def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None):
"""
ThreeAugmentPipeline
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment