From 1dc8d9bd3589bb0eb73d67b4c8c594ba1f09e27c Mon Sep 17 00:00:00 2001
From: Celyn Walters <celyn.walters@surrey.ac.uk>
Date: Mon, 23 Jan 2023 10:05:37 +0000
Subject: [PATCH] Remove legacy code

---
 rl/environments/cartpole.py | 4 +---
 rl/environments/skiing.py   | 2 +-
 2 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/rl/environments/cartpole.py b/rl/environments/cartpole.py
index 84d3c16..15c6f05 100644
--- a/rl/environments/cartpole.py
+++ b/rl/environments/cartpole.py
@@ -21,9 +21,7 @@ class CartPoleEvents(EventEnv, CartPoleEnv):
 		Args:
 			args (argparse.Namespace): Parsed arguments, depends on which specific env we're using.
 			event_image (bool, optional): Accuumlates events into an event image. Defaults to False.
-			return_rgb (bool, optional): _description_. Defaults to False.
 		"""
-		self.return_rgb = return_rgb
 		self.output_width = output_width
 		self.output_height = output_height
 		self.updatedPolicy = False # Used for logging whenever the policy is updated
@@ -308,7 +306,7 @@ class CartPoleRGB(CartPoleEnv):
 	# ----------------------------------------------------------------------------------------------
 	def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> np.ndarray:
 		"""
-		Resets the environment, and also the model (if defined).
+		Resets the environment.
 
 		Args:
 			seed (int, optional): The seed that is used to initialize the environment's PRNG. Defaults to None.
diff --git a/rl/environments/skiing.py b/rl/environments/skiing.py
index 3b8f022..7c24379 100644
--- a/rl/environments/skiing.py
+++ b/rl/environments/skiing.py
@@ -26,7 +26,7 @@ class SkiingEvents(AtariEnv):
 		# https://github.com/DLR-RM/rl-trained-agents/blob/1e2a45e5d06efd6cc15da6cf2d1939d72dcbdf87/ppo/PongNoFrameskip-v4_1/PongNoFrameskip-v4/config.yml
 		# and refer to https://github.com/DLR-RM/rl-trained-agents/blob/1e2a45e5d06efd6cc15da6cf2d1939d72dcbdf87/ppo/PongNoFrameskip-v4_1/PongNoFrameskip-v4/args.yml
 		# # parser.set_defaults(steps=10000000)
-		parser.set_defaults(n_steps=128 * 8) # n_envs = 8, rollout buffer size is n_steps * n_envs
+		parser.set_defaults(n_steps=128) # rollout buffer full size is n_steps * n_envs
 		parser.set_defaults(n_epochs=4)
 		parser.set_defaults(ent_coef=0.01)
 		parser.set_defaults(lr=2.5e-4)
-- 
GitLab