diff --git a/gdl/models/DECA.py b/gdl/models/DECA.py index 5905cb38e778cfd0dec88a9382527818aec37564..2079829f3de25eac3605bd8b9636eaf71536fdb7 100644 --- a/gdl/models/DECA.py +++ b/gdl/models/DECA.py @@ -3225,6 +3225,7 @@ class ExpDECA(DECA): deca_code_list_copy = deca_code_list.copy() + # self.E_mica.cfg.model.n_shape #TODO: clean this if-else block up if self.config.exp_deca_global_pose and self.config.exp_deca_jaw_pose: @@ -3284,3 +3285,144 @@ class ExpDECA(DECA): self.D_detail.eval() return self + + +class EMICA(ExpDECA): + + def __init__(self, config): + self.use_mica_shape_dim = True + # self.use_mica_shape_dim = False + from .mica.config import get_cfg_defaults + self.mica_cfg = get_cfg_defaults() + super().__init__(config) + + def _create_model(self): + # 1) Initialize DECA + super()._create_model() + from .mica.mica import MICA + #TODO: MICA uses FLAME + # 1) This is redundant - get rid of it + # 2) Make sure it's the same FLAME as EMOCA + if Path(self.config.mica_model_path).exists(): + mica_path = self.config.mica_model_path + else: + from gdl.utils.other import get_path_to_assets + mica_path = get_path_to_assets() / self.config.mica_model_path + assert mica_path.exists(), f"MICA model path does not exist: '{mica_path}'" + + self.mica_cfg.pretrained_model_path = str(mica_path) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.E_mica = MICA(self.mica_cfg, device, str(mica_path), instantiate_flame=False) + # E_mica should be fixed + self.E_mica.requires_grad_(False) + self.E_mica.testing = True + + # preprocessing for MICA + if self.config.mica_preprocessing: + from insightface.app import FaceAnalysis + self.app = FaceAnalysis(name='antelopev2', providers=['CUDAExecutionProvider']) + self.app.prepare(ctx_id=0, det_size=(224, 224)) + + + def _get_num_shape_params(self): + if self.use_mica_shape_dim: + return self.mica_cfg.model.n_shape + return self.config.n_shape + + def _get_coarse_trainable_parameters(self): + # MICA is not trainable so we don't wanna add it + return super()._get_coarse_trainable_parameters() + + + def train(self, mode: bool = True): + super().train(mode) + self.E_mica.train(False) # MICA is pretrained and will be set to EVAL at all times + + + def _encode_flame(self, images): + + if self.config.mica_preprocessing: + mica_image = self._dirty_image_preprocessing(images) + else: + mica_image = F.interpolate(images, (112,112), mode='bilinear', align_corners=False) + + deca_code, exp_deca_code = super()._encode_flame(images) + mica_code = self.E_mica.encode(images, mica_image) + mica_code = self.E_mica.decode(mica_code, predict_vertices=False) + return deca_code, exp_deca_code, mica_code['pred_shape_code'] + + def _dirty_image_preprocessing(self, input_image): + # breaks whatever gradient flow that may have gone into the image creation process + from gdl.models.mica.detector import get_center, get_arcface_input + from insightface.app.common import Face + + image = input_image.detach().clone().cpu().numpy() * 255. + # b,c,h,w to b,h,w,c + image = image.transpose((0,2,3,1)) + + min_det_score = 0.5 + image_list = list(image) + aligned_image_list = [] + for i, img in enumerate(image_list): + bboxes, kpss = self.app.det_model.detect(img, max_num=0, metric='default') + if bboxes.shape[0] == 0: + aimg = resize(img, output_shape=(112,112), preserve_range=True) + aligned_image_list.append(aimg) + raise RuntimeError("No faces detected") + continue + i = get_center(bboxes, image) + bbox = bboxes[i, 0:4] + det_score = bboxes[i, 4] + # if det_score < min_det_score: + # continue + kps = None + if kpss is not None: + kps = kpss[i] + + face = Face(bbox=bbox, kps=kps, det_score=det_score) + blob, aimg = get_arcface_input(face, img) + aligned_image_list.append(aimg) + aligned_images = np.array(aligned_image_list) + # b,h,w,c to b,c,h,w + aligned_images = aligned_images.transpose((0,3,1,2)) + # to torch to correct device + aligned_images = torch.from_numpy(aligned_images).to(input_image.device) + return aligned_images + + def decompose_code(self, code): + deca_code = code[0] + expdeca_code = code[1] + mica_code = code[2] + + code_list, deca_code_list_copy = super().decompose_code((deca_code, expdeca_code), ) + + id_idx = 0 # identity is the first part of the vector + # assert self.config.n_shape == mica_code.shape[-1] + # assert code_list[id_idx].shape[-1] == mica_code.shape[-1] + if self.use_mica_shape_dim: + code_list[id_idx] = mica_code + else: + code_list[id_idx] = mica_code[..., :self.config.n_shape] + return code_list, deca_code_list_copy + + +def instantiate_deca(cfg, stage, prefix, checkpoint=None, checkpoint_kwargs=None): + """ + Function that instantiates a DecaModule from checkpoint or config + """ + + if checkpoint is None: + deca = DecaModule(cfg.model, cfg.learning, cfg.inout, prefix) + if cfg.model.resume_training: + # This load the DECA model weights from the original DECA release + print("[WARNING] Loading EMOCA checkpoint pretrained by the old code") + deca.deca._load_old_checkpoint() + else: + checkpoint_kwargs = checkpoint_kwargs or {} + deca = DecaModule.load_from_checkpoint(checkpoint_path=checkpoint, strict=False, **checkpoint_kwargs) + if stage == 'train': + mode = True + else: + mode = False + deca.reconfigure(cfg.model, cfg.inout, cfg.learning, prefix, downgrade_ok=True, train=mode) + return deca