diff --git a/gdl/models/DECA.py b/gdl/models/DECA.py
index 5905cb38e778cfd0dec88a9382527818aec37564..2079829f3de25eac3605bd8b9636eaf71536fdb7 100644
--- a/gdl/models/DECA.py
+++ b/gdl/models/DECA.py
@@ -3225,6 +3225,7 @@ class ExpDECA(DECA):
 
         deca_code_list_copy = deca_code_list.copy()
 
+        # self.E_mica.cfg.model.n_shape
 
         #TODO: clean this if-else block up
         if self.config.exp_deca_global_pose and self.config.exp_deca_jaw_pose:
@@ -3284,3 +3285,144 @@ class ExpDECA(DECA):
             self.D_detail.eval()
         return self
 
+
+    
+class EMICA(ExpDECA): 
+
+    def __init__(self, config):
+        self.use_mica_shape_dim = True
+        # self.use_mica_shape_dim = False
+        from .mica.config import get_cfg_defaults
+        self.mica_cfg = get_cfg_defaults()
+        super().__init__(config)
+  
+    def _create_model(self):
+        # 1) Initialize DECA
+        super()._create_model()
+        from .mica.mica import MICA
+        #TODO: MICA uses FLAME  
+        # 1) This is redundant - get rid of it 
+        # 2) Make sure it's the same FLAME as EMOCA
+        if Path(self.config.mica_model_path).exists(): 
+            mica_path = self.config.mica_model_path 
+        else:
+            from gdl.utils.other import get_path_to_assets
+            mica_path = get_path_to_assets() / self.config.mica_model_path  
+            assert mica_path.exists(), f"MICA model path does not exist: '{mica_path}'"
+
+        self.mica_cfg.pretrained_model_path = str(mica_path)
+        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        self.E_mica = MICA(self.mica_cfg, device, str(mica_path), instantiate_flame=False)
+        # E_mica should be fixed 
+        self.E_mica.requires_grad_(False)
+        self.E_mica.testing = True
+
+        # preprocessing for MICA
+        if self.config.mica_preprocessing:
+            from insightface.app import FaceAnalysis
+            self.app = FaceAnalysis(name='antelopev2', providers=['CUDAExecutionProvider'])
+            self.app.prepare(ctx_id=0, det_size=(224, 224))
+
+
+    def _get_num_shape_params(self):
+        if self.use_mica_shape_dim:
+            return self.mica_cfg.model.n_shape
+        return self.config.n_shape
+
+    def _get_coarse_trainable_parameters(self):
+        # MICA is not trainable so we don't wanna add it 
+        return super()._get_coarse_trainable_parameters()
+
+
+    def train(self, mode: bool = True):
+        super().train(mode)
+        self.E_mica.train(False) # MICA is pretrained and will be set to EVAL at all times 
+
+
+    def _encode_flame(self, images):
+        
+        if self.config.mica_preprocessing:
+            mica_image = self._dirty_image_preprocessing(images)
+        else: 
+            mica_image = F.interpolate(images, (112,112), mode='bilinear', align_corners=False)
+
+        deca_code, exp_deca_code = super()._encode_flame(images)
+        mica_code = self.E_mica.encode(images, mica_image) 
+        mica_code = self.E_mica.decode(mica_code, predict_vertices=False)
+        return deca_code, exp_deca_code, mica_code['pred_shape_code']
+ 
+    def _dirty_image_preprocessing(self, input_image): 
+        # breaks whatever gradient flow that may have gone into the image creation process
+        from gdl.models.mica.detector import get_center, get_arcface_input
+        from insightface.app.common import Face
+        
+        image = input_image.detach().clone().cpu().numpy() * 255. 
+        # b,c,h,w to b,h,w,c
+        image = image.transpose((0,2,3,1))
+    
+        min_det_score = 0.5
+        image_list = list(image)
+        aligned_image_list = []
+        for i, img in enumerate(image_list):
+            bboxes, kpss = self.app.det_model.detect(img, max_num=0, metric='default')
+            if bboxes.shape[0] == 0:
+                aimg = resize(img, output_shape=(112,112), preserve_range=True)
+                aligned_image_list.append(aimg)
+                raise RuntimeError("No faces detected")
+                continue
+            i = get_center(bboxes, image)
+            bbox = bboxes[i, 0:4]
+            det_score = bboxes[i, 4]
+            # if det_score < min_det_score:
+            #     continue
+            kps = None
+            if kpss is not None:
+                kps = kpss[i]
+
+            face = Face(bbox=bbox, kps=kps, det_score=det_score)
+            blob, aimg = get_arcface_input(face, img)
+            aligned_image_list.append(aimg)
+        aligned_images = np.array(aligned_image_list)
+        # b,h,w,c to b,c,h,w
+        aligned_images = aligned_images.transpose((0,3,1,2))
+        # to torch to correct device 
+        aligned_images = torch.from_numpy(aligned_images).to(input_image.device)
+        return aligned_images
+
+    def decompose_code(self, code): 
+        deca_code = code[0]
+        expdeca_code = code[1]
+        mica_code = code[2]
+
+        code_list, deca_code_list_copy = super().decompose_code((deca_code, expdeca_code), )
+
+        id_idx = 0 # identity is the first part of the vector
+        # assert self.config.n_shape == mica_code.shape[-1]
+        # assert code_list[id_idx].shape[-1] == mica_code.shape[-1]
+        if self.use_mica_shape_dim:
+            code_list[id_idx] = mica_code
+        else: 
+            code_list[id_idx] = mica_code[..., :self.config.n_shape]
+        return code_list, deca_code_list_copy
+
+
+def instantiate_deca(cfg, stage, prefix, checkpoint=None, checkpoint_kwargs=None):
+    """
+    Function that instantiates a DecaModule from checkpoint or config
+    """
+
+    if checkpoint is None:
+        deca = DecaModule(cfg.model, cfg.learning, cfg.inout, prefix)
+        if cfg.model.resume_training:
+            # This load the DECA model weights from the original DECA release
+            print("[WARNING] Loading EMOCA checkpoint pretrained by the old code")
+            deca.deca._load_old_checkpoint()
+    else:
+        checkpoint_kwargs = checkpoint_kwargs or {}
+        deca = DecaModule.load_from_checkpoint(checkpoint_path=checkpoint, strict=False, **checkpoint_kwargs)
+        if stage == 'train':
+            mode = True
+        else:
+            mode = False
+        deca.reconfigure(cfg.model, cfg.inout, cfg.learning, prefix, downgrade_ok=True, train=mode)
+    return deca