diff --git a/vitookit/datasets/build_dataset.py b/vitookit/datasets/build_dataset.py index 34ae0357a72fd544fe4acc6a4b7db92dc978ac54..b64203fd1e17a1b56665fb4f198cadbbad23ad66 100644 --- a/vitookit/datasets/build_dataset.py +++ b/vitookit/datasets/build_dataset.py @@ -78,6 +78,10 @@ def build_dataset(args, is_train, trnsfrm=None,): trnsfrm.transforms.insert(-2,transforms.Grayscale(num_output_channels=3)) dataset = datasets.Omniglot(args.data_location,transform=tfm,download=True) nb_classes = 1623 + + elif args.data_set == 'INAT': + dataset = INatDataset(args.data_path, train=is_train, year=2018, + transform=tfm) else: print('dataloader of {} is not implemented .. please add the dataloader under datasets folder.'.format(args.data_set))