diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py index ffb51ac1045d71895cf1db4c998d5b39227ad322..ae25390193a116313c31c1c03eb0b1d2c7a7a36a 100644 --- a/vitookit/datasets/ffcv_transform.py +++ b/vitookit/datasets/ffcv_transform.py @@ -24,7 +24,7 @@ import gin from cv2 import GaussianBlur from scipy.ndimage import gaussian_filter -from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip, View +from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip, View, Convert from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder, SimpleRGBImageDecoder, CenterCropRGBImageDecoder import torch @@ -114,101 +114,16 @@ class DynamicResolution: decoder.size=(img_size,img_size) decoder.scale = (lower_scale,upper_scale) - -@gin.configurable -class DynamicResolution1: - def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED, - res=(224,224,0),lower_scale=(0.08,0.08,0), upper_scale=(1,1,0)): - self.start_ramp = start_ramp - self.end_ramp = end_ramp - self.resolution = res - self.lower_scale = lower_scale - self.upper_scale = upper_scale - - - def get_resolution(self, epoch): - - lv,rv, step = self.resolution - - if step==0: - assert lv==rv - return rv - - if epoch <= self.start_ramp: - return lv - - if epoch >= self.end_ramp: - return rv - - interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv]) - num_steps = int(((interp[0]-lv) / step)) - return num_steps*step + lv - - def get_lower_scale(self, epoch): - - lv,rv, step = self.lower_scale - - if step==0: - assert lv==rv - return rv - - if epoch <= self.start_ramp: - return lv - - if epoch >= self.end_ramp: - return rv - - - interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv]) - num_steps = int(((interp[0]-lv) / step)) - return num_steps*step + lv - - def get_upper_scale(self, epoch): - - lv,rv, step = self.upper_scale - - if step==0: - assert lv==rv - return rv - - if epoch <= self.start_ramp: - return lv - - if epoch >= self.end_ramp: - return rv - - interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv]) - num_steps = int(((interp[0]-lv) / step)) - return num_steps*step + lv - - - def __call__(self, loader, epoch,is_ffcv=False): - img_size = self.get_resolution(epoch) - lower_scale = self.get_lower_scale(epoch) - upper_scale = self.get_upper_scale(epoch) - print(f"resolution = {img_size}, scale = ({lower_scale:.2f},{upper_scale:.2f}) ") - - if is_ffcv: - pipeline=loader.pipeline_specs['image'] - if pipeline.decoder.output_size[0] != img_size: - pipeline.decoder.scale = (lower_scale,upper_scale) - pipeline.decoder.output_size = (img_size,img_size) - loader.generate_code() - else: - decoder = loader.dataset.transforms.transform.transforms[0] - decoder.size=(img_size,img_size) - decoder.scale = (lower_scale,upper_scale) - - @gin.configurable def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0)): image_pipeline = [ - RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,ratio=ratio), - RandomHorizontalFlip(), - NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16), + RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,ratio=ratio,), + RandomHorizontalFlip(), ToTensor(), - ToDevice(torch.device('cuda')), + ToDevice(torch.device('cuda')), ToTorchImage(), + Convert(torch.float16), + tfms.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255], inplace=True), ] label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda'))] # Pipeline for each data field @@ -222,10 +137,11 @@ def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0)): def ValPipeline(img_size=224,ratio= 224/256): image_pipeline = [ CenterCropRGBImageDecoder((img_size, img_size), ratio), - NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16), ToTensor(), ToDevice(torch.device('cuda')), ToTorchImage(), + Convert(torch.float16), + tfms.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255], inplace=True), ] label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda')),View(-1)] # Pipeline for each data field @@ -575,54 +491,6 @@ def convolution(image: np.ndarray, kernel: list | tuple, output: np.ndarray) -> return output -class FFCVThreeAugmentation(Operation): - def __init__( - self, threshold=128, radius_min=0.1, radius_max=2. - ): - super().__init__() - self.threshold = threshold - self.radius_min = radius_min - self.radius_max = radius_max - - def generate_code(self) -> Callable: - my_range = Compiler.get_iterator() - threshold = self.threshold - radius_min = self.radius_min - radius_max = self.radius_max - - def randchoice(images, dst): - for i in my_range(images.shape[0]): - idx = random.randint(0, 2) - if idx == 0: - # solarize - mask = images[i] >= threshold - dst[i] = np.where(mask, 255 - images[i], images[i]) - elif idx == 1: - # grayscale - dst[i] = ( - 0.2989 * images[i, ..., 0:1] - + 0.5870 * images[i, ..., 1:2] - + 0.1140 * images[i, ..., 2:3] - ) - else: - sigma = np.random.uniform(radius_min, radius_max) - kernel_width = int(3 * sigma) - if kernel_width % 2 == 0: - kernel_width = kernel_width + 1 # make sure kernel width only sth 3,5,7 etc - kernel = generate_gaussian_filter(sigma,filter_shape=(kernel_width, kernel_width)) - convolution(images[i], kernel, dst[i]) - - return dst - randchoice.parallel = True - return randchoice - - def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: - # No updates to state or extra memory necessary! - mem_alloc = AllocationQuery(previous_state.shape,dtype=previous_state.dtype) - return replace(previous_state,jit_mode=True), mem_alloc - - - from torchvision.transforms import InterpolationMode @gin.configurable @@ -664,7 +532,7 @@ class ThreeAugmentation(nn.Module): return f"{super().__repr__()}(p={self.p})" @gin.configurable -def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscale=1): +def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None): """ ThreeAugmentPipeline """