diff --git a/README.md b/README.md index 01b52983850902ae08cd9abe64d8781b2c0dc63d..e7ef64a31197d11a3b4c8942f5f0e73644190138 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,7 @@ kellog \ gitpython \ setproctitle \ opencv-python \ -# h5py \ -gym[all] +gymnasium \ stable-baselines3 \ msgpack \ rospkg diff --git a/catkin_ws/src/rl/scripts/aigym.py b/catkin_ws/src/rl/scripts/aigym.py index d7e9d82e837d0f05563dfcfdcb7194c7b8c0a449..745e337e44fc575e6b4a0a667202ed89d2fc9377 100755 --- a/catkin_ws/src/rl/scripts/aigym.py +++ b/catkin_ws/src/rl/scripts/aigym.py @@ -10,7 +10,7 @@ import tf2_ros from sensor_msgs.msg import Image from cv_bridge import CvBridge, CvBridgeError -import gym +import gymnasium as gym from rich import print, inspect import rl # Registers custom environments diff --git a/rl/callbacks/policy_update_callback.py b/rl/callbacks/policy_update_callback.py index 697a2cd25e2c991c5f1a4c7fdfd3919a342e8218..e862d1e7ba6b3c30d4a31396f0fbab91c71878c9 100644 --- a/rl/callbacks/policy_update_callback.py +++ b/rl/callbacks/policy_update_callback.py @@ -1,5 +1,5 @@ from stable_baselines3.common.callbacks import BaseCallback -import gym +import gymnasium as gym # ================================================================================================== class PolicyUpdateCallback(BaseCallback): diff --git a/rl/callbacks/trial_eval_callback.py b/rl/callbacks/trial_eval_callback.py index 4a8109a61a23eb82d73e88e401e50c63781047f3..bbeac12d6ad25b42dd4fefe210fb83e7e6b0ea9f 100644 --- a/rl/callbacks/trial_eval_callback.py +++ b/rl/callbacks/trial_eval_callback.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Optional from stable_baselines3.common.callbacks import EvalCallback -import gym +import gymnasium as gym import optuna # ================================================================================================== diff --git a/rl/environments/__init__.py b/rl/environments/__init__.py index 98b2225965bcd54137fe5dc40b9d28463acaca6f..a1647670c5fe6d88bdc818e5e1b77ba84a689778 100644 --- a/rl/environments/__init__.py +++ b/rl/environments/__init__.py @@ -4,7 +4,7 @@ from .pong import PongEvents from .freeway import FreewayEvents from .skiing import SkiingEvents -from gym.envs.registration import register#, EnvSpec +from gymnasium.envs.registration import register#, EnvSpec # ARBITRARY arguments are passed to EnvSpec version = "v0" diff --git a/rl/environments/cartpole.py b/rl/environments/cartpole.py index dc93124719445f8acd49388b1ae89997e2dbd4cd..c0522fb824b6dfbbe96f1c1cec8b4c332123b906 100644 --- a/rl/environments/cartpole.py +++ b/rl/environments/cartpole.py @@ -1,9 +1,9 @@ import argparse from typing import Optional -import gym -from gym.envs.classic_control.cartpole import CartPoleEnv -from gym import spaces +import gymnasium as gym +from gymnasium.envs.classic_control.cartpole import CartPoleEnv +from gymnasium import spaces import numpy as np import cv2 import pygame, pygame.gfxdraw @@ -30,6 +30,7 @@ class CartPoleEvents(EventEnv, CartPoleEnv): CartPoleEnv.__init__(self, render_mode="rgb_array") EventEnv.__init__(self, self.output_width, self.output_height, args, event_image) # type: ignore self.iter = 0 + self.failReason = None # ---------------------------------------------------------------------------------------------- @staticmethod @@ -57,27 +58,41 @@ class CartPoleEvents(EventEnv, CartPoleEnv): action (int): Which action to perform this step. Returns: - tuple[np.ndarray, float, bool, bool, dict]: Step returns. + tuple[np.ndarray, float, bool, bool]: Step returns. """ _, reward, terminated, truncated, _ = super().step(action) # type: ignore events = self.observe() - info = super().get_info() - info["failReason"] = None x, _, theta, _ = self.state if x < -self.x_threshold: - info["failReason"] = "too_far_left" + self.failReason = "too_far_left" elif x > self.x_threshold: - info["failReason"] = "too_far_right" + self.failReason = "too_far_right" elif theta < -self.theta_threshold_radians: - info["failReason"] = "pole_fell_left" + self.failReason = "pole_fell_left" elif theta > self.theta_threshold_radians: - info["failReason"] = "pole_fell_right" + self.failReason = "pole_fell_right" + else: + self.failReason = None if terminated: # Monitor only writes a line when an episode is terminated self.updatedPolicy = False - return events.numpy(), reward, terminated, truncated, info + + return events.numpy(), reward, terminated, truncated, self.get_info() + + # ---------------------------------------------------------------------------------------------- + def get_info(self) -> dict: + """ + Return a created dictionary for the step info. + + Returns: + dict: Key-value pairs for the step info. + """ + info = EventEnv.get_info(self) + info["failReason"] = self.failReason + + return info # ---------------------------------------------------------------------------------------------- def resize(self, rgb: np.ndarray) -> np.ndarray: @@ -224,6 +239,7 @@ class CartPoleRGB(CartPoleEnv): # FIX: I should normalise my observation space (well, both), but not sure how to for event tensor self.shape = [3, self.output_height, self.output_width] self.observation_space = spaces.Box(low=0, high=255, shape=self.shape, dtype=np.uint8) + self.failReason = None # ---------------------------------------------------------------------------------------------- @staticmethod @@ -250,19 +266,6 @@ class CartPoleRGB(CartPoleEnv): """ self.updatedPolicy = True - # ---------------------------------------------------------------------------------------------- - def get_info(self) -> dict: - """ - Return a created dictionary for the step info. - - Returns: - dict: Key-value pairs for the step info. - """ - return { - "state": self.state, # Used later for bootstrap loss - "updatedPolicy": int(self.updatedPolicy), - } - # ---------------------------------------------------------------------------------------------- def step(self, action: int) -> tuple[np.ndarray, float, bool, bool, dict]: """ @@ -279,25 +282,38 @@ class CartPoleRGB(CartPoleEnv): rgb = self.resize(rgb) rgb = np.transpose(rgb, (2, 0, 1)) # HWC -> CHW - info = self.get_info() - info["failReason"] = None x, _, theta, _ = self.state if x < -self.x_threshold: - info["failReason"] = "too_far_left" + self.failReason = "too_far_left" elif x > self.x_threshold: - info["failReason"] = "too_far_right" + self.failReason = "too_far_right" elif theta < -self.theta_threshold_radians: - info["failReason"] = "pole_fell_left" + self.failReason = "pole_fell_left" elif theta > self.theta_threshold_radians: - info["failReason"] = "pole_fell_right" + self.failReason = "pole_fell_right" + else: + self.failReason = None if terminated: # Monitor only writes a line when an episode is terminated self.updatedPolicy = False - return rgb, reward, terminated, truncated, info + return rgb, reward, terminated, truncated, self.get_info() + + # ---------------------------------------------------------------------------------------------- + def get_info(self) -> dict: + """ + Return a created dictionary for the step info. + + Returns: + dict: Key-value pairs for the step info. + """ + info = EventEnv.get_info(self) + info["failReason"] = self.failReason + + return info # ---------------------------------------------------------------------------------------------- - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]: + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> np.ndarray: """ Resets the environment, and also the model (if defined). @@ -306,16 +322,15 @@ class CartPoleRGB(CartPoleEnv): options (dict, optional): Additional information to specify how the environment is reset. Defaults to None. Returns: - tuple[np.ndarray, Optional[dict]]: First observation and optionally info about the step. + np.ndarray: First observation. """ super().reset(seed=seed, options=options) # NOTE: Not using the output - info = self.get_info() rgb = self.render() rgb = self.resize(rgb) rgb = np.transpose(rgb, (2, 0, 1)) # HWC -> CHW - return rgb, info + return rgb, self.get_info() # ---------------------------------------------------------------------------------------------- def resize(self, rgb: np.ndarray) -> np.ndarray: @@ -361,8 +376,10 @@ class CartPoleRGB(CartPoleEnv): self.screen = pygame.display.set_mode( (self.screen_width, self.screen_height) ) - else: # self.render_mode == "rgb_array" + elif self.render_mode == "rgb_array": self.screen = pygame.Surface((self.screen_width, self.screen_height)) + else: + raise ValueError(f"Invalid render mode? 'self.render_mode'") if self.clock is None: self.clock = pygame.time.Clock() diff --git a/rl/environments/mountaincar.py b/rl/environments/mountaincar.py index 0045a424ee6d74e4c342fef485d9d37f1ef6f857..15a0c2708fef8baf797e63bbb260b3fb5a180900 100644 --- a/rl/environments/mountaincar.py +++ b/rl/environments/mountaincar.py @@ -2,11 +2,11 @@ import argparse import math from typing import Optional -from gym.envs.classic_control.mountain_car import MountainCarEnv -from gym import spaces +from gymnasium.envs.classic_control.mountain_car import MountainCarEnv +from gymnasium import spaces import numpy as np import pygame, pygame.gfxdraw -import gym +import gymnasium as gym from rich import print, inspect from rl.environments.utils import EventEnv diff --git a/rl/environments/utils/atari_env.py b/rl/environments/utils/atari_env.py index d3f7105aef7d521cb89fe5d9736b217f6c6647f2..b898638853873268ee5109287bf4cb89da6cbd16 100644 --- a/rl/environments/utils/atari_env.py +++ b/rl/environments/utils/atari_env.py @@ -6,7 +6,7 @@ from typing import Optional import numpy as np import torch -from gym import spaces +from gymnasium import spaces from ale_py.env.gym import AtariEnv as SB3_AtariEnv from atariari.benchmark.wrapper import ram2label diff --git a/rl/environments/utils/event_env.py b/rl/environments/utils/event_env.py index ce60f065b8499b9fa296a47140323df5f061e0ca..42e64c4d33d5af84ee3744e755a54da6f53732a7 100644 --- a/rl/environments/utils/event_env.py +++ b/rl/environments/utils/event_env.py @@ -4,8 +4,8 @@ import time import argparse from typing import Optional -import gym -from gym import spaces +import gymnasium as gym +from gymnasium import spaces import numpy as np import torch from rich import print, inspect @@ -84,14 +84,13 @@ class EventEnv(gym.Env): tuple[np.ndarray, Optional[dict]]: First observation and optionally info about the step. """ super().reset(seed=seed, options=options) # NOTE: Not using the output - info = self.get_info() self.observe(wait=False) # Initialise ESIM; Need two frames to get a difference to generate events self.events = self.observe() self.iter = 0 - return torch.zeros(*self.shape, dtype=torch.uint8).numpy(), info + return torch.zeros(*self.shape, dtype=torch.uint8).numpy(), self.get_info() # ---------------------------------------------------------------------------------------------- def observe(self, rgb: Optional[np.ndarray] = None, wait: bool = True) -> Optional[torch.Tensor]: @@ -163,11 +162,9 @@ class EventEnv(gym.Env): Returns: dict: Key-value pairs for the step info. """ - rgb = self.resize(self.render()) if self.return_rgb else None return { "state": self.state, # Used later for bootstrap loss "updatedPolicy": int(self.updatedPolicy), - "rgb": rgb, } # ---------------------------------------------------------------------------------------------- diff --git a/rl/environments/utils/skip_cutscenes.py b/rl/environments/utils/skip_cutscenes.py index eea47052c4f6d7225198304907b3c9f3f53f08fe..c8114a83de7da9d4efe3b14c14ede25fb28fed83 100644 --- a/rl/environments/utils/skip_cutscenes.py +++ b/rl/environments/utils/skip_cutscenes.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym import ale_py import numpy as np from stable_baselines3.common.type_aliases import GymObs, GymStepReturn diff --git a/rl/models/edenn.py b/rl/models/edenn.py index 42170868d2a35a56cad7cd8d14646cdbcc40cd48..1a8a1f26467336d4d6adc619cb240148ff989568 100755 --- a/rl/models/edenn.py +++ b/rl/models/edenn.py @@ -4,7 +4,7 @@ import math from typing import Optional import torch -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.torch_layers import BaseFeaturesExtractor from rich import print, inspect @@ -188,7 +188,7 @@ class EDeNN(BaseFeaturesExtractor): # ================================================================================================== if __name__ == "__main__": - import gym + import gymnasium as gym from rich import print, inspect import rl diff --git a/rl/models/naturecnn.py b/rl/models/naturecnn.py index 4fcb5bdc60258e8e89c929144cc2d4843099638e..f2ef22f3c3d6a6f0994eb803019c882db63e7e9c 100755 --- a/rl/models/naturecnn.py +++ b/rl/models/naturecnn.py @@ -6,14 +6,14 @@ Changed colours to increase contrast. import argparse from stable_baselines3.common.torch_layers import NatureCNN as SB3_NatureCNN -import gym.spaces +import gymnasium.spaces from torch import nn import torch as th # ================================================================================================== class NatureCNN(SB3_NatureCNN): # ---------------------------------------------------------------------------------------------- - def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): + def __init__(self, observation_space: gymnasium.spaces.Box, features_dim: int = 512): """ CNN from DQN nature paper: Mnih, Volodymyr, et al. @@ -22,7 +22,7 @@ class NatureCNN(SB3_NatureCNN): Subclassed to remove assertion on observation space. Args: - observation_space (gym.spaces.Box): Observation space. + observation_space (gymnasium.spaces.Box): Observation space. features_dim (int, optional): Number of features extracted. This corresponds to the number of unit for the last layer. Defaults to 512. """ super(SB3_NatureCNN, self).__init__(observation_space, features_dim) # Grandparent class @@ -62,7 +62,7 @@ class NatureCNN(SB3_NatureCNN): # ================================================================================================== if __name__ == "__main__": - import gym + import gymnasium as gym import torch from rich import print, inspect @@ -81,7 +81,7 @@ if __name__ == "__main__": if env_choice == "CartPoleRGB-v0": kwargs += dict(args=args) - # from gym.envs.classic_control.cartpole import CartPoleEnv + # from gymnasium.envs.classic_control.cartpole import CartPoleEnv env = gym.make( env_choice, **kwargs, diff --git a/rl/models/snn.py b/rl/models/snn.py index 5d866e04403bb0aa64fd28c8cf7cef2158e90c3c..82a7b3bebd5036d7cede63cd48cbc6a392ab6db8 100755 --- a/rl/models/snn.py +++ b/rl/models/snn.py @@ -4,7 +4,7 @@ import dataclasses from stable_baselines3.common.torch_layers import BaseFeaturesExtractor import torch -from gym import spaces +from gymnasium import spaces from slayerSNN.slayer import spikeLayer from rich import print, inspect @@ -147,7 +147,7 @@ class SNN(BaseFeaturesExtractor): # ================================================================================================== if __name__ == "__main__": - import gym + import gymnasium as gym from rich import print, inspect import rl diff --git a/rl/models/utils/ppo.py b/rl/models/utils/ppo.py index 52841f657fa7162c6d1be15a65b5f2d670af91db..c4fc3cd3876e091dedf0c33dd270d0ae6a44fc2b 100644 --- a/rl/models/utils/ppo.py +++ b/rl/models/utils/ppo.py @@ -5,7 +5,7 @@ from typing import Optional, Dict, Any import numpy as np import torch import torch.nn.functional as F -import gym +import gymnasium as gym from stable_baselines3 import PPO as SB3_PPO from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.vec_env import VecEnv diff --git a/rl/objective.py b/rl/objective.py index c5379ee08d0f02f39ed6a84dc79eb2a30a9863fa..543763046be2ec884f3aa9439eebcc9f730a5126 100644 --- a/rl/objective.py +++ b/rl/objective.py @@ -5,7 +5,7 @@ from typing import Optional import optuna.trial import optuna.exceptions -import gym +import gymnasium as gym import torch import wandb from wandb.integration.sb3 import WandbCallback diff --git a/tests/framestack.py b/tests/framestack.py index eddd174d95a3914ecb74f323e1e8b02e30fc4aa8..35ba55c5d93b4f74175cd26ed8475ac47dc20b1b 100755 --- a/tests/framestack.py +++ b/tests/framestack.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 -import gym.spaces.box +import gymnasium.spaces.box import stable_baselines3.common.vec_env.stacked_observations import torch # ================================================================================================== def main(): - os = gym.spaces.box.Box(0, 1, (3, 480, 640)) + os = gymnasium.spaces.box.Box(0, 1, (3, 480, 640)) # env = gym.make("CartPole-v1") # env.reset() # obs, _, _, _, info = env.step(0) @@ -21,8 +21,8 @@ def main(): # ================================================================================================== if __name__ == "__main__": # main() - import gym - import gym.wrappers.frame_stack + import gymnasium as gym + import gymnasium.wrappers.frame_stack env = gym.make("CarRacing-v2") def get_shape(): diff --git a/tests/pong.py b/tests/pong.py index b7e0b0d9ed53136b836fdd397f0b4980963a0239..6c71e7cc8ebced0af4571cd8d9228582392f577a 100755 --- a/tests/pong.py +++ b/tests/pong.py @@ -12,10 +12,10 @@ from stable_baselines3.common.torch_layers import NatureCNN from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.atari_wrappers import NoopResetEnv, MaxAndSkipEnv, EpisodicLifeEnv, WarpFrame, ClipRewardEnv from stable_baselines3.common.callbacks import EvalCallback -import gym +import gymnasium as gym import numpy as np import cv2 -from gym import spaces +from gymnasium import spaces from rich import print, inspect import rl diff --git a/tests/visualise.py b/tests/visualise.py index c53717e5ff26bf0d7812889dd5307c9f6300c45d..eae95f6fc177774967292a4dfabaf39294dccd54 100755 --- a/tests/visualise.py +++ b/tests/visualise.py @@ -6,7 +6,7 @@ import time from typing import Optional import torch -import gym +import gymnasium as gym import cv2 import numpy as np from tqdm import tqdm diff --git a/tools/vis_gradients.py b/tools/vis_gradients.py index 70f4f490591e4dbd7f48faec9a7108e9696b213c..1c2cf2400a71d9b1bf4f8575e74f63cf3fe46e56 100755 --- a/tools/vis_gradients.py +++ b/tools/vis_gradients.py @@ -3,7 +3,7 @@ import argparse from pathlib import Path import torchviz -import gym +import gymnasium as gym import rl from rl.models import Estimator, PPO_mod, A2C_mod