Skip to content
Snippets Groups Projects
Commit 4f3d42c9 authored by Wu, Jiantao (PG/R - Comp Sci & Elec Eng)'s avatar Wu, Jiantao (PG/R - Comp Sci & Elec Eng)
Browse files

fix threeaug

parent dc2a6f5e
No related branches found
No related tags found
No related merge requests found
...@@ -575,7 +575,7 @@ def convolution(image: np.ndarray, kernel: list | tuple, output: np.ndarray) -> ...@@ -575,7 +575,7 @@ def convolution(image: np.ndarray, kernel: list | tuple, output: np.ndarray) ->
return output return output
class ThreeAugmentation(Operation): class FFCVThreeAugmentation(Operation):
def __init__( def __init__(
self, threshold=128, radius_min=0.1, radius_max=2. self, threshold=128, radius_min=0.1, radius_max=2.
): ):
...@@ -621,20 +621,7 @@ class ThreeAugmentation(Operation): ...@@ -621,20 +621,7 @@ class ThreeAugmentation(Operation):
mem_alloc = AllocationQuery(previous_state.shape,dtype=previous_state.dtype) mem_alloc = AllocationQuery(previous_state.shape,dtype=previous_state.dtype)
return replace(previous_state,jit_mode=True), mem_alloc return replace(previous_state,jit_mode=True), mem_alloc
class RandomChoice(nn.Module):
"""Apply single transformation randomly picked from a list. This transform does not support torchscript."""
def __init__(self, transforms, p=None):
super().__init__()
self.transforms = transforms
self.p = p
def __call__(self, *args):
t = random.choices(self.transforms, weights=self.p)[0]
return t(*args)
def __repr__(self) -> str:
return f"{super().__repr__()}(p={self.p})"
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
...@@ -652,32 +639,50 @@ class RandDownSampling(nn.Module): ...@@ -652,32 +639,50 @@ class RandDownSampling(nn.Module):
down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC) down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC)
up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC) up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC)
return up return up
class ThreeAugmentation(nn.Module):
"""Apply single transformation randomly picked from a list. This transform does not support torchscript."""
def __init__(self, ):
super().__init__()
self.guassian_blur = tfms.GaussianBlur(5,sigma=(0.1,2))
def __call__(self, x):
mb = math.ceil(len(x)/3)
perm = torch.randperm(3)
for i in range(3):
if perm[i] == 0:
x[i*mb:(i+1)*mb] = self.guassian_blur(x[i*mb:(i+1)*mb])
elif perm[i] == 1:
x[i*mb:(i+1)*mb] = F.solarize(x[i*mb:(i+1)*mb], 0)
else:
x[i*mb:(i+1)*mb] = F.rgb_to_grayscale(x[i*mb:(i+1)*mb],1)
return x
def __repr__(self) -> str:
return f"{super().__repr__()}(p={self.p})"
@gin.configurable @gin.configurable
def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscale=1): def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscale=1):
""" """
ThreeAugmentPipeline ThreeAugmentPipeline
""" """
image_pipeline = ( image_pipeline = (
# first_tfl # first_tfl
[ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,), [ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,),
RandomHorizontalFlip(),]+ RandomHorizontalFlip(),]+
# second_tfl # second_tfl
( [RandomColorJitter(jitter_prob=0.8, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) + ( [RandomColorJitter(jitter_prob=1, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter,)] if color_jitter else []) +
# final_tfl # final_tfl
[ [
NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32),
ToTensor(), ToTensor(),
# ToDevice(torch.device('cuda')), # ToDevice(torch.device('cuda')),
ToTorchImage(), ToTorchImage(),
# ThreeAugmentation() ThreeAugmentation(),
RandomChoice([ ])
tfms.RandomSolarize(0,1),
tfms.RandomGrayscale(1),
tfms.GaussianBlur(5,sigma=(0.1,2)),
]),
RandDownSampling((downscale,1)),
])
label_pipeline = [IntDecoder(), ToTensor(),View(-1)] label_pipeline = [IntDecoder(), ToTensor(),View(-1)]
# Pipeline for each data field # Pipeline for each data field
pipelines = { pipelines = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment