Skip to content
Snippets Groups Projects
Commit 813e9fb5 authored by gent's avatar gent
Browse files

Add RandDownSampling configurable class

parent e09a7015
No related branches found
No related tags found
No related merge requests found
......@@ -476,6 +476,7 @@ class RandomChoice(nn.Module):
from torchvision.transforms import InterpolationMode
@gin.configurable
class RandDownSampling(nn.Module):
def __init__(self, r=(3/4,1)) -> None:
super().__init__()
......@@ -483,13 +484,15 @@ class RandDownSampling(nn.Module):
def forward(self,x):
h, w = x.shape[-2:]
r = random.uniform(*self.r)
if r == 1:
return x
nh,hw = int(h*r),int(w*r)
down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC)
up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC)
return up
@gin.configurable
def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsampling=False):
def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscale=1):
"""
ThreeAugmentPipeline
"""
......@@ -498,7 +501,6 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsamp
[ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,),
RandomHorizontalFlip(),]+
# second_tfl
# [ ThreeAugmentation(),] +
( [RandomColorJitter(jitter_prob=1, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) +
# final_tfl
[
......@@ -511,9 +513,9 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsamp
tfms.RandomSolarize(0,1),
tfms.RandomGrayscale(1),
tfms.GaussianBlur(5,sigma=(0.1,2)),
])
] +
[RandDownSampling((3/4,1))] if downsampling else [])
]),
RandDownSampling((downscale,1)),
])
label_pipeline = [IntDecoder(), ToTensor(),View(-1)]
# Pipeline for each data field
pipelines = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment