diff --git a/dirtorch/nets/__init__.py b/dirtorch/nets/__init__.py index 1df2bdadf6f3b62e113df7ecbd4cdb3ee927bbd8..29d844db6cce639f04a72357539e609e5031a6ab 100644 --- a/dirtorch/nets/__init__.py +++ b/dirtorch/nets/__init__.py @@ -34,7 +34,7 @@ def create_model(arch, pretrained='', delete_fc=False, *args, **kwargs): optional arguments ''' # creating model - if arch not in globals(): + if arch not in model_names: raise NameError("unknown model architecture '%s'\nSelect one in %s" % ( arch, ','.join(model_names))) model = globals()[arch](*args, **kwargs)