Skip to content
Snippets Groups Projects
Commit e09a7015 authored by gent's avatar gent
Browse files

add RandDownSampling for compression shift

parent 20b883cb
Branches main
No related tags found
No related merge requests found
......@@ -474,8 +474,22 @@ class RandomChoice(nn.Module):
def __repr__(self) -> str:
return f"{super().__repr__()}(p={self.p})"
from torchvision.transforms import InterpolationMode
class RandDownSampling(nn.Module):
def __init__(self, r=(3/4,1)) -> None:
super().__init__()
self.r = r
def forward(self,x):
h, w = x.shape[-2:]
r = random.uniform(*self.r)
nh,hw = int(h*r),int(w*r)
down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC)
up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC)
return up
@gin.configurable
def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None):
def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsampling=False):
"""
ThreeAugmentPipeline
"""
......@@ -496,9 +510,10 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None):
RandomChoice([
tfms.RandomSolarize(0,1),
tfms.RandomGrayscale(1),
tfms.GaussianBlur(7,sigma=(0.1,2)),
tfms.GaussianBlur(5,sigma=(0.1,2)),
])
])
] +
[RandDownSampling((3/4,1))] if downsampling else [])
label_pipeline = [IntDecoder(), ToTensor(),View(-1)]
# Pipeline for each data field
pipelines = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment