Skip to content
Snippets Groups Projects
Commit 0febb0ea authored by cocoalex00's avatar cocoalex00
Browse files

Merge branch 'edenn+' of https://gitlab.surrey.ac.uk/ah02299/edenn into edenn+

parents 225b3456 fc594f4e
No related branches found
No related tags found
No related merge requests found
......@@ -189,9 +189,9 @@ def train_model(args, model_id, dataset_id):
lr_monitor = LearningRateMonitor(logging_interval='step')
if os.path.exists(cpt_path_full): # 1. Checkpoint available
trainer = pl.Trainer(accumulate_grad_batches = int(args.accumulate_grads/n_gpus) if args.accumulate_grads > 0 else 0, strategy = "ddp" if args.gpus > 1 else "auto", max_epochs = args.epochs, resume_from_checkpoint =cpt_path_full, accelerator = "gpu" if device is "cuda" else "cpu", devices =args.gpus, num_nodes = 1, callbacks = [checkpointCall, lr_monitor],logger = wandb_logger, log_every_n_steps = 1)#, overfit_batches = 1)
trainer = pl.Trainer(overfit_batches = 1, accumulate_grad_batches = int(args.accumulate_grads/n_gpus) if args.accumulate_grads > 0 else 0, strategy = "ddp" if args.gpus > 1 else "auto", max_epochs = args.epochs, resume_from_checkpoint =cpt_path_full, accelerator = "gpu" if device is "cuda" else "cpu", devices =args.gpus, num_nodes = 1, callbacks = [checkpointCall, lr_monitor],logger = wandb_logger, log_every_n_steps = 1)#, overfit_batches = 1)
else:
trainer = pl.Trainer(accumulate_grad_batches = int(args.accumulate_grads/n_gpus) if args.accumulate_grads > 0 else 0, strategy = "ddp" if args.gpus > 1 else "auto", max_epochs = args.epochs, accelerator = "gpu" if device is "cuda" else "cpu", devices =args.gpus, num_nodes = 1, callbacks = [checkpointCall, lr_monitor],logger = wandb_logger, log_every_n_steps= 1)#, overfit_batches= 1)
trainer = pl.Trainer(overfit_batches = 1, accumulate_grad_batches = int(args.accumulate_grads/n_gpus) if args.accumulate_grads > 0 else 0, strategy = "ddp" if args.gpus > 1 else "auto", max_epochs = args.epochs, accelerator = "gpu" if device is "cuda" else "cpu", devices =args.gpus, num_nodes = 1, callbacks = [checkpointCall, lr_monitor],logger = wandb_logger, log_every_n_steps= 1)#, overfit_batches= 1)
# accumulate_grad_batches = int(args.accumulate_grads/n_gpus) if args.accumulate_grads > 0 else 0
# Sigterm handler for condor
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment