diff --git a/rl/models/utils/actor_critic_policy.py b/rl/models/utils/actor_critic_policy.py index ef45f99fd41ceb8542e859f3aed87aba18b243eb..5def0baa81fa441720055ee758784d99b97a969f 100644 --- a/rl/models/utils/actor_critic_policy.py +++ b/rl/models/utils/actor_critic_policy.py @@ -3,6 +3,8 @@ from typing import Optional import torch from stable_baselines3.common.policies import ActorCriticPolicy as SB3_ACP from stable_baselines3.common.preprocessing import preprocess_obs +from stable_baselines3.common.distributions import Distribution +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor from rich import print, inspect import rl.models