diff --git a/alex_trainL.py b/alex_trainL.py
index e97c33e86520354c16725845f52208a5e64052cb..084d953265b3bdd8b62ac7e5960d2dbb979facf4 100644
--- a/alex_trainL.py
+++ b/alex_trainL.py
@@ -189,9 +189,9 @@ def train_model(args, model_id, dataset_id):
         
         lr_monitor = LearningRateMonitor(logging_interval='step')
         if os.path.exists(cpt_path_full): # 1. Checkpoint available
-            trainer = pl.Trainer(accumulate_grad_batches = int(args.accumulate_grads/n_gpus) if args.accumulate_grads > 0 else 0, strategy = "ddp" if args.gpus > 1 else "auto", max_epochs = args.epochs, resume_from_checkpoint =cpt_path_full, accelerator = "gpu" if device is "cuda" else "cpu", devices =args.gpus, num_nodes = 1, callbacks = [checkpointCall, lr_monitor],logger = wandb_logger, log_every_n_steps = 1)#, overfit_batches = 1)
+            trainer = pl.Trainer(overfit_batches = 1, accumulate_grad_batches = int(args.accumulate_grads/n_gpus) if args.accumulate_grads > 0 else 0, strategy = "ddp" if args.gpus > 1 else "auto", max_epochs = args.epochs, resume_from_checkpoint =cpt_path_full, accelerator = "gpu" if device is "cuda" else "cpu", devices =args.gpus, num_nodes = 1, callbacks = [checkpointCall, lr_monitor],logger = wandb_logger, log_every_n_steps = 1)#, overfit_batches = 1)
         else:
-            trainer = pl.Trainer(accumulate_grad_batches = int(args.accumulate_grads/n_gpus) if args.accumulate_grads > 0 else 0, strategy = "ddp" if args.gpus > 1 else "auto", max_epochs = args.epochs, accelerator = "gpu" if device is "cuda" else "cpu", devices =args.gpus, num_nodes = 1, callbacks = [checkpointCall, lr_monitor],logger = wandb_logger, log_every_n_steps= 1)#, overfit_batches= 1)
+            trainer = pl.Trainer(overfit_batches = 1, accumulate_grad_batches = int(args.accumulate_grads/n_gpus) if args.accumulate_grads > 0 else 0, strategy = "ddp" if args.gpus > 1 else "auto", max_epochs = args.epochs, accelerator = "gpu" if device is "cuda" else "cpu", devices =args.gpus, num_nodes = 1, callbacks = [checkpointCall, lr_monitor],logger = wandb_logger, log_every_n_steps= 1)#, overfit_batches= 1)
         
         # accumulate_grad_batches = int(args.accumulate_grads/n_gpus) if args.accumulate_grads > 0 else 0
         # Sigterm handler for condor