diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py index 1d2a722ed187050abe050c35baee589cc0db47d5..ffb51ac1045d71895cf1db4c998d5b39227ad322 100644 --- a/vitookit/datasets/ffcv_transform.py +++ b/vitookit/datasets/ffcv_transform.py @@ -575,7 +575,7 @@ def convolution(image: np.ndarray, kernel: list | tuple, output: np.ndarray) -> return output -class ThreeAugmentation(Operation): +class FFCVThreeAugmentation(Operation): def __init__( self, threshold=128, radius_min=0.1, radius_max=2. ): @@ -621,20 +621,7 @@ class ThreeAugmentation(Operation): mem_alloc = AllocationQuery(previous_state.shape,dtype=previous_state.dtype) return replace(previous_state,jit_mode=True), mem_alloc -class RandomChoice(nn.Module): - """Apply single transformation randomly picked from a list. This transform does not support torchscript.""" - - def __init__(self, transforms, p=None): - super().__init__() - self.transforms = transforms - self.p = p - def __call__(self, *args): - t = random.choices(self.transforms, weights=self.p)[0] - return t(*args) - - def __repr__(self) -> str: - return f"{super().__repr__()}(p={self.p})" from torchvision.transforms import InterpolationMode @@ -652,32 +639,50 @@ class RandDownSampling(nn.Module): down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC) up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC) return up - + +class ThreeAugmentation(nn.Module): + """Apply single transformation randomly picked from a list. This transform does not support torchscript.""" + + def __init__(self, ): + super().__init__() + self.guassian_blur = tfms.GaussianBlur(5,sigma=(0.1,2)) + + + def __call__(self, x): + mb = math.ceil(len(x)/3) + perm = torch.randperm(3) + for i in range(3): + if perm[i] == 0: + x[i*mb:(i+1)*mb] = self.guassian_blur(x[i*mb:(i+1)*mb]) + elif perm[i] == 1: + x[i*mb:(i+1)*mb] = F.solarize(x[i*mb:(i+1)*mb], 0) + else: + x[i*mb:(i+1)*mb] = F.rgb_to_grayscale(x[i*mb:(i+1)*mb],1) + return x + + def __repr__(self) -> str: + return f"{super().__repr__()}(p={self.p})" + @gin.configurable def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscale=1): """ ThreeAugmentPipeline """ image_pipeline = ( - # first_tfl - [ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,), - RandomHorizontalFlip(),]+ - # second_tfl - ( [RandomColorJitter(jitter_prob=0.8, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) + - # final_tfl - [ - NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), - ToTensor(), - # ToDevice(torch.device('cuda')), - ToTorchImage(), - # ThreeAugmentation() - RandomChoice([ - tfms.RandomSolarize(0,1), - tfms.RandomGrayscale(1), - tfms.GaussianBlur(5,sigma=(0.1,2)), - ]), - RandDownSampling((downscale,1)), - ]) + # first_tfl + [ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,), + RandomHorizontalFlip(),]+ + # second_tfl + ( [RandomColorJitter(jitter_prob=1, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter,)] if color_jitter else []) + + # final_tfl + [ + NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), + ToTensor(), + # ToDevice(torch.device('cuda')), + ToTorchImage(), + ThreeAugmentation(), + ]) + label_pipeline = [IntDecoder(), ToTensor(),View(-1)] # Pipeline for each data field pipelines = {