diff --git a/gym_wrapper.py b/gym_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..2f1dd5e83e0f4eecb67d1ac01da3486278e8c10e --- /dev/null +++ b/gym_wrapper.py @@ -0,0 +1,34 @@ +import gym + +class GymWrapper: + def __init__(self, env): + self.env = env + self.new_step_api = self._check_new_step_api() + + def _check_new_step_api(self): + # Check if the environment uses the new step API + try: + reset_result = self.env.reset() + return isinstance(reset_result, tuple) and len(reset_result) == 2 + except Exception: + return False + + def reset(self): + if self.new_step_api: + return self.env.reset()[0] + else: + return self.env.reset() + + def step(self, action): + if self.new_step_api: + next_state, reward, terminated, truncated, info = self.env.step(action) + done = terminated or truncated + return next_state, reward, done, info + else: + return self.env.step(action) + + def render(self): + return self.env.render() + + def close(self): + self.env.close() \ No newline at end of file