diff --git a/vitookit/models/vision_transformer.py b/vitookit/models/vision_transformer.py index 8cd4ce9c9afad9268a43d3e64641323466f18c98..f9f15cd5e3f621dc18f76422c869edf0a2428cdf 100644 --- a/vitookit/models/vision_transformer.py +++ b/vitookit/models/vision_transformer.py @@ -19,7 +19,7 @@ class VisionTransformer(timm.models.vision_transformer.VisionTransformer): def vit_tiny(**kwargs): model = VisionTransformer( - patch_size=16,embed_dim=192,depth=12,num_heads=12,, mlp_ratio=4, qkv_bias=True, + patch_size=16,embed_dim=192,depth=12,num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model