Skip to content
Snippets Groups Projects
Commit 67acff7f authored by Walters, Celyn Dr (Comp Sci & Elec Eng)'s avatar Walters, Celyn Dr (Comp Sci & Elec Eng)
Browse files

Backawrd pass per training epoch

parent bc549c47
No related branches found
No related tags found
No related merge requests found
......@@ -193,7 +193,6 @@ class PPO(SB3_PPO):
approx_kl_divs = []
# Do a complete pass on the rollout buffer
all_losses = 0
for rollout_data in self.rollout_buffer.get(batch_size=self.n_envs): # Sampling has to be done at this size later
actions = rollout_data.actions
if isinstance(self.action_space, gym.spaces.Discrete):
......@@ -277,14 +276,12 @@ class PPO(SB3_PPO):
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
all_losses += loss
# Optimization step
self.policy.optimizer.zero_grad()
all_losses.backward()
# Clip grad norm
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
if not continue_training:
break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment