From a51b9f17c9a5eadf6ca45f188898fd79991beb92 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 26 Sep 2025 11:31:00 +0000 Subject: [PATCH 1/4] Initial plan From e3f21902cfe7c5208417f102383713c4bfd43bbb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 26 Sep 2025 11:45:38 +0000 Subject: [PATCH 2/4] Add check_vecenv function with comprehensive tests and documentation Co-authored-by: araffin <1973948+araffin@users.noreply.github.com> --- docs/misc/changelog.rst | 1 + stable_baselines3/common/env_checker.py | 255 ++++++++++++++++++++++++ tests/test_env_checker.py | 192 +++++++++++++++++- 3 files changed, 447 insertions(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 73c2ab202..a173ff40d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -12,6 +12,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - ``RolloutBuffer`` and ``DictRolloutBuffer`` now uses the actual observation / action space ``dtype`` (instead of float32), this should save memory (@Trenza1ore) +- Added ``check_vecenv()`` function to check that a VecEnv follows the VecEnv API and is compatible with Stable-Baselines3 (@copilot) Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 0e9bd05ff..f1d0d48c3 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -7,6 +7,7 @@ from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan +from stable_baselines3.common.vec_env.base_vec_env import VecEnv def _is_oneof_space(space: spaces.Space) -> bool: @@ -537,3 +538,257 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - _check_nan(env) except NotImplementedError: pass + + +def _check_vecenv_spaces(vec_env: VecEnv) -> None: + """ + Check that the VecEnv has valid observation and action spaces. + """ + assert hasattr(vec_env, "observation_space"), "VecEnv must have an observation_space attribute" + assert hasattr(vec_env, "action_space"), "VecEnv must have an action_space attribute" + assert hasattr(vec_env, "num_envs"), "VecEnv must have a num_envs attribute" + + assert isinstance( + vec_env.observation_space, spaces.Space + ), "The observation space must inherit from gymnasium.spaces" + assert isinstance(vec_env.action_space, spaces.Space), "The action space must inherit from gymnasium.spaces" + assert isinstance(vec_env.num_envs, int) and vec_env.num_envs > 0, "num_envs must be a positive integer" + + +def _check_vecenv_reset(vec_env: VecEnv) -> Any: + """ + Check that VecEnv reset method works correctly and returns properly shaped observations. + """ + try: + obs = vec_env.reset() + except Exception as e: + raise RuntimeError(f"VecEnv reset() failed: {e}") from e + + # Check observation shape matches expected vectorized shape + if isinstance(vec_env.observation_space, spaces.Box): + assert isinstance(obs, np.ndarray), f"For Box observation space, reset() must return np.ndarray, got {type(obs)}" + expected_shape = (vec_env.num_envs,) + vec_env.observation_space.shape + assert obs.shape == expected_shape, ( + f"Expected observation shape {expected_shape}, got {obs.shape}. " + f"VecEnv observations should have batch dimension first." + ) + elif isinstance(vec_env.observation_space, spaces.Dict): + assert isinstance(obs, dict), f"For Dict observation space, reset() must return dict, got {type(obs)}" + for key, space in vec_env.observation_space.spaces.items(): + assert key in obs, f"Missing key '{key}' in observation dict" + if isinstance(space, spaces.Box): + expected_shape = (vec_env.num_envs,) + space.shape + assert obs[key].shape == expected_shape, ( + f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" + ) + elif isinstance(vec_env.observation_space, spaces.Discrete): + assert isinstance(obs, np.ndarray), f"For Discrete observation space, reset() must return np.ndarray, got {type(obs)}" + expected_shape = (vec_env.num_envs,) + assert obs.shape == expected_shape, f"Expected observation shape {expected_shape}, got {obs.shape}" + + return obs + + +def _check_vecenv_step(vec_env: VecEnv, obs: Any) -> None: + """ + Check that VecEnv step method works correctly and returns properly shaped values. + """ + # Generate valid actions + if isinstance(vec_env.action_space, spaces.Box): + actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) + elif isinstance(vec_env.action_space, spaces.Discrete): + actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) + elif isinstance(vec_env.action_space, spaces.MultiDiscrete): + actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) + elif isinstance(vec_env.action_space, spaces.MultiBinary): + actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) + else: + actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) + + try: + obs, rewards, dones, infos = vec_env.step(actions) + except Exception as e: + raise RuntimeError(f"VecEnv step() failed: {e}") from e + + # Check rewards + assert isinstance(rewards, np.ndarray), f"step() must return rewards as np.ndarray, got {type(rewards)}" + assert rewards.shape == (vec_env.num_envs,), f"Expected rewards shape ({vec_env.num_envs},), got {rewards.shape}" + + # Check dones + assert isinstance(dones, np.ndarray), f"step() must return dones as np.ndarray, got {type(dones)}" + assert dones.shape == (vec_env.num_envs,), f"Expected dones shape ({vec_env.num_envs},), got {dones.shape}" + assert dones.dtype == bool, f"dones must have dtype bool, got {dones.dtype}" + + # Check infos + assert isinstance(infos, (list, tuple)), f"step() must return infos as list or tuple, got {type(infos)}" + assert len(infos) == vec_env.num_envs, f"Expected infos length {vec_env.num_envs}, got {len(infos)}" + for i, info in enumerate(infos): + assert isinstance(info, dict), f"infos[{i}] must be dict, got {type(info)}" + + # Check observation shape consistency (similar to reset) + if isinstance(vec_env.observation_space, spaces.Box): + assert isinstance(obs, np.ndarray), f"For Box observation space, step() must return np.ndarray, got {type(obs)}" + expected_shape = (vec_env.num_envs,) + vec_env.observation_space.shape + assert obs.shape == expected_shape, ( + f"Expected observation shape {expected_shape}, got {obs.shape}. " + f"VecEnv observations should have batch dimension first." + ) + elif isinstance(vec_env.observation_space, spaces.Dict): + assert isinstance(obs, dict), f"For Dict observation space, step() must return dict, got {type(obs)}" + for key, space in vec_env.observation_space.spaces.items(): + assert key in obs, f"Missing key '{key}' in observation dict" + if isinstance(space, spaces.Box): + expected_shape = (vec_env.num_envs,) + space.shape + assert obs[key].shape == expected_shape, ( + f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" + ) + + +def _check_vecenv_unsupported_spaces(observation_space: spaces.Space, action_space: spaces.Space) -> bool: + """ + Emit warnings when the observation space or action space used is not supported by Stable-Baselines + for VecEnv. This is a VecEnv-specific version of _check_unsupported_spaces. + + :return: True if return value tests should be skipped. + """ + should_skip = graph_space = sequence_space = False + if isinstance(observation_space, spaces.Dict): + nested_dict = False + for key, space in observation_space.spaces.items(): + if isinstance(space, spaces.Dict): + nested_dict = True + elif isinstance(space, spaces.Graph): + graph_space = True + elif isinstance(space, spaces.Sequence): + sequence_space = True + _check_non_zero_start(space, "observation", key) + + if nested_dict: + warnings.warn( + "Nested observation spaces are not supported by Stable Baselines3 " + "(Dict spaces inside Dict space). " + "You should flatten it to have only one level of keys." + "For example, `dict(space1=dict(space2=Box(), space3=Box()), spaces4=Discrete())` " + "is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is." + ) + + if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1: + warnings.warn( + f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} " + "which is currently not supported by Stable-Baselines3. " + "Please convert it to a 1D array using a wrapper: " + "https://github.com/DLR-RM/stable-baselines3/issues/1836." + ) + + if isinstance(observation_space, spaces.Tuple): + warnings.warn( + "The observation space is a Tuple, " + "this is currently not supported by Stable Baselines3. " + "However, you can convert it to a Dict observation space " + "(cf. https://gymnasium.farama.org/api/spaces/composite/#dict). " + "which is supported by SB3." + ) + # Check for Sequence spaces inside Tuple + for space in observation_space.spaces: + if isinstance(space, spaces.Sequence): + sequence_space = True + elif isinstance(space, spaces.Graph): + graph_space = True + + # Check for Sequence spaces inside OneOf + if _is_oneof_space(observation_space): + warnings.warn( + "OneOf observation space is not supported by Stable-Baselines3. " + "Note: The checks for returned values are skipped." + ) + should_skip = True + + _check_non_zero_start(observation_space, "observation") + + if isinstance(observation_space, spaces.Sequence) or sequence_space: + warnings.warn( + "Sequence observation space is not supported by Stable-Baselines3. " + "You can pad your observation to have a fixed size instead.\n" + "Note: The checks for returned values are skipped." + ) + should_skip = True + + if isinstance(observation_space, spaces.Graph) or graph_space: + warnings.warn( + "Graph observation space is not supported by Stable-Baselines3. " + "Note: The checks for returned values are skipped." + ) + should_skip = True + + _check_non_zero_start(action_space, "action") + + if not _is_numpy_array_space(action_space): + warnings.warn( + "The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. " + "This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the " + "action using a wrapper." + ) + return should_skip + + +def check_vecenv(vec_env: VecEnv, warn: bool = True) -> None: + """ + Check that a VecEnv follows the VecEnv API and is compatible with Stable-Baselines3. + + This checker verifies that: + - The VecEnv has proper observation_space, action_space, and num_envs attributes + - The reset() method returns observations with correct vectorized shape + - The step() method returns observations, rewards, dones, and infos with correct shapes + - All return values have the expected types and dimensions + + :param vec_env: The vectorized environment to check + :param warn: Whether to output additional warnings mainly related to + the interaction with Stable Baselines + """ + assert isinstance(vec_env, VecEnv), ( + "Your environment must inherit from stable_baselines3.common.vec_env.VecEnv" + ) + + # ============= Check basic VecEnv attributes ================ + _check_vecenv_spaces(vec_env) + + # Define aliases for convenience + observation_space = vec_env.observation_space + action_space = vec_env.action_space + + # Warn the user if needed - reuse existing space checking logic + if warn: + should_skip = _check_vecenv_unsupported_spaces(observation_space, action_space) + if should_skip: + warnings.warn("VecEnv contains unsupported spaces, skipping some checks") + return + + obs_spaces = observation_space.spaces if isinstance(observation_space, spaces.Dict) else {"": observation_space} + for key, space in obs_spaces.items(): + if isinstance(space, spaces.Box): + _check_box_obs(space, key) + + # Check for the action space + if isinstance(action_space, spaces.Box) and ( + np.any(np.abs(action_space.low) != np.abs(action_space.high)) + or np.any(action_space.low != -1) + or np.any(action_space.high != 1) + ): + warnings.warn( + "We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) " + "cf. https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" + ) + + if isinstance(action_space, spaces.Box): + assert np.all( + np.isfinite(np.array([action_space.low, action_space.high])) + ), "Continuous action space must have a finite lower and upper bound" + + if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32): + warnings.warn( + f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors." + ) + + # ============ Check the VecEnv methods =============== + obs = _check_vecenv_reset(vec_env) + _check_vecenv_step(vec_env, obs) diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 3b0fd179e..eb844c9de 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -5,7 +5,8 @@ import pytest from gymnasium import spaces -from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.env_checker import check_env, check_vecenv +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv class ActionDictTestEnv(gym.Env): @@ -271,3 +272,192 @@ def test_check_env_oneof(): with pytest.warns(Warning, match=r"OneOf.*not supported"): check_env(env, warn=True) + + +class BrokenVecEnv: + """A broken VecEnv that doesn't inherit from VecEnv.""" + + def __init__(self): + self.num_envs = 2 + self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,)) + self.action_space = spaces.Discrete(2) + + +class MissingAttributeVecEnv(VecEnv): + """A VecEnv missing required attributes.""" + + def __init__(self): + # Intentionally not calling super().__init__ + pass + + def reset(self): + pass + + def step_async(self, actions): + pass + + def step_wait(self): + pass + + def close(self): + pass + + def get_attr(self, attr_name, indices=None): + pass + + def set_attr(self, attr_name, value, indices=None): + pass + + def env_method(self, method_name, *method_args, indices=None, **method_kwargs): + pass + + def env_is_wrapped(self, wrapper_class, indices=None): + return [False] * getattr(self, 'num_envs', 1) + + +class WrongShapeVecEnv(VecEnv): + """A VecEnv that returns wrong-shaped observations.""" + + def __init__(self): + super().__init__( + num_envs=2, + observation_space=spaces.Box(low=-1.0, high=1.0, shape=(3,)), + action_space=spaces.Discrete(2) + ) + + def reset(self): + # Return wrong shape (should be (2, 3) but return (3,)) + return np.zeros(3) + + def step_async(self, actions): + pass + + def step_wait(self): + # Return wrong shapes + obs = np.zeros(3) # Should be (2, 3) + rewards = np.zeros(3) # Should be (2,) + dones = np.zeros(3) # Should be (2,) + infos = [{}] # Should be [{}, {}] - list or tuple with 2 elements + return obs, rewards, dones, infos + + def close(self): + pass + + def get_attr(self, attr_name, indices=None): + return [None] * self.num_envs + + def set_attr(self, attr_name, value, indices=None): + pass + + def env_method(self, method_name, *method_args, indices=None, **method_kwargs): + return [None] * self.num_envs + + def env_is_wrapped(self, wrapper_class, indices=None): + return [False] * self.num_envs + + +def test_check_vecenv_basic(): + """Test basic VecEnv checker functionality with a working VecEnv.""" + + def make_env(): + return gym.make('CartPole-v1') + + vec_env = DummyVecEnv([make_env for _ in range(2)]) + + try: + # Should pass without issues + check_vecenv(vec_env, warn=True) + finally: + vec_env.close() + + +def test_check_vecenv_not_vecenv(): + """Test that check_vecenv raises error for non-VecEnv objects.""" + + broken_env = BrokenVecEnv() + + with pytest.raises(AssertionError, match="must inherit from.*VecEnv"): + check_vecenv(broken_env) + + +def test_check_vecenv_missing_attributes(): + """Test that check_vecenv raises error for VecEnv with missing attributes.""" + + broken_env = MissingAttributeVecEnv() + + with pytest.raises(AssertionError, match="must have.*attribute"): + check_vecenv(broken_env) + + +def test_check_vecenv_wrong_shapes(): + """Test that check_vecenv catches wrong-shaped observations and returns.""" + + broken_env = WrongShapeVecEnv() + + try: + with pytest.raises(AssertionError, match="Expected observation shape"): + check_vecenv(broken_env) + finally: + broken_env.close() + + +def test_check_vecenv_dict_space(): + """Test VecEnv checker with Dict observation space.""" + + class DictEnv(gym.Env): + def __init__(self): + self.observation_space = spaces.Dict({ + 'observation': spaces.Box(low=-1.0, high=1.0, shape=(4,)), + 'achieved_goal': spaces.Box(low=-1.0, high=1.0, shape=(2,)), + }) + self.action_space = spaces.Discrete(2) + + def reset(self, *, seed=None, options=None): + return { + 'observation': np.zeros(4), + 'achieved_goal': np.zeros(2), + }, {} + + def step(self, action): + obs = { + 'observation': np.zeros(4), + 'achieved_goal': np.zeros(2), + } + return obs, 0.0, False, False, {} + + def make_dict_env(): + return DictEnv() + + vec_env = DummyVecEnv([make_dict_env for _ in range(2)]) + + try: + check_vecenv(vec_env, warn=True) + finally: + vec_env.close() + + +def test_check_vecenv_warnings(): + """Test that check_vecenv emits appropriate warnings.""" + + class BoxActionEnv(gym.Env): + def __init__(self): + self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(4,)) + # Asymmetric action space should trigger warning + self.action_space = spaces.Box(low=-2.0, high=3.0, shape=(2,)) + + def reset(self, *, seed=None, options=None): + return np.zeros(4), {} + + def step(self, action): + return np.zeros(4), 0.0, False, False, {} + + def make_box_env(): + return BoxActionEnv() + + vec_env = DummyVecEnv([make_box_env for _ in range(2)]) + + try: + with pytest.warns(UserWarning, match="symmetric and normalized Box action space"): + check_vecenv(vec_env, warn=True) + finally: + vec_env.close() From 005ab7f56d941e0c47d171532b85d116bd6d9a49 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 26 Sep 2025 12:27:46 +0000 Subject: [PATCH 3/4] Refactor VecEnv checker: move to separate files and add documentation Co-authored-by: araffin <1973948+araffin@users.noreply.github.com> --- docs/guide/vec_envs.rst | 53 +++ stable_baselines3/common/env_checker.py | 255 -------------- stable_baselines3/common/vec_env/__init__.py | 4 + stable_baselines3/common/vec_env_checker.py | 352 +++++++++++++++++++ tests/test_env_checker.py | 192 +--------- tests/test_vec_env_checker.py | 196 +++++++++++ 6 files changed, 606 insertions(+), 446 deletions(-) create mode 100644 stable_baselines3/common/vec_env_checker.py create mode 100644 tests/test_vec_env_checker.py diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index da009e2c7..d968e67f1 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -183,6 +183,59 @@ This callback can then be used to safely modify environment attributes during tr it calls the environment setter method. +Checking VecEnv Implementation +----------------------------- + +When implementing custom vectorized environments, it's easy to make mistakes that can lead to hard-to-debug issues. +To help with this, Stable-Baselines3 provides a ``check_vecenv`` function that validates your VecEnv implementation +and checks for common issues. + +The ``check_vecenv`` function verifies: + +* The VecEnv properly inherits from ``stable_baselines3.common.vec_env.VecEnv`` +* Required attributes (``num_envs``, ``observation_space``, ``action_space``) are present and valid +* The ``reset()`` method returns observations with the correct vectorized shape (batch dimension first) +* The ``step()`` method returns properly shaped observations, rewards, dones, and infos +* All return values have the expected types and dimensions +* Compatibility with Stable-Baselines3 algorithms + +**Usage:** + +.. code-block:: python + + from stable_baselines3.common.vec_env import DummyVecEnv + from stable_baselines3.common.vec_env_checker import check_vecenv + import gymnasium as gym + + def make_env(): + return gym.make('CartPole-v1') + + # Create your VecEnv + vec_env = DummyVecEnv([make_env for _ in range(4)]) + + # Check the VecEnv implementation + check_vecenv(vec_env, warn=True) + + vec_env.close() + +**When to use:** + +* When implementing a custom VecEnv class +* When debugging issues with vectorized environments +* When contributing new VecEnv implementations to ensure they follow the API +* As a sanity check before training to catch potential issues early + +**Note:** Similar to ``check_env`` for single environments, ``check_vecenv`` is particularly useful during development +and debugging. It helps catch common vectorization mistakes like incorrect batch dimensions, wrong return types, or +missing required methods. + + +VecEnv Checker +~~~~~~~~~~~~~~ + +.. autofunction:: stable_baselines3.common.vec_env_checker.check_vecenv + + Vectorized Environments Wrappers -------------------------------- diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index f1d0d48c3..0e9bd05ff 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -7,7 +7,6 @@ from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan -from stable_baselines3.common.vec_env.base_vec_env import VecEnv def _is_oneof_space(space: spaces.Space) -> bool: @@ -538,257 +537,3 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - _check_nan(env) except NotImplementedError: pass - - -def _check_vecenv_spaces(vec_env: VecEnv) -> None: - """ - Check that the VecEnv has valid observation and action spaces. - """ - assert hasattr(vec_env, "observation_space"), "VecEnv must have an observation_space attribute" - assert hasattr(vec_env, "action_space"), "VecEnv must have an action_space attribute" - assert hasattr(vec_env, "num_envs"), "VecEnv must have a num_envs attribute" - - assert isinstance( - vec_env.observation_space, spaces.Space - ), "The observation space must inherit from gymnasium.spaces" - assert isinstance(vec_env.action_space, spaces.Space), "The action space must inherit from gymnasium.spaces" - assert isinstance(vec_env.num_envs, int) and vec_env.num_envs > 0, "num_envs must be a positive integer" - - -def _check_vecenv_reset(vec_env: VecEnv) -> Any: - """ - Check that VecEnv reset method works correctly and returns properly shaped observations. - """ - try: - obs = vec_env.reset() - except Exception as e: - raise RuntimeError(f"VecEnv reset() failed: {e}") from e - - # Check observation shape matches expected vectorized shape - if isinstance(vec_env.observation_space, spaces.Box): - assert isinstance(obs, np.ndarray), f"For Box observation space, reset() must return np.ndarray, got {type(obs)}" - expected_shape = (vec_env.num_envs,) + vec_env.observation_space.shape - assert obs.shape == expected_shape, ( - f"Expected observation shape {expected_shape}, got {obs.shape}. " - f"VecEnv observations should have batch dimension first." - ) - elif isinstance(vec_env.observation_space, spaces.Dict): - assert isinstance(obs, dict), f"For Dict observation space, reset() must return dict, got {type(obs)}" - for key, space in vec_env.observation_space.spaces.items(): - assert key in obs, f"Missing key '{key}' in observation dict" - if isinstance(space, spaces.Box): - expected_shape = (vec_env.num_envs,) + space.shape - assert obs[key].shape == expected_shape, ( - f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" - ) - elif isinstance(vec_env.observation_space, spaces.Discrete): - assert isinstance(obs, np.ndarray), f"For Discrete observation space, reset() must return np.ndarray, got {type(obs)}" - expected_shape = (vec_env.num_envs,) - assert obs.shape == expected_shape, f"Expected observation shape {expected_shape}, got {obs.shape}" - - return obs - - -def _check_vecenv_step(vec_env: VecEnv, obs: Any) -> None: - """ - Check that VecEnv step method works correctly and returns properly shaped values. - """ - # Generate valid actions - if isinstance(vec_env.action_space, spaces.Box): - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) - elif isinstance(vec_env.action_space, spaces.Discrete): - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) - elif isinstance(vec_env.action_space, spaces.MultiDiscrete): - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) - elif isinstance(vec_env.action_space, spaces.MultiBinary): - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) - else: - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) - - try: - obs, rewards, dones, infos = vec_env.step(actions) - except Exception as e: - raise RuntimeError(f"VecEnv step() failed: {e}") from e - - # Check rewards - assert isinstance(rewards, np.ndarray), f"step() must return rewards as np.ndarray, got {type(rewards)}" - assert rewards.shape == (vec_env.num_envs,), f"Expected rewards shape ({vec_env.num_envs},), got {rewards.shape}" - - # Check dones - assert isinstance(dones, np.ndarray), f"step() must return dones as np.ndarray, got {type(dones)}" - assert dones.shape == (vec_env.num_envs,), f"Expected dones shape ({vec_env.num_envs},), got {dones.shape}" - assert dones.dtype == bool, f"dones must have dtype bool, got {dones.dtype}" - - # Check infos - assert isinstance(infos, (list, tuple)), f"step() must return infos as list or tuple, got {type(infos)}" - assert len(infos) == vec_env.num_envs, f"Expected infos length {vec_env.num_envs}, got {len(infos)}" - for i, info in enumerate(infos): - assert isinstance(info, dict), f"infos[{i}] must be dict, got {type(info)}" - - # Check observation shape consistency (similar to reset) - if isinstance(vec_env.observation_space, spaces.Box): - assert isinstance(obs, np.ndarray), f"For Box observation space, step() must return np.ndarray, got {type(obs)}" - expected_shape = (vec_env.num_envs,) + vec_env.observation_space.shape - assert obs.shape == expected_shape, ( - f"Expected observation shape {expected_shape}, got {obs.shape}. " - f"VecEnv observations should have batch dimension first." - ) - elif isinstance(vec_env.observation_space, spaces.Dict): - assert isinstance(obs, dict), f"For Dict observation space, step() must return dict, got {type(obs)}" - for key, space in vec_env.observation_space.spaces.items(): - assert key in obs, f"Missing key '{key}' in observation dict" - if isinstance(space, spaces.Box): - expected_shape = (vec_env.num_envs,) + space.shape - assert obs[key].shape == expected_shape, ( - f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" - ) - - -def _check_vecenv_unsupported_spaces(observation_space: spaces.Space, action_space: spaces.Space) -> bool: - """ - Emit warnings when the observation space or action space used is not supported by Stable-Baselines - for VecEnv. This is a VecEnv-specific version of _check_unsupported_spaces. - - :return: True if return value tests should be skipped. - """ - should_skip = graph_space = sequence_space = False - if isinstance(observation_space, spaces.Dict): - nested_dict = False - for key, space in observation_space.spaces.items(): - if isinstance(space, spaces.Dict): - nested_dict = True - elif isinstance(space, spaces.Graph): - graph_space = True - elif isinstance(space, spaces.Sequence): - sequence_space = True - _check_non_zero_start(space, "observation", key) - - if nested_dict: - warnings.warn( - "Nested observation spaces are not supported by Stable Baselines3 " - "(Dict spaces inside Dict space). " - "You should flatten it to have only one level of keys." - "For example, `dict(space1=dict(space2=Box(), space3=Box()), spaces4=Discrete())` " - "is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is." - ) - - if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1: - warnings.warn( - f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} " - "which is currently not supported by Stable-Baselines3. " - "Please convert it to a 1D array using a wrapper: " - "https://github.com/DLR-RM/stable-baselines3/issues/1836." - ) - - if isinstance(observation_space, spaces.Tuple): - warnings.warn( - "The observation space is a Tuple, " - "this is currently not supported by Stable Baselines3. " - "However, you can convert it to a Dict observation space " - "(cf. https://gymnasium.farama.org/api/spaces/composite/#dict). " - "which is supported by SB3." - ) - # Check for Sequence spaces inside Tuple - for space in observation_space.spaces: - if isinstance(space, spaces.Sequence): - sequence_space = True - elif isinstance(space, spaces.Graph): - graph_space = True - - # Check for Sequence spaces inside OneOf - if _is_oneof_space(observation_space): - warnings.warn( - "OneOf observation space is not supported by Stable-Baselines3. " - "Note: The checks for returned values are skipped." - ) - should_skip = True - - _check_non_zero_start(observation_space, "observation") - - if isinstance(observation_space, spaces.Sequence) or sequence_space: - warnings.warn( - "Sequence observation space is not supported by Stable-Baselines3. " - "You can pad your observation to have a fixed size instead.\n" - "Note: The checks for returned values are skipped." - ) - should_skip = True - - if isinstance(observation_space, spaces.Graph) or graph_space: - warnings.warn( - "Graph observation space is not supported by Stable-Baselines3. " - "Note: The checks for returned values are skipped." - ) - should_skip = True - - _check_non_zero_start(action_space, "action") - - if not _is_numpy_array_space(action_space): - warnings.warn( - "The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. " - "This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the " - "action using a wrapper." - ) - return should_skip - - -def check_vecenv(vec_env: VecEnv, warn: bool = True) -> None: - """ - Check that a VecEnv follows the VecEnv API and is compatible with Stable-Baselines3. - - This checker verifies that: - - The VecEnv has proper observation_space, action_space, and num_envs attributes - - The reset() method returns observations with correct vectorized shape - - The step() method returns observations, rewards, dones, and infos with correct shapes - - All return values have the expected types and dimensions - - :param vec_env: The vectorized environment to check - :param warn: Whether to output additional warnings mainly related to - the interaction with Stable Baselines - """ - assert isinstance(vec_env, VecEnv), ( - "Your environment must inherit from stable_baselines3.common.vec_env.VecEnv" - ) - - # ============= Check basic VecEnv attributes ================ - _check_vecenv_spaces(vec_env) - - # Define aliases for convenience - observation_space = vec_env.observation_space - action_space = vec_env.action_space - - # Warn the user if needed - reuse existing space checking logic - if warn: - should_skip = _check_vecenv_unsupported_spaces(observation_space, action_space) - if should_skip: - warnings.warn("VecEnv contains unsupported spaces, skipping some checks") - return - - obs_spaces = observation_space.spaces if isinstance(observation_space, spaces.Dict) else {"": observation_space} - for key, space in obs_spaces.items(): - if isinstance(space, spaces.Box): - _check_box_obs(space, key) - - # Check for the action space - if isinstance(action_space, spaces.Box) and ( - np.any(np.abs(action_space.low) != np.abs(action_space.high)) - or np.any(action_space.low != -1) - or np.any(action_space.high != 1) - ): - warnings.warn( - "We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) " - "cf. https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" - ) - - if isinstance(action_space, spaces.Box): - assert np.all( - np.isfinite(np.array([action_space.low, action_space.high])) - ), "Continuous action space must have a finite lower and upper bound" - - if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32): - warnings.warn( - f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors." - ) - - # ============ Check the VecEnv methods =============== - obs = _check_vecenv_reset(vec_env) - _check_vecenv_step(vec_env, obs) diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index ac49a0469..6ec840b1d 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -13,6 +13,9 @@ from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder +# Avoid circular import by importing the vec_env_checker here +from stable_baselines3.common.vec_env_checker import check_vecenv + VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper) @@ -98,6 +101,7 @@ def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None: "VecNormalize", "VecTransposeImage", "VecVideoRecorder", + "check_vecenv", "is_vecenv_wrapped", "sync_envs_normalization", "unwrap_vec_normalize", diff --git a/stable_baselines3/common/vec_env_checker.py b/stable_baselines3/common/vec_env_checker.py new file mode 100644 index 000000000..aa5b09e82 --- /dev/null +++ b/stable_baselines3/common/vec_env_checker.py @@ -0,0 +1,352 @@ +import warnings +from typing import Any, Union + +import numpy as np +from gymnasium import spaces + +from stable_baselines3.common.vec_env.base_vec_env import VecEnv + + +def _is_oneof_space(space: spaces.Space) -> bool: + """ + Return True if the provided space is a OneOf space, + False if not or if the current version of Gym doesn't support this space. + """ + try: + return isinstance(space, spaces.OneOf) # type: ignore[attr-defined] + except AttributeError: + # Gym < v1.0 + return False + + +def _is_numpy_array_space(space: spaces.Space) -> bool: + """ + Returns False if provided space is not representable as a single numpy array + (e.g. Dict and Tuple spaces return False) + """ + return not isinstance(space, (spaces.Dict, spaces.Tuple)) + + +def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool: + """ + Return False if a (Multi)Discrete space has a non-zero start. + """ + return np.allclose(space.start, np.zeros_like(space.start)) + + +def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None: + """ + :param space: Observation or action space + :param space_type: information about whether it is an observation or action space + (for the warning message) + :param key: When the observation space comes from a Dict space, we pass the + corresponding key to have more precise warning messages. Defaults to "". + """ + if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space): + maybe_key = f"(key='{key}')" if key else "" + warnings.warn( + f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) " + "is not supported by Stable-Baselines3. " + "You can use a wrapper (see https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html) " + f"or update your {space_type} space." + ) + + +def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: + """ + Check that the input will be compatible with Stable-Baselines + when the observation is apparently an image. + + :param observation_space: Observation space + :param key: When the observation space comes from a Dict space, we pass the + corresponding key to have more precise warning messages. Defaults to "". + """ + if observation_space.dtype != np.uint8: + warnings.warn( + f"It seems that your observation {key} is an image but its `dtype` " + f"is ({observation_space.dtype}) whereas it has to be `np.uint8`. " + "If your observation is not an image, we recommend you to flatten the observation " + "to have only a 1D vector" + ) + + if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): + warnings.warn( + f"It seems that your observation space {key} is an image but the " + "upper and lower bounds are not in [0, 255]. " + "Because the CNN policy normalize automatically the observation " + "you may encounter issue if the values are not in that range." + ) + + +def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None: + """ + Check that the observation space is correctly formatted + when dealing with a ``Box()`` space. In particular, it checks: + - that the dimensions are big enough when it is an image, and that the type matches + - that the observation has an expected shape (warn the user if not) + """ + # If image, check the low and high values, the type and the number of channels + # and the shape (minimal value) + if len(observation_space.shape) == 3: + _check_image_input(observation_space, key) + + if len(observation_space.shape) not in [1, 3]: + warnings.warn( + f"Your observation {key} has an unconventional shape (neither an image, nor a 1D vector). " + "We recommend you to flatten the observation " + "to have only a 1D vector or use a custom policy to properly process the data." + ) + + +def _check_vecenv_spaces(vec_env: VecEnv) -> None: + """ + Check that the VecEnv has valid observation and action spaces. + """ + assert hasattr(vec_env, "observation_space"), "VecEnv must have an observation_space attribute" + assert hasattr(vec_env, "action_space"), "VecEnv must have an action_space attribute" + assert hasattr(vec_env, "num_envs"), "VecEnv must have a num_envs attribute" + + assert isinstance( + vec_env.observation_space, spaces.Space + ), "The observation space must inherit from gymnasium.spaces" + assert isinstance(vec_env.action_space, spaces.Space), "The action space must inherit from gymnasium.spaces" + assert isinstance(vec_env.num_envs, int) and vec_env.num_envs > 0, "num_envs must be a positive integer" + + +def _check_vecenv_reset(vec_env: VecEnv) -> Any: + """ + Check that VecEnv reset method works correctly and returns properly shaped observations. + """ + try: + obs = vec_env.reset() + except Exception as e: + raise RuntimeError(f"VecEnv reset() failed: {e}") from e + + # Check observation shape matches expected vectorized shape + if isinstance(vec_env.observation_space, spaces.Box): + assert isinstance(obs, np.ndarray), f"For Box observation space, reset() must return np.ndarray, got {type(obs)}" + expected_shape = (vec_env.num_envs,) + vec_env.observation_space.shape + assert obs.shape == expected_shape, ( + f"Expected observation shape {expected_shape}, got {obs.shape}. " + f"VecEnv observations should have batch dimension first." + ) + elif isinstance(vec_env.observation_space, spaces.Dict): + assert isinstance(obs, dict), f"For Dict observation space, reset() must return dict, got {type(obs)}" + for key, space in vec_env.observation_space.spaces.items(): + assert key in obs, f"Missing key '{key}' in observation dict" + if isinstance(space, spaces.Box): + expected_shape = (vec_env.num_envs,) + space.shape + assert obs[key].shape == expected_shape, ( + f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" + ) + elif isinstance(vec_env.observation_space, spaces.Discrete): + assert isinstance(obs, np.ndarray), f"For Discrete observation space, reset() must return np.ndarray, got {type(obs)}" + expected_shape = (vec_env.num_envs,) + assert obs.shape == expected_shape, f"Expected observation shape {expected_shape}, got {obs.shape}" + + return obs + + +def _check_vecenv_step(vec_env: VecEnv, obs: Any) -> None: + """ + Check that VecEnv step method works correctly and returns properly shaped values. + """ + # Generate valid actions + if isinstance(vec_env.action_space, spaces.Box): + actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) + elif isinstance(vec_env.action_space, spaces.Discrete): + actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) + elif isinstance(vec_env.action_space, spaces.MultiDiscrete): + actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) + elif isinstance(vec_env.action_space, spaces.MultiBinary): + actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) + else: + actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) + + try: + obs, rewards, dones, infos = vec_env.step(actions) + except Exception as e: + raise RuntimeError(f"VecEnv step() failed: {e}") from e + + # Check rewards + assert isinstance(rewards, np.ndarray), f"step() must return rewards as np.ndarray, got {type(rewards)}" + assert rewards.shape == (vec_env.num_envs,), f"Expected rewards shape ({vec_env.num_envs},), got {rewards.shape}" + + # Check dones + assert isinstance(dones, np.ndarray), f"step() must return dones as np.ndarray, got {type(dones)}" + assert dones.shape == (vec_env.num_envs,), f"Expected dones shape ({vec_env.num_envs},), got {dones.shape}" + assert dones.dtype == bool, f"dones must have dtype bool, got {dones.dtype}" + + # Check infos + assert isinstance(infos, (list, tuple)), f"step() must return infos as list or tuple, got {type(infos)}" + assert len(infos) == vec_env.num_envs, f"Expected infos length {vec_env.num_envs}, got {len(infos)}" + for i, info in enumerate(infos): + assert isinstance(info, dict), f"infos[{i}] must be dict, got {type(info)}" + + # Check observation shape consistency (similar to reset) + if isinstance(vec_env.observation_space, spaces.Box): + assert isinstance(obs, np.ndarray), f"For Box observation space, step() must return np.ndarray, got {type(obs)}" + expected_shape = (vec_env.num_envs,) + vec_env.observation_space.shape + assert obs.shape == expected_shape, ( + f"Expected observation shape {expected_shape}, got {obs.shape}. " + f"VecEnv observations should have batch dimension first." + ) + elif isinstance(vec_env.observation_space, spaces.Dict): + assert isinstance(obs, dict), f"For Dict observation space, step() must return dict, got {type(obs)}" + for key, space in vec_env.observation_space.spaces.items(): + assert key in obs, f"Missing key '{key}' in observation dict" + if isinstance(space, spaces.Box): + expected_shape = (vec_env.num_envs,) + space.shape + assert obs[key].shape == expected_shape, ( + f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" + ) + + +def _check_vecenv_unsupported_spaces(observation_space: spaces.Space, action_space: spaces.Space) -> bool: + """ + Emit warnings when the observation space or action space used is not supported by Stable-Baselines + for VecEnv. This is a VecEnv-specific version of _check_unsupported_spaces. + + :return: True if return value tests should be skipped. + """ + should_skip = graph_space = sequence_space = False + if isinstance(observation_space, spaces.Dict): + nested_dict = False + for key, space in observation_space.spaces.items(): + if isinstance(space, spaces.Dict): + nested_dict = True + elif isinstance(space, spaces.Graph): + graph_space = True + elif isinstance(space, spaces.Sequence): + sequence_space = True + _check_non_zero_start(space, "observation", key) + + if nested_dict: + warnings.warn( + "Nested observation spaces are not supported by Stable Baselines3 " + "(Dict spaces inside Dict space). " + "You should flatten it to have only one level of keys." + "For example, `dict(space1=dict(space2=Box(), space3=Box()), spaces4=Discrete())` " + "is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is." + ) + + if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1: + warnings.warn( + f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} " + "which is currently not supported by Stable-Baselines3. " + "Please convert it to a 1D array using a wrapper: " + "https://github.com/DLR-RM/stable-baselines3/issues/1836." + ) + + if isinstance(observation_space, spaces.Tuple): + warnings.warn( + "The observation space is a Tuple, " + "this is currently not supported by Stable Baselines3. " + "However, you can convert it to a Dict observation space " + "(cf. https://gymnasium.farama.org/api/spaces/composite/#dict). " + "which is supported by SB3." + ) + # Check for Sequence spaces inside Tuple + for space in observation_space.spaces: + if isinstance(space, spaces.Sequence): + sequence_space = True + elif isinstance(space, spaces.Graph): + graph_space = True + + # Check for Sequence spaces inside OneOf + if _is_oneof_space(observation_space): + warnings.warn( + "OneOf observation space is not supported by Stable-Baselines3. " + "Note: The checks for returned values are skipped." + ) + should_skip = True + + _check_non_zero_start(observation_space, "observation") + + if isinstance(observation_space, spaces.Sequence) or sequence_space: + warnings.warn( + "Sequence observation space is not supported by Stable-Baselines3. " + "You can pad your observation to have a fixed size instead.\n" + "Note: The checks for returned values are skipped." + ) + should_skip = True + + if isinstance(observation_space, spaces.Graph) or graph_space: + warnings.warn( + "Graph observation space is not supported by Stable-Baselines3. " + "Note: The checks for returned values are skipped." + ) + should_skip = True + + _check_non_zero_start(action_space, "action") + + if not _is_numpy_array_space(action_space): + warnings.warn( + "The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. " + "This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the " + "action using a wrapper." + ) + return should_skip + + +def check_vecenv(vec_env: VecEnv, warn: bool = True) -> None: + """ + Check that a VecEnv follows the VecEnv API and is compatible with Stable-Baselines3. + + This checker verifies that: + - The VecEnv has proper observation_space, action_space, and num_envs attributes + - The reset() method returns observations with correct vectorized shape + - The step() method returns observations, rewards, dones, and infos with correct shapes + - All return values have the expected types and dimensions + + :param vec_env: The vectorized environment to check + :param warn: Whether to output additional warnings mainly related to + the interaction with Stable Baselines + """ + assert isinstance(vec_env, VecEnv), ( + "Your environment must inherit from stable_baselines3.common.vec_env.VecEnv" + ) + + # ============= Check basic VecEnv attributes ================ + _check_vecenv_spaces(vec_env) + + # Define aliases for convenience + observation_space = vec_env.observation_space + action_space = vec_env.action_space + + # Warn the user if needed - reuse existing space checking logic + if warn: + should_skip = _check_vecenv_unsupported_spaces(observation_space, action_space) + if should_skip: + warnings.warn("VecEnv contains unsupported spaces, skipping some checks") + return + + obs_spaces = observation_space.spaces if isinstance(observation_space, spaces.Dict) else {"": observation_space} + for key, space in obs_spaces.items(): + if isinstance(space, spaces.Box): + _check_box_obs(space, key) + + # Check for the action space + if isinstance(action_space, spaces.Box) and ( + np.any(np.abs(action_space.low) != np.abs(action_space.high)) + or np.any(action_space.low != -1) + or np.any(action_space.high != 1) + ): + warnings.warn( + "We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) " + "cf. https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" + ) + + if isinstance(action_space, spaces.Box): + assert np.all( + np.isfinite(np.array([action_space.low, action_space.high])) + ), "Continuous action space must have a finite lower and upper bound" + + if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32): + warnings.warn( + f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors." + ) + + # ============ Check the VecEnv methods =============== + obs = _check_vecenv_reset(vec_env) + _check_vecenv_step(vec_env, obs) \ No newline at end of file diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index eb844c9de..3b0fd179e 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -5,8 +5,7 @@ import pytest from gymnasium import spaces -from stable_baselines3.common.env_checker import check_env, check_vecenv -from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv +from stable_baselines3.common.env_checker import check_env class ActionDictTestEnv(gym.Env): @@ -272,192 +271,3 @@ def test_check_env_oneof(): with pytest.warns(Warning, match=r"OneOf.*not supported"): check_env(env, warn=True) - - -class BrokenVecEnv: - """A broken VecEnv that doesn't inherit from VecEnv.""" - - def __init__(self): - self.num_envs = 2 - self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,)) - self.action_space = spaces.Discrete(2) - - -class MissingAttributeVecEnv(VecEnv): - """A VecEnv missing required attributes.""" - - def __init__(self): - # Intentionally not calling super().__init__ - pass - - def reset(self): - pass - - def step_async(self, actions): - pass - - def step_wait(self): - pass - - def close(self): - pass - - def get_attr(self, attr_name, indices=None): - pass - - def set_attr(self, attr_name, value, indices=None): - pass - - def env_method(self, method_name, *method_args, indices=None, **method_kwargs): - pass - - def env_is_wrapped(self, wrapper_class, indices=None): - return [False] * getattr(self, 'num_envs', 1) - - -class WrongShapeVecEnv(VecEnv): - """A VecEnv that returns wrong-shaped observations.""" - - def __init__(self): - super().__init__( - num_envs=2, - observation_space=spaces.Box(low=-1.0, high=1.0, shape=(3,)), - action_space=spaces.Discrete(2) - ) - - def reset(self): - # Return wrong shape (should be (2, 3) but return (3,)) - return np.zeros(3) - - def step_async(self, actions): - pass - - def step_wait(self): - # Return wrong shapes - obs = np.zeros(3) # Should be (2, 3) - rewards = np.zeros(3) # Should be (2,) - dones = np.zeros(3) # Should be (2,) - infos = [{}] # Should be [{}, {}] - list or tuple with 2 elements - return obs, rewards, dones, infos - - def close(self): - pass - - def get_attr(self, attr_name, indices=None): - return [None] * self.num_envs - - def set_attr(self, attr_name, value, indices=None): - pass - - def env_method(self, method_name, *method_args, indices=None, **method_kwargs): - return [None] * self.num_envs - - def env_is_wrapped(self, wrapper_class, indices=None): - return [False] * self.num_envs - - -def test_check_vecenv_basic(): - """Test basic VecEnv checker functionality with a working VecEnv.""" - - def make_env(): - return gym.make('CartPole-v1') - - vec_env = DummyVecEnv([make_env for _ in range(2)]) - - try: - # Should pass without issues - check_vecenv(vec_env, warn=True) - finally: - vec_env.close() - - -def test_check_vecenv_not_vecenv(): - """Test that check_vecenv raises error for non-VecEnv objects.""" - - broken_env = BrokenVecEnv() - - with pytest.raises(AssertionError, match="must inherit from.*VecEnv"): - check_vecenv(broken_env) - - -def test_check_vecenv_missing_attributes(): - """Test that check_vecenv raises error for VecEnv with missing attributes.""" - - broken_env = MissingAttributeVecEnv() - - with pytest.raises(AssertionError, match="must have.*attribute"): - check_vecenv(broken_env) - - -def test_check_vecenv_wrong_shapes(): - """Test that check_vecenv catches wrong-shaped observations and returns.""" - - broken_env = WrongShapeVecEnv() - - try: - with pytest.raises(AssertionError, match="Expected observation shape"): - check_vecenv(broken_env) - finally: - broken_env.close() - - -def test_check_vecenv_dict_space(): - """Test VecEnv checker with Dict observation space.""" - - class DictEnv(gym.Env): - def __init__(self): - self.observation_space = spaces.Dict({ - 'observation': spaces.Box(low=-1.0, high=1.0, shape=(4,)), - 'achieved_goal': spaces.Box(low=-1.0, high=1.0, shape=(2,)), - }) - self.action_space = spaces.Discrete(2) - - def reset(self, *, seed=None, options=None): - return { - 'observation': np.zeros(4), - 'achieved_goal': np.zeros(2), - }, {} - - def step(self, action): - obs = { - 'observation': np.zeros(4), - 'achieved_goal': np.zeros(2), - } - return obs, 0.0, False, False, {} - - def make_dict_env(): - return DictEnv() - - vec_env = DummyVecEnv([make_dict_env for _ in range(2)]) - - try: - check_vecenv(vec_env, warn=True) - finally: - vec_env.close() - - -def test_check_vecenv_warnings(): - """Test that check_vecenv emits appropriate warnings.""" - - class BoxActionEnv(gym.Env): - def __init__(self): - self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(4,)) - # Asymmetric action space should trigger warning - self.action_space = spaces.Box(low=-2.0, high=3.0, shape=(2,)) - - def reset(self, *, seed=None, options=None): - return np.zeros(4), {} - - def step(self, action): - return np.zeros(4), 0.0, False, False, {} - - def make_box_env(): - return BoxActionEnv() - - vec_env = DummyVecEnv([make_box_env for _ in range(2)]) - - try: - with pytest.warns(UserWarning, match="symmetric and normalized Box action space"): - check_vecenv(vec_env, warn=True) - finally: - vec_env.close() diff --git a/tests/test_vec_env_checker.py b/tests/test_vec_env_checker.py new file mode 100644 index 000000000..3ce542e23 --- /dev/null +++ b/tests/test_vec_env_checker.py @@ -0,0 +1,196 @@ +import gymnasium as gym +import numpy as np +import pytest +from gymnasium import spaces + +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv +from stable_baselines3.common.vec_env_checker import check_vecenv + + +class BrokenVecEnv: + """A broken VecEnv that doesn't inherit from VecEnv.""" + + def __init__(self): + self.num_envs = 2 + self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,)) + self.action_space = spaces.Discrete(2) + + +class MissingAttributeVecEnv(VecEnv): + """A VecEnv missing required attributes.""" + + def __init__(self): + # Intentionally not calling super().__init__ + pass + + def reset(self): + pass + + def step_async(self, actions): + pass + + def step_wait(self): + pass + + def close(self): + pass + + def get_attr(self, attr_name, indices=None): + pass + + def set_attr(self, attr_name, value, indices=None): + pass + + def env_method(self, method_name, *method_args, indices=None, **method_kwargs): + pass + + def env_is_wrapped(self, wrapper_class, indices=None): + return [False] * getattr(self, 'num_envs', 1) + + +class WrongShapeVecEnv(VecEnv): + """A VecEnv that returns wrong-shaped observations.""" + + def __init__(self): + super().__init__( + num_envs=2, + observation_space=spaces.Box(low=-1.0, high=1.0, shape=(3,)), + action_space=spaces.Discrete(2) + ) + + def reset(self): + # Return wrong shape (should be (2, 3) but return (3,)) + return np.zeros(3) + + def step_async(self, actions): + pass + + def step_wait(self): + # Return wrong shapes + obs = np.zeros(3) # Should be (2, 3) + rewards = np.zeros(3) # Should be (2,) + dones = np.zeros(3) # Should be (2,) + infos = [{}] # Should be [{}, {}] - list or tuple with 2 elements + return obs, rewards, dones, infos + + def close(self): + pass + + def get_attr(self, attr_name, indices=None): + return [None] * self.num_envs + + def set_attr(self, attr_name, value, indices=None): + pass + + def env_method(self, method_name, *method_args, indices=None, **method_kwargs): + return [None] * self.num_envs + + def env_is_wrapped(self, wrapper_class, indices=None): + return [False] * self.num_envs + + +def test_check_vecenv_basic(): + """Test basic VecEnv checker functionality with a working VecEnv.""" + + def make_env(): + return gym.make('CartPole-v1') + + vec_env = DummyVecEnv([make_env for _ in range(2)]) + + try: + # Should pass without issues + check_vecenv(vec_env, warn=True) + finally: + vec_env.close() + + +def test_check_vecenv_not_vecenv(): + """Test that check_vecenv raises error for non-VecEnv objects.""" + + broken_env = BrokenVecEnv() + + with pytest.raises(AssertionError, match="must inherit from.*VecEnv"): + check_vecenv(broken_env) + + +def test_check_vecenv_missing_attributes(): + """Test that check_vecenv raises error for VecEnv with missing attributes.""" + + broken_env = MissingAttributeVecEnv() + + with pytest.raises(AssertionError, match="must have.*attribute"): + check_vecenv(broken_env) + + +def test_check_vecenv_wrong_shapes(): + """Test that check_vecenv catches wrong-shaped observations and returns.""" + + broken_env = WrongShapeVecEnv() + + try: + with pytest.raises(AssertionError, match="Expected observation shape"): + check_vecenv(broken_env) + finally: + broken_env.close() + + +def test_check_vecenv_dict_space(): + """Test VecEnv checker with Dict observation space.""" + + class DictEnv(gym.Env): + def __init__(self): + self.observation_space = spaces.Dict({ + 'observation': spaces.Box(low=-1.0, high=1.0, shape=(4,)), + 'achieved_goal': spaces.Box(low=-1.0, high=1.0, shape=(2,)), + }) + self.action_space = spaces.Discrete(2) + + def reset(self, *, seed=None, options=None): + return { + 'observation': np.zeros(4), + 'achieved_goal': np.zeros(2), + }, {} + + def step(self, action): + obs = { + 'observation': np.zeros(4), + 'achieved_goal': np.zeros(2), + } + return obs, 0.0, False, False, {} + + def make_dict_env(): + return DictEnv() + + vec_env = DummyVecEnv([make_dict_env for _ in range(2)]) + + try: + check_vecenv(vec_env, warn=True) + finally: + vec_env.close() + + +def test_check_vecenv_warnings(): + """Test that check_vecenv emits appropriate warnings.""" + + class BoxActionEnv(gym.Env): + def __init__(self): + self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(4,)) + # Asymmetric action space should trigger warning + self.action_space = spaces.Box(low=-2.0, high=3.0, shape=(2,)) + + def reset(self, *, seed=None, options=None): + return np.zeros(4), {} + + def step(self, action): + return np.zeros(4), 0.0, False, False, {} + + def make_box_env(): + return BoxActionEnv() + + vec_env = DummyVecEnv([make_box_env for _ in range(2)]) + + try: + with pytest.warns(UserWarning, match="symmetric and normalized Box action space"): + check_vecenv(vec_env, warn=True) + finally: + vec_env.close() \ No newline at end of file From ff2c000b17f1ecd50889b92896ff1f58cb373b93 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 26 Sep 2025 12:55:20 +0000 Subject: [PATCH 4/4] Fix linting, type checking, and improve VecEnv checker implementation Co-authored-by: araffin <1973948+araffin@users.noreply.github.com> --- docs/guide/vec_envs.rst | 2 +- stable_baselines3/common/vec_env/__init__.py | 2 +- stable_baselines3/common/vec_env_checker.py | 252 +++---------------- tests/test_vec_env_checker.py | 102 ++++---- 4 files changed, 94 insertions(+), 264 deletions(-) diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index d968e67f1..c682fbe41 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -184,7 +184,7 @@ it calls the environment setter method. Checking VecEnv Implementation ------------------------------ +------------------------------ When implementing custom vectorized environments, it's easy to make mistakes that can lead to hard-to-debug issues. To help with this, Stable-Baselines3 provides a ``check_vecenv`` function that validates your VecEnv implementation diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 6ec840b1d..bcbecc62f 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -13,7 +13,7 @@ from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder -# Avoid circular import by importing the vec_env_checker here +# Avoid circular import by importing the vec_env_checker here from stable_baselines3.common.vec_env_checker import check_vecenv VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper) diff --git a/stable_baselines3/common/vec_env_checker.py b/stable_baselines3/common/vec_env_checker.py index aa5b09e82..c614a5815 100644 --- a/stable_baselines3/common/vec_env_checker.py +++ b/stable_baselines3/common/vec_env_checker.py @@ -1,103 +1,13 @@ import warnings -from typing import Any, Union +from typing import Any import numpy as np from gymnasium import spaces +from stable_baselines3.common.env_checker import _check_box_obs, _check_unsupported_spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnv -def _is_oneof_space(space: spaces.Space) -> bool: - """ - Return True if the provided space is a OneOf space, - False if not or if the current version of Gym doesn't support this space. - """ - try: - return isinstance(space, spaces.OneOf) # type: ignore[attr-defined] - except AttributeError: - # Gym < v1.0 - return False - - -def _is_numpy_array_space(space: spaces.Space) -> bool: - """ - Returns False if provided space is not representable as a single numpy array - (e.g. Dict and Tuple spaces return False) - """ - return not isinstance(space, (spaces.Dict, spaces.Tuple)) - - -def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool: - """ - Return False if a (Multi)Discrete space has a non-zero start. - """ - return np.allclose(space.start, np.zeros_like(space.start)) - - -def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None: - """ - :param space: Observation or action space - :param space_type: information about whether it is an observation or action space - (for the warning message) - :param key: When the observation space comes from a Dict space, we pass the - corresponding key to have more precise warning messages. Defaults to "". - """ - if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space): - maybe_key = f"(key='{key}')" if key else "" - warnings.warn( - f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) " - "is not supported by Stable-Baselines3. " - "You can use a wrapper (see https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html) " - f"or update your {space_type} space." - ) - - -def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: - """ - Check that the input will be compatible with Stable-Baselines - when the observation is apparently an image. - - :param observation_space: Observation space - :param key: When the observation space comes from a Dict space, we pass the - corresponding key to have more precise warning messages. Defaults to "". - """ - if observation_space.dtype != np.uint8: - warnings.warn( - f"It seems that your observation {key} is an image but its `dtype` " - f"is ({observation_space.dtype}) whereas it has to be `np.uint8`. " - "If your observation is not an image, we recommend you to flatten the observation " - "to have only a 1D vector" - ) - - if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): - warnings.warn( - f"It seems that your observation space {key} is an image but the " - "upper and lower bounds are not in [0, 255]. " - "Because the CNN policy normalize automatically the observation " - "you may encounter issue if the values are not in that range." - ) - - -def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None: - """ - Check that the observation space is correctly formatted - when dealing with a ``Box()`` space. In particular, it checks: - - that the dimensions are big enough when it is an image, and that the type matches - - that the observation has an expected shape (warn the user if not) - """ - # If image, check the low and high values, the type and the number of channels - # and the shape (minimal value) - if len(observation_space.shape) == 3: - _check_image_input(observation_space, key) - - if len(observation_space.shape) not in [1, 3]: - warnings.warn( - f"Your observation {key} has an unconventional shape (neither an image, nor a 1D vector). " - "We recommend you to flatten the observation " - "to have only a 1D vector or use a custom policy to properly process the data." - ) - - def _check_vecenv_spaces(vec_env: VecEnv) -> None: """ Check that the VecEnv has valid observation and action spaces. @@ -108,24 +18,25 @@ def _check_vecenv_spaces(vec_env: VecEnv) -> None: assert isinstance( vec_env.observation_space, spaces.Space - ), "The observation space must inherit from gymnasium.spaces" - assert isinstance(vec_env.action_space, spaces.Space), "The action space must inherit from gymnasium.spaces" - assert isinstance(vec_env.num_envs, int) and vec_env.num_envs > 0, "num_envs must be a positive integer" + ), f"The observation space must inherit from gymnasium.spaces, got {type(vec_env.observation_space)}" + assert isinstance( + vec_env.action_space, spaces.Space + ), f"The action space must inherit from gymnasium.spaces, got {type(vec_env.action_space)}" + assert ( + isinstance(vec_env.num_envs, int) and vec_env.num_envs > 0 + ), f"num_envs must be a positive integer, got {vec_env.num_envs} (type: {type(vec_env.num_envs)})" def _check_vecenv_reset(vec_env: VecEnv) -> Any: """ Check that VecEnv reset method works correctly and returns properly shaped observations. """ - try: - obs = vec_env.reset() - except Exception as e: - raise RuntimeError(f"VecEnv reset() failed: {e}") from e + obs = vec_env.reset() # Check observation shape matches expected vectorized shape if isinstance(vec_env.observation_space, spaces.Box): assert isinstance(obs, np.ndarray), f"For Box observation space, reset() must return np.ndarray, got {type(obs)}" - expected_shape = (vec_env.num_envs,) + vec_env.observation_space.shape + expected_shape = (vec_env.num_envs, *vec_env.observation_space.shape) assert obs.shape == expected_shape, ( f"Expected observation shape {expected_shape}, got {obs.shape}. " f"VecEnv observations should have batch dimension first." @@ -135,10 +46,10 @@ def _check_vecenv_reset(vec_env: VecEnv) -> Any: for key, space in vec_env.observation_space.spaces.items(): assert key in obs, f"Missing key '{key}' in observation dict" if isinstance(space, spaces.Box): - expected_shape = (vec_env.num_envs,) + space.shape - assert obs[key].shape == expected_shape, ( - f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" - ) + expected_shape = (vec_env.num_envs, *space.shape) + assert ( + obs[key].shape == expected_shape + ), f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" elif isinstance(vec_env.observation_space, spaces.Discrete): assert isinstance(obs, np.ndarray), f"For Discrete observation space, reset() must return np.ndarray, got {type(obs)}" expected_shape = (vec_env.num_envs,) @@ -152,21 +63,9 @@ def _check_vecenv_step(vec_env: VecEnv, obs: Any) -> None: Check that VecEnv step method works correctly and returns properly shaped values. """ # Generate valid actions - if isinstance(vec_env.action_space, spaces.Box): - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) - elif isinstance(vec_env.action_space, spaces.Discrete): - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) - elif isinstance(vec_env.action_space, spaces.MultiDiscrete): - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) - elif isinstance(vec_env.action_space, spaces.MultiBinary): - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) - else: - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) - - try: - obs, rewards, dones, infos = vec_env.step(actions) - except Exception as e: - raise RuntimeError(f"VecEnv step() failed: {e}") from e + actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) + + obs, rewards, dones, infos = vec_env.step(actions) # Check rewards assert isinstance(rewards, np.ndarray), f"step() must return rewards as np.ndarray, got {type(rewards)}" @@ -186,7 +85,7 @@ def _check_vecenv_step(vec_env: VecEnv, obs: Any) -> None: # Check observation shape consistency (similar to reset) if isinstance(vec_env.observation_space, spaces.Box): assert isinstance(obs, np.ndarray), f"For Box observation space, step() must return np.ndarray, got {type(obs)}" - expected_shape = (vec_env.num_envs,) + vec_env.observation_space.shape + expected_shape = (vec_env.num_envs, *vec_env.observation_space.shape) assert obs.shape == expected_shape, ( f"Expected observation shape {expected_shape}, got {obs.shape}. " f"VecEnv observations should have batch dimension first." @@ -196,121 +95,52 @@ def _check_vecenv_step(vec_env: VecEnv, obs: Any) -> None: for key, space in vec_env.observation_space.spaces.items(): assert key in obs, f"Missing key '{key}' in observation dict" if isinstance(space, spaces.Box): - expected_shape = (vec_env.num_envs,) + space.shape - assert obs[key].shape == expected_shape, ( - f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" - ) + expected_shape = (vec_env.num_envs, *space.shape) + assert ( + obs[key].shape == expected_shape + ), f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" + + +class _DummyVecEnvForSpaceCheck: + """Dummy class to pass to _check_unsupported_spaces function.""" + + def __init__(self, observation_space: spaces.Space, action_space: spaces.Space): + self.observation_space = observation_space + self.action_space = action_space def _check_vecenv_unsupported_spaces(observation_space: spaces.Space, action_space: spaces.Space) -> bool: """ Emit warnings when the observation space or action space used is not supported by Stable-Baselines - for VecEnv. This is a VecEnv-specific version of _check_unsupported_spaces. + for VecEnv. Reuses the existing _check_unsupported_spaces function. :return: True if return value tests should be skipped. """ - should_skip = graph_space = sequence_space = False - if isinstance(observation_space, spaces.Dict): - nested_dict = False - for key, space in observation_space.spaces.items(): - if isinstance(space, spaces.Dict): - nested_dict = True - elif isinstance(space, spaces.Graph): - graph_space = True - elif isinstance(space, spaces.Sequence): - sequence_space = True - _check_non_zero_start(space, "observation", key) - - if nested_dict: - warnings.warn( - "Nested observation spaces are not supported by Stable Baselines3 " - "(Dict spaces inside Dict space). " - "You should flatten it to have only one level of keys." - "For example, `dict(space1=dict(space2=Box(), space3=Box()), spaces4=Discrete())` " - "is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is." - ) - - if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1: - warnings.warn( - f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} " - "which is currently not supported by Stable-Baselines3. " - "Please convert it to a 1D array using a wrapper: " - "https://github.com/DLR-RM/stable-baselines3/issues/1836." - ) - - if isinstance(observation_space, spaces.Tuple): - warnings.warn( - "The observation space is a Tuple, " - "this is currently not supported by Stable Baselines3. " - "However, you can convert it to a Dict observation space " - "(cf. https://gymnasium.farama.org/api/spaces/composite/#dict). " - "which is supported by SB3." - ) - # Check for Sequence spaces inside Tuple - for space in observation_space.spaces: - if isinstance(space, spaces.Sequence): - sequence_space = True - elif isinstance(space, spaces.Graph): - graph_space = True - - # Check for Sequence spaces inside OneOf - if _is_oneof_space(observation_space): - warnings.warn( - "OneOf observation space is not supported by Stable-Baselines3. " - "Note: The checks for returned values are skipped." - ) - should_skip = True - - _check_non_zero_start(observation_space, "observation") - - if isinstance(observation_space, spaces.Sequence) or sequence_space: - warnings.warn( - "Sequence observation space is not supported by Stable-Baselines3. " - "You can pad your observation to have a fixed size instead.\n" - "Note: The checks for returned values are skipped." - ) - should_skip = True - - if isinstance(observation_space, spaces.Graph) or graph_space: - warnings.warn( - "Graph observation space is not supported by Stable-Baselines3. " - "Note: The checks for returned values are skipped." - ) - should_skip = True - - _check_non_zero_start(action_space, "action") - - if not _is_numpy_array_space(action_space): - warnings.warn( - "The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. " - "This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the " - "action using a wrapper." - ) - return should_skip + # Create a dummy env object to pass to the existing function + dummy_env = _DummyVecEnvForSpaceCheck(observation_space, action_space) + return _check_unsupported_spaces(dummy_env, observation_space, action_space) # type: ignore[arg-type] def check_vecenv(vec_env: VecEnv, warn: bool = True) -> None: """ Check that a VecEnv follows the VecEnv API and is compatible with Stable-Baselines3. - + This checker verifies that: - The VecEnv has proper observation_space, action_space, and num_envs attributes - The reset() method returns observations with correct vectorized shape - The step() method returns observations, rewards, dones, and infos with correct shapes - All return values have the expected types and dimensions - + :param vec_env: The vectorized environment to check - :param warn: Whether to output additional warnings mainly related to + :param warn: Whether to output additional warnings mainly related to the interaction with Stable Baselines """ - assert isinstance(vec_env, VecEnv), ( - "Your environment must inherit from stable_baselines3.common.vec_env.VecEnv" - ) + assert isinstance(vec_env, VecEnv), "Your environment must inherit from stable_baselines3.common.vec_env.VecEnv" # ============= Check basic VecEnv attributes ================ _check_vecenv_spaces(vec_env) - # Define aliases for convenience + # Define aliases for convenience observation_space = vec_env.observation_space action_space = vec_env.action_space @@ -349,4 +179,4 @@ def check_vecenv(vec_env: VecEnv, warn: bool = True) -> None: # ============ Check the VecEnv methods =============== obs = _check_vecenv_reset(vec_env) - _check_vecenv_step(vec_env, obs) \ No newline at end of file + _check_vecenv_step(vec_env, obs) diff --git a/tests/test_vec_env_checker.py b/tests/test_vec_env_checker.py index 3ce542e23..f9444d1cc 100644 --- a/tests/test_vec_env_checker.py +++ b/tests/test_vec_env_checker.py @@ -9,7 +9,7 @@ class BrokenVecEnv: """A broken VecEnv that doesn't inherit from VecEnv.""" - + def __init__(self): self.num_envs = 2 self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,)) @@ -18,53 +18,51 @@ def __init__(self): class MissingAttributeVecEnv(VecEnv): """A VecEnv missing required attributes.""" - + def __init__(self): # Intentionally not calling super().__init__ pass - + def reset(self): pass - + def step_async(self, actions): pass - + def step_wait(self): pass - + def close(self): pass - + def get_attr(self, attr_name, indices=None): pass - + def set_attr(self, attr_name, value, indices=None): pass - + def env_method(self, method_name, *method_args, indices=None, **method_kwargs): pass - + def env_is_wrapped(self, wrapper_class, indices=None): - return [False] * getattr(self, 'num_envs', 1) + return [False] * getattr(self, "num_envs", 1) class WrongShapeVecEnv(VecEnv): """A VecEnv that returns wrong-shaped observations.""" - + def __init__(self): super().__init__( - num_envs=2, - observation_space=spaces.Box(low=-1.0, high=1.0, shape=(3,)), - action_space=spaces.Discrete(2) + num_envs=2, observation_space=spaces.Box(low=-1.0, high=1.0, shape=(3,)), action_space=spaces.Discrete(2) ) - + def reset(self): # Return wrong shape (should be (2, 3) but return (3,)) return np.zeros(3) - + def step_async(self, actions): pass - + def step_wait(self): # Return wrong shapes obs = np.zeros(3) # Should be (2, 3) @@ -72,31 +70,31 @@ def step_wait(self): dones = np.zeros(3) # Should be (2,) infos = [{}] # Should be [{}, {}] - list or tuple with 2 elements return obs, rewards, dones, infos - + def close(self): pass - + def get_attr(self, attr_name, indices=None): return [None] * self.num_envs - + def set_attr(self, attr_name, value, indices=None): pass - + def env_method(self, method_name, *method_args, indices=None, **method_kwargs): return [None] * self.num_envs - + def env_is_wrapped(self, wrapper_class, indices=None): return [False] * self.num_envs def test_check_vecenv_basic(): """Test basic VecEnv checker functionality with a working VecEnv.""" - + def make_env(): - return gym.make('CartPole-v1') + return gym.make("CartPole-v1") vec_env = DummyVecEnv([make_env for _ in range(2)]) - + try: # Should pass without issues check_vecenv(vec_env, warn=True) @@ -106,27 +104,27 @@ def make_env(): def test_check_vecenv_not_vecenv(): """Test that check_vecenv raises error for non-VecEnv objects.""" - + broken_env = BrokenVecEnv() - - with pytest.raises(AssertionError, match="must inherit from.*VecEnv"): + + with pytest.raises(AssertionError, match=r"must inherit from.*VecEnv"): check_vecenv(broken_env) def test_check_vecenv_missing_attributes(): """Test that check_vecenv raises error for VecEnv with missing attributes.""" - + broken_env = MissingAttributeVecEnv() - - with pytest.raises(AssertionError, match="must have.*attribute"): + + with pytest.raises(AssertionError, match=r"must have.*attribute"): check_vecenv(broken_env) def test_check_vecenv_wrong_shapes(): """Test that check_vecenv catches wrong-shaped observations and returns.""" - + broken_env = WrongShapeVecEnv() - + try: with pytest.raises(AssertionError, match="Expected observation shape"): check_vecenv(broken_env) @@ -136,25 +134,27 @@ def test_check_vecenv_wrong_shapes(): def test_check_vecenv_dict_space(): """Test VecEnv checker with Dict observation space.""" - + class DictEnv(gym.Env): def __init__(self): - self.observation_space = spaces.Dict({ - 'observation': spaces.Box(low=-1.0, high=1.0, shape=(4,)), - 'achieved_goal': spaces.Box(low=-1.0, high=1.0, shape=(2,)), - }) + self.observation_space = spaces.Dict( + { + "observation": spaces.Box(low=-1.0, high=1.0, shape=(4,)), + "achieved_goal": spaces.Box(low=-1.0, high=1.0, shape=(2,)), + } + ) self.action_space = spaces.Discrete(2) - + def reset(self, *, seed=None, options=None): return { - 'observation': np.zeros(4), - 'achieved_goal': np.zeros(2), + "observation": np.zeros(4), + "achieved_goal": np.zeros(2), }, {} - + def step(self, action): obs = { - 'observation': np.zeros(4), - 'achieved_goal': np.zeros(2), + "observation": np.zeros(4), + "achieved_goal": np.zeros(2), } return obs, 0.0, False, False, {} @@ -162,7 +162,7 @@ def make_dict_env(): return DictEnv() vec_env = DummyVecEnv([make_dict_env for _ in range(2)]) - + try: check_vecenv(vec_env, warn=True) finally: @@ -171,16 +171,16 @@ def make_dict_env(): def test_check_vecenv_warnings(): """Test that check_vecenv emits appropriate warnings.""" - + class BoxActionEnv(gym.Env): def __init__(self): self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(4,)) # Asymmetric action space should trigger warning self.action_space = spaces.Box(low=-2.0, high=3.0, shape=(2,)) - + def reset(self, *, seed=None, options=None): return np.zeros(4), {} - + def step(self, action): return np.zeros(4), 0.0, False, False, {} @@ -188,9 +188,9 @@ def make_box_env(): return BoxActionEnv() vec_env = DummyVecEnv([make_box_env for _ in range(2)]) - + try: with pytest.warns(UserWarning, match="symmetric and normalized Box action space"): check_vecenv(vec_env, warn=True) finally: - vec_env.close() \ No newline at end of file + vec_env.close()