diff --git a/gdl/models/DECA.py b/gdl/models/DECA.py index 2079829f3de25eac3605bd8b9636eaf71536fdb7..df5ea57dfe8d0d6f5b49e37e647b598c8dfc5d5c 100644 --- a/gdl/models/DECA.py +++ b/gdl/models/DECA.py @@ -234,7 +234,7 @@ class DecaModule(LightningModule): """ Initialize the au perceptual loss (not currently used in EMOCA) """ - if 'lipread_loss' in self.deca.config.keys(): + if 'lipread_loss' in self.deca.config.keys() and self.deca.config.lipread_loss.get('load', True): if self.lipread_loss is not None: force_override = True if 'force_override' in self.deca.config.lipread_loss.keys() \ and self.deca.config.lipread_loss.force_override else False diff --git a/gdl/models/EmoDECA.py b/gdl/models/EmoDECA.py index bdcc7c7ef033b2f3750337f5f9f2a2bf4d0cb9af..ef7f7bd832ebf9abcb1ab782729deac123dde48c 100644 --- a/gdl/models/EmoDECA.py +++ b/gdl/models/EmoDECA.py @@ -54,9 +54,12 @@ class EmoDECA(EmotionRecognitionBaseModule): "stage_name": "testing", } # instantiate the face net - self.deca = instantiate_deca(config.model.deca_cfg, deca_stage , "test", deca_checkpoint, deca_checkpoint_kwargs) - self.deca.inout_params.full_run_dir = config.inout.full_run_dir - self._setup_deca(False) + if bool(deca_checkpoint): + self.deca = instantiate_deca(config.model.deca_cfg, deca_stage , "test", deca_checkpoint, deca_checkpoint_kwargs) + self.deca.inout_params.full_run_dir = config.inout.full_run_dir + self._setup_deca(False) + else: + self.deca = None # which latent codes are being used in_size = 0