From e722d75a417fc9612916e7446dc8e968f620801d Mon Sep 17 00:00:00 2001 From: akbaig Date: Sun, 23 Mar 2025 04:55:22 +0200 Subject: [PATCH] feat: maskable recurrent ppo --- sb3_contrib/__init__.py | 2 + .../common/maskable_recurrent/__init__.py | 0 .../common/maskable_recurrent/buffers.py | 388 ++++++++++++ .../common/maskable_recurrent/policies.py | 579 ++++++++++++++++++ .../common/maskable_recurrent/type_aliases.py | 32 + sb3_contrib/ppo_mask_recurrent/__init__.py | 4 + sb3_contrib/ppo_mask_recurrent/policies.py | 9 + .../ppo_mask_recurrent/ppo_mask_recurrent.py | 412 +++++++++++++ 8 files changed, 1426 insertions(+) create mode 100644 sb3_contrib/common/maskable_recurrent/__init__.py create mode 100644 sb3_contrib/common/maskable_recurrent/buffers.py create mode 100644 sb3_contrib/common/maskable_recurrent/policies.py create mode 100644 sb3_contrib/common/maskable_recurrent/type_aliases.py create mode 100644 sb3_contrib/ppo_mask_recurrent/__init__.py create mode 100644 sb3_contrib/ppo_mask_recurrent/policies.py create mode 100644 sb3_contrib/ppo_mask_recurrent/ppo_mask_recurrent.py diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 2aa7a19b..690f0f91 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -4,6 +4,7 @@ from sb3_contrib.crossq import CrossQ from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.ppo_recurrent import RecurrentPPO +from sb3_contrib.ppo_mask_recurrent import MaskableRecurrentPPO from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC from sb3_contrib.trpo import TRPO @@ -21,4 +22,5 @@ "CrossQ", "MaskablePPO", "RecurrentPPO", + "MaskableRecurrentPPO" ] diff --git a/sb3_contrib/common/maskable_recurrent/__init__.py b/sb3_contrib/common/maskable_recurrent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sb3_contrib/common/maskable_recurrent/buffers.py b/sb3_contrib/common/maskable_recurrent/buffers.py new file mode 100644 index 00000000..93731907 --- /dev/null +++ b/sb3_contrib/common/maskable_recurrent/buffers.py @@ -0,0 +1,388 @@ +from collections.abc import Generator +from functools import partial +from typing import Callable, Optional, Union + +import numpy as np +import torch as th +from gymnasium import spaces +from stable_baselines3.common.vec_env import VecNormalize + +from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer +from sb3_contrib.common.maskable_recurrent.type_aliases import ( + MaskableRecurrentDictRolloutBufferSamples, + MaskableRecurrentRolloutBufferSamples, + RNNStates, +) + + +def pad( + seq_start_indices: np.ndarray, + seq_end_indices: np.ndarray, + device: th.device, + tensor: np.ndarray, + padding_value: float = 0.0, +) -> th.Tensor: + """ + Chunk sequences and pad them to have constant dimensions. + + :param seq_start_indices: Indices of the transitions that start a sequence + :param seq_end_indices: Indices of the transitions that end a sequence + :param device: PyTorch device + :param tensor: Tensor of shape (batch_size, *tensor_shape) + :param padding_value: Value used to pad sequence to the same length + (zero padding by default) + :return: (n_seq, max_length, *tensor_shape) + """ + # Create sequences given start and end + seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)] + return th.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=padding_value) + + +def pad_and_flatten( + seq_start_indices: np.ndarray, + seq_end_indices: np.ndarray, + device: th.device, + tensor: np.ndarray, + padding_value: float = 0.0, +) -> th.Tensor: + """ + Pad and flatten the sequences of scalar values, + while keeping the sequence order. + From (batch_size, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) + + :param seq_start_indices: Indices of the transitions that start a sequence + :param seq_end_indices: Indices of the transitions that end a sequence + :param device: PyTorch device (cpu, gpu, ...) + :param tensor: Tensor of shape (max_length, n_seq, 1) + :param padding_value: Value used to pad sequence to the same length + (zero padding by default) + :return: (n_seq * max_length,) aka (padded_batch_size,) + """ + return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten() + + +def create_sequencers( + episode_starts: np.ndarray, + env_change: np.ndarray, + device: th.device, +) -> tuple[np.ndarray, Callable, Callable]: + """ + Create the utility function to chunk data into + sequences and pad them to create fixed size tensors. + + :param episode_starts: Indices where an episode starts + :param env_change: Indices where the data collected + come from a different env (when using multiple env for data collection) + :param device: PyTorch device + :return: Indices of the transitions that start a sequence, + pad and pad_and_flatten utilities tailored for this batch + (sequence starts and ends indices are fixed) + """ + # Create sequence if env changes too + seq_start = np.logical_or(episode_starts, env_change).flatten() + # First index is always the beginning of a sequence + seq_start[0] = True + # Retrieve indices of sequence starts + seq_start_indices = np.where(seq_start == True)[0] # noqa: E712 + # End of sequence are just before sequence starts + # Last index is also always end of a sequence + seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])]) + + # Create padding method for this minibatch + # to avoid repeating arguments (seq_start_indices, seq_end_indices) + local_pad = partial(pad, seq_start_indices, seq_end_indices, device) + local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices, device) + return seq_start_indices, local_pad, local_pad_and_flatten + + +class MaskableRecurrentRolloutBuffer(RecurrentRolloutBuffer): + """ + Rollout buffer that also stores the LSTM cell and hidden states. + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + hidden_state_shape: tuple[int, int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + super().__init__(buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs) + self.action_masks = None + + def reset(self): + if isinstance(self.action_space, spaces.Discrete): + mask_dims = self.action_space.n + elif isinstance(self.action_space, spaces.MultiDiscrete): + mask_dims = sum(self.action_space.nvec) + elif isinstance(self.action_space, spaces.MultiBinary): + mask_dims = 2 * self.action_space.n # One mask per binary outcome + else: + raise ValueError(f"Unsupported action space {type(self.action_space)}") + + self.mask_dims = mask_dims + self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32) + + super().reset() + + def add(self, *args, lstm_states: RNNStates, action_masks: Optional[np.ndarray] = None, **kwargs) -> None: + """ + :param hidden_states: LSTM cell and hidden state + :param action_masks: Masks applied to constrain the choice of possible actions + """ + if action_masks is not None: + self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims)) + super().add(*args, lstm_states=lstm_states, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRecurrentRolloutBufferSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + + # flatten but keep the sequence order + # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape) + # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape) + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + "action_masks" + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + # Sampling strategy that allows any mini batch size but requires + # more complexity and use of padding + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds, env_change) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env_change: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> MaskableRecurrentRolloutBufferSamples: + # Retrieve sequence starts and utility function + self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( + self.episode_starts[batch_inds], env_change[batch_inds], self.device + ) + + # Number of sequences + n_seq = len(self.seq_start_indices) + max_length = self.pad(self.actions[batch_inds]).shape[1] + padded_batch_size = n_seq * max_length + # We retrieve the lstm hidden states that will allow + # to properly initialize the LSTM at the beginning of each sequence + lstm_states_pi = ( + # 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim) + # 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim) + # 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_vf = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) + lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) + + return MaskableRecurrentRolloutBufferSamples( + # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) + observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size, *self.obs_shape)), + actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=self.pad_and_flatten(self.values[batch_inds]), + old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), + advantages=self.pad_and_flatten(self.advantages[batch_inds]), + returns=self.pad_and_flatten(self.returns[batch_inds]), + lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), + episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), + action_masks=self.pad(self.action_masks[batch_inds]).reshape((padded_batch_size, self.mask_dims)) + ) + + +class MaskableRecurrentDictRolloutBuffer(RecurrentDictRolloutBuffer): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RecurrentRolloutBuffer to use dictionary observations + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param hidden_state_shape: Shape of the buffer that will collect lstm states + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + hidden_state_shape: tuple[int, int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.action_masks = None + super().__init__(buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs=n_envs) + + def reset(self): + if isinstance(self.action_space, spaces.Discrete): + mask_dims = self.action_space.n + elif isinstance(self.action_space, spaces.MultiDiscrete): + mask_dims = sum(self.action_space.nvec) + elif isinstance(self.action_space, spaces.MultiBinary): + mask_dims = 2 * self.action_space.n # One mask per binary outcome + else: + raise ValueError(f"Unsupported action space {type(self.action_space)}") + + self.mask_dims = mask_dims + self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32) + + super().reset() + + def add(self, *args, lstm_states: RNNStates, action_masks: Optional[np.ndarray] = None, **kwargs) -> None: + """ + :param hidden_states: LSTM cell and hidden state + :param action_masks: Masks applied to constrain the choice of possible actions. + """ + if action_masks is not None: + self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims)) + super().add(*args, lstm_states=lstm_states, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRecurrentDictRolloutBufferSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + for tensor in [ + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + "action_masks" + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds, env_change) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env_change: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> MaskableRecurrentDictRolloutBufferSamples: + # Retrieve sequence starts and utility function + self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( + self.episode_starts[batch_inds], env_change[batch_inds], self.device + ) + + n_seq = len(self.seq_start_indices) + max_length = self.pad(self.actions[batch_inds]).shape[1] + padded_batch_size = n_seq * max_length + # We retrieve the lstm hidden states that will allow + # to properly initialize the LSTM at the beginning of each sequence + lstm_states_pi = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_vf = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) + lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) + + observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} + observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()} + + return MaskableRecurrentDictRolloutBufferSamples( + observations=observations, + actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=self.pad_and_flatten(self.values[batch_inds]), + old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), + advantages=self.pad_and_flatten(self.advantages[batch_inds]), + returns=self.pad_and_flatten(self.returns[batch_inds]), + lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), + episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), + action_masks=self.pad(self.action_masks[batch_inds]).reshape((padded_batch_size, self.mask_dims)) + ) diff --git a/sb3_contrib/common/maskable_recurrent/policies.py b/sb3_contrib/common/maskable_recurrent/policies.py new file mode 100644 index 00000000..213b56f3 --- /dev/null +++ b/sb3_contrib/common/maskable_recurrent/policies.py @@ -0,0 +1,579 @@ +from functools import partial +from typing import Any, Optional, Union + +import numpy as np +import torch as th +from gymnasium import spaces +from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution +from stable_baselines3.common.policies import ActorCriticPolicy +from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, +) +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.utils import zip_strict +from torch import nn + +from sb3_contrib.common.recurrent.type_aliases import RNNStates + + +class MaskableRecurrentActorCriticPolicy(RecurrentActorCriticPolicy): + """ + Recurrent policy class for actor-critic algorithms (has both policy and value prediction). + To be used with A2C, PPO and the likes. + It assumes that both the actor and the critic LSTM + have the same architecture. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic + (in that case, only the actor gradient is used) + By default, the actor and the critic have two separate LSTM. + :param enable_critic_lstm: Use a seperate LSTM for the critic. + :param lstm_kwargs: Additional keyword arguments to pass the the LSTM + constructor. + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + lstm_hidden_size: int = 256, + n_lstm_layers: int = 1, + shared_lstm: bool = False, + enable_critic_lstm: bool = True, + lstm_kwargs: Optional[dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + share_features_extractor, + normalize_images, + optimizer_class, + optimizer_kwargs, + lstm_hidden_size, + n_lstm_layers, + shared_lstm, + enable_critic_lstm, + lstm_kwargs + ) + # Action distribution + self.action_dist = make_masked_proba_distribution(action_space) + self._build(lr_schedule) + + def _build(self, lr_schedule: Schedule) -> None: + """ + Create the networks and the optimizer. + + :param lr_schedule: Learning rate schedule + lr_schedule(1) is the initial learning rate + """ + self._build_mlp_extractor() + + self.action_net = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi) + self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) + + # Init weights: use orthogonal initialization + # with small initial weight for the output + if self.ortho_init: + # TODO: check for features_extractor + # Values from stable-baselines. + # features_extractor/mlp values are + # originally from openai/baselines (default gains/init_scales). + module_gains = { + self.features_extractor: np.sqrt(2), + self.mlp_extractor: np.sqrt(2), + self.action_net: 0.01, + self.value_net: 1, + } + if not self.share_features_extractor: + # Note(antonin): this is to keep SB3 results + # consistent, see GH#1148 + del module_gains[self.features_extractor] + module_gains[self.pi_features_extractor] = np.sqrt(2) + module_gains[self.vf_features_extractor] = np.sqrt(2) + + for module, gain in module_gains.items(): + module.apply(partial(self.init_weights, gain=gain)) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + def forward( + self, + obs: th.Tensor, + lstm_states: RNNStates, + episode_starts: th.Tensor, + deterministic: bool = False, + action_masks: Optional[np.ndarray] = None, + ) -> tuple[th.Tensor, th.Tensor, th.Tensor, RNNStates]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation. Observation + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + if self.share_features_extractor: + pi_features = vf_features = features # alis + else: + pi_features, vf_features = features + # latent_pi, latent_vf = self.mlp_extractor(features) + latent_pi, lstm_states_pi = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor) + if self.lstm_critic is not None: + latent_vf, lstm_states_vf = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic) + elif self.shared_lstm: + # Re-use LSTM features but do not backpropagate + latent_vf = latent_pi.detach() + lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) + else: + # Critic only has a feedforward network + latent_vf = self.critic(vf_features) + lstm_states_vf = lstm_states_pi + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + if action_masks is not None: + distribution.apply_masking(action_masks) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf) + + def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution: + """ + Retrieve action distribution given the latent codes. + + :param latent_pi: Latent code for the actor + :return: Action distribution + """ + action_logits = self.action_net(latent_pi) + return self.action_dist.proba_distribution(action_logits=action_logits) + + def get_distribution( + self, + obs: th.Tensor, + lstm_states: tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + action_masks: Optional[np.ndarray] = None + ) -> tuple[MaskableDistribution, tuple[th.Tensor, ...]]: + """ + Get the current policy distribution given the observations. + + :param obs: Observation. + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: the action distribution and new hidden states. + """ + # Call the method from the parent of the parent class + features = super(ActorCriticPolicy, self).extract_features(obs, self.pi_features_extractor) + latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + distribution = self._get_action_dist_from_latent(latent_pi) + if action_masks is not None: + distribution.apply_masking(action_masks) + return distribution, lstm_states + + def predict_values( + self, + obs: th.Tensor, + lstm_states: tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: Observation. + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: the estimated values. + """ + # Call the method from the parent of the parent class + features = super(ActorCriticPolicy, self).extract_features(obs, self.vf_features_extractor) + + if self.lstm_critic is not None: + latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic) + elif self.shared_lstm: + # Use LSTM from the actor + latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) + latent_vf = latent_pi.detach() + else: + latent_vf = self.critic(features) + + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + return self.value_net(latent_vf) + + def evaluate_actions( + self, obs: th.Tensor, actions: th.Tensor, lstm_states: RNNStates, episode_starts: th.Tensor, action_masks: Optional[np.ndarray] = None + ) -> tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: Observation. + :param actions: + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param action_masks: Action masks to apply to the action distribution + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + if self.share_features_extractor: + pi_features = vf_features = features # alias + else: + pi_features, vf_features = features + latent_pi, _ = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor) + if self.lstm_critic is not None: + latent_vf, _ = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic) + elif self.shared_lstm: + latent_vf = latent_pi.detach() + else: + latent_vf = self.critic(vf_features) + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + + distribution = self._get_action_dist_from_latent(latent_pi) + if action_masks is not None: + distribution.apply_masking(action_masks) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + + def _predict( + self, + observation: th.Tensor, + lstm_states: tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + deterministic: bool = False, + action_masks: Optional[np.ndarray] = None + ) -> tuple[th.Tensor, tuple[th.Tensor, ...]]: + """ + Get the action according to the policy for a given observation. + + :param observation: + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param deterministic: Whether to use stochastic or deterministic actions + :param action_masks: Action masks to apply to the action distribution + :return: Taken action according to the policy and hidden states of the RNN + """ + distribution, lstm_states = self.get_distribution(observation, lstm_states, episode_starts, action_masks) + return distribution.get_actions(deterministic=deterministic), lstm_states + + def predict( + self, + observation: Union[np.ndarray, dict[str, np.ndarray]], + state: Optional[tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + action_masks: Optional[np.ndarray] = None + ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param deterministic: Whether or not to return deterministic actions. + :param action_masks: Action masks to apply to the action distribution + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + observation, vectorized_env = self.obs_to_tensor(observation) + + if isinstance(observation, dict): + n_envs = observation[next(iter(observation.keys()))].shape[0] + else: + n_envs = observation.shape[0] + # state : (n_layers, n_envs, dim) + if state is None: + # Initialize hidden states to zeros + state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1) + state = (state, state) + + if episode_start is None: + episode_start = np.array([False for _ in range(n_envs)]) + + with th.no_grad(): + # Convert to PyTorch tensors + states = th.tensor(state[0], dtype=th.float32, device=self.device), th.tensor( + state[1], dtype=th.float32, device=self.device + ) + episode_starts = th.tensor(episode_start, dtype=th.float32, device=self.device) + actions, states = self._predict( + observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic, action_masks=action_masks + ) + states = (states[0].cpu().numpy(), states[1].cpu().numpy()) + + # Convert to numpy + actions = actions.cpu().numpy() + + if isinstance(self.action_space, spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + # Remove batch dimension if needed + if not vectorized_env: + actions = actions.squeeze(axis=0) + + return actions, states + + +class MaskableRecurrentActorCriticCnnPolicy(MaskableRecurrentActorCriticPolicy): + """ + CNN recurrent policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic. + By default, only the actor has a recurrent network. + :param enable_critic_lstm: Use a seperate LSTM for the critic. + :param lstm_kwargs: Additional keyword arguments to pass the the LSTM + constructor. + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + lstm_hidden_size: int = 256, + n_lstm_layers: int = 1, + shared_lstm: bool = False, + enable_critic_lstm: bool = True, + lstm_kwargs: Optional[dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + share_features_extractor, + normalize_images, + optimizer_class, + optimizer_kwargs, + lstm_hidden_size, + n_lstm_layers, + shared_lstm, + enable_critic_lstm, + lstm_kwargs, + ) + + +class MaskableRecurrentMultiInputActorCriticPolicy(MaskableRecurrentActorCriticPolicy): + """ + MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic. + By default, only the actor has a recurrent network. + :param enable_critic_lstm: Use a seperate LSTM for the critic. + :param lstm_kwargs: Additional keyword arguments to pass the the LSTM + constructor. + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + lstm_hidden_size: int = 256, + n_lstm_layers: int = 1, + shared_lstm: bool = False, + enable_critic_lstm: bool = True, + lstm_kwargs: Optional[dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + share_features_extractor, + normalize_images, + optimizer_class, + optimizer_kwargs, + lstm_hidden_size, + n_lstm_layers, + shared_lstm, + enable_critic_lstm, + lstm_kwargs, + ) diff --git a/sb3_contrib/common/maskable_recurrent/type_aliases.py b/sb3_contrib/common/maskable_recurrent/type_aliases.py new file mode 100644 index 00000000..a8f703bd --- /dev/null +++ b/sb3_contrib/common/maskable_recurrent/type_aliases.py @@ -0,0 +1,32 @@ +from typing import NamedTuple + +import torch as th +from stable_baselines3.common.type_aliases import TensorDict +from sb3_contrib.common.recurrent.type_aliases import ( + RNNStates, +) + +class MaskableRecurrentRolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + lstm_states: RNNStates + episode_starts: th.Tensor + mask: th.Tensor + action_masks: th.Tensor + + +class MaskableRecurrentDictRolloutBufferSamples(NamedTuple): + observations: TensorDict + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + lstm_states: RNNStates + episode_starts: th.Tensor + mask: th.Tensor + action_masks: th.Tensor diff --git a/sb3_contrib/ppo_mask_recurrent/__init__.py b/sb3_contrib/ppo_mask_recurrent/__init__.py new file mode 100644 index 00000000..5ca1375d --- /dev/null +++ b/sb3_contrib/ppo_mask_recurrent/__init__.py @@ -0,0 +1,4 @@ +from sb3_contrib.ppo_mask_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy +from sb3_contrib.ppo_mask_recurrent.ppo_mask_recurrent import MaskableRecurrentPPO + +__all__ = ["CnnLstmPolicy", "MlpLstmPolicy", "MultiInputLstmPolicy", "MaskableRecurrentPPO"] diff --git a/sb3_contrib/ppo_mask_recurrent/policies.py b/sb3_contrib/ppo_mask_recurrent/policies.py new file mode 100644 index 00000000..ab80bef7 --- /dev/null +++ b/sb3_contrib/ppo_mask_recurrent/policies.py @@ -0,0 +1,9 @@ +from sb3_contrib.common.maskable_recurrent.policies import ( + MaskableRecurrentActorCriticPolicy, + MaskableRecurrentActorCriticCnnPolicy, + MaskableRecurrentMultiInputActorCriticPolicy, +) + +MlpLstmPolicy = MaskableRecurrentActorCriticPolicy +CnnLstmPolicy = MaskableRecurrentActorCriticCnnPolicy +MultiInputLstmPolicy = MaskableRecurrentMultiInputActorCriticPolicy diff --git a/sb3_contrib/ppo_mask_recurrent/ppo_mask_recurrent.py b/sb3_contrib/ppo_mask_recurrent/ppo_mask_recurrent.py new file mode 100644 index 00000000..fd58775b --- /dev/null +++ b/sb3_contrib/ppo_mask_recurrent/ppo_mask_recurrent.py @@ -0,0 +1,412 @@ +from copy import deepcopy +from collections import deque +from typing import Any, ClassVar, Optional, TypeVar, Union + +import numpy as np +import torch as th +from gymnasium import spaces +from stable_baselines3.common import utils +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor +from stable_baselines3.common.vec_env import VecEnv + +from sb3_contrib import RecurrentPPO +from sb3_contrib.common.maskable_recurrent.buffers import MaskableRecurrentDictRolloutBuffer, MaskableRecurrentRolloutBuffer +from sb3_contrib.common.maskable_recurrent.policies import MaskableRecurrentActorCriticPolicy +from sb3_contrib.common.recurrent.type_aliases import RNNStates +from sb3_contrib.ppo_mask_recurrent.policies import MlpLstmPolicy, CnnLstmPolicy, MultiInputLstmPolicy +from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported + +SelfMaskableRecurrentPPO = TypeVar("SelfMaskableRecurrentPPO", bound="MaskableRecurrentPPO") + +class MaskableRecurrentPPO(RecurrentPPO): + + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { + "MlpLstmPolicy": MlpLstmPolicy, + "CnnLstmPolicy": CnnLstmPolicy, + "MultiInputLstmPolicy": MultiInputLstmPolicy, + } + + def _setup_model(self) -> None: + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + buffer_cls = MaskableRecurrentDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRecurrentRolloutBuffer + + self.policy = self.policy_class( + self.observation_space, + self.action_space, + self.lr_schedule, + use_sde=self.use_sde, + **self.policy_kwargs, + ) + self.policy = self.policy.to(self.device) + + # We assume that LSTM for the actor and the critic + # have the same architecture + lstm = self.policy.lstm_actor + + if not isinstance(self.policy, MaskableRecurrentActorCriticPolicy): + raise ValueError("Policy must subclass MaskableRecurrentActorCriticPolicy") + + single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) + # hidden and cell states for actor and critic + self._last_lstm_states = RNNStates( + ( + th.zeros(single_hidden_state_shape, device=self.device), + th.zeros(single_hidden_state_shape, device=self.device), + ), + ( + th.zeros(single_hidden_state_shape, device=self.device), + th.zeros(single_hidden_state_shape, device=self.device), + ), + ) + + hidden_state_buffer_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + + self.rollout_buffer = buffer_cls( + self.n_steps, + self.observation_space, + self.action_space, + hidden_state_buffer_shape, + self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + ) + + # Initialize schedules for policy/value clipping + self.clip_range = get_schedule_fn(self.clip_range) + if self.clip_range_vf is not None: + if isinstance(self.clip_range_vf, (float, int)): + assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, pass `None` to deactivate vf clipping" + + self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, + use_masking: bool = True, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + assert isinstance( + rollout_buffer, (MaskableRecurrentRolloutBuffer, MaskableRecurrentDictRolloutBuffer) + ), f"{rollout_buffer} doesn't support maskable recurrent policy" + + + assert self._last_obs is not None, "No previous observation was provided" + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + + n_steps = 0 + action_masks = None + rollout_buffer.reset() + # Sample new weights for the state dependent exploration + if self.use_sde: + self.policy.reset_noise(env.num_envs) + + callback.on_rollout_start() + + lstm_states = deepcopy(self._last_lstm_states) + + while n_steps < n_rollout_steps: + if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.policy.reset_noise(env.num_envs) + + with th.no_grad(): + # Convert to pytorch tensor or to TensorDict + obs_tensor = obs_as_tensor(self._last_obs, self.device) + episode_starts = th.tensor(self._last_episode_starts, dtype=th.float32, device=self.device) + # This is the only change related to invalid action masking + if use_masking: + action_masks = get_action_masks(env) + actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts, action_masks=action_masks) + + actions = actions.cpu().numpy() + + # Rescale and perform action + clipped_actions = actions + # Clip the actions to avoid out of bound error + if isinstance(self.action_space, spaces.Box): + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + + new_obs, rewards, dones, infos = env.step(clipped_actions) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if not callback.on_step(): + return False + + self._update_info_buffer(infos, dones) + n_steps += 1 + + if isinstance(self.action_space, spaces.Discrete): + # Reshape in case of discrete action + actions = actions.reshape(-1, 1) + + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + for idx, done_ in enumerate(dones): + if ( + done_ + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + with th.no_grad(): + terminal_lstm_state = ( + lstm_states.vf[0][:, idx : idx + 1, :].contiguous(), + lstm_states.vf[1][:, idx : idx + 1, :].contiguous(), + ) + # terminal_lstm_state = None + episode_starts = th.tensor([False], dtype=th.float32, device=self.device) + terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] + rewards[idx] += self.gamma * terminal_value + + rollout_buffer.add( + self._last_obs, + actions, + rewards, + self._last_episode_starts, + values, + log_probs, + lstm_states=self._last_lstm_states, + action_masks=action_masks + ) + + self._last_obs = new_obs + self._last_episode_starts = dones + self._last_lstm_states = lstm_states + + with th.no_grad(): + # Compute value for the last timestep + episode_starts = th.tensor(dones, dtype=th.float32, device=self.device) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts) + + rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) + + callback.on_rollout_end() + + return True + + def predict( + self, + observation: np.ndarray, + state: Optional[tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + action_masks: Optional[np.ndarray] = None, + ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. + :param deterministic: Whether or not to return deterministic actions. + :param action_masks: Optional mask + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + return self.policy.predict(observation, state, episode_start, deterministic, action_masks=action_masks) + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range(self._current_progress_remaining) + # Optional: clip range for the value function + if self.clip_range_vf is not None: + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + + entropy_losses = [] + pg_losses, value_losses = [], [] + clip_fractions = [] + + continue_training = True + + # train for n_epochs epochs + for epoch in range(self.n_epochs): + approx_kl_divs = [] + # Do a complete pass on the rollout buffer + for rollout_data in self.rollout_buffer.get(self.batch_size): + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + # Convert mask from float to bool + mask = rollout_data.mask > 1e-8 + + values, log_prob, entropy = self.policy.evaluate_actions( + rollout_data.observations, + actions, + rollout_data.lstm_states, + rollout_data.episode_starts, + action_masks=rollout_data.action_masks + ) + + values = values.flatten() + # Normalize advantage + advantages = rollout_data.advantages + if self.normalize_advantage: + advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = th.exp(log_prob - rollout_data.old_log_prob) + + # clipped surrogate loss + policy_loss_1 = advantages * ratio + policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) + policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask]) + + # Logging + pg_losses.append(policy_loss.item()) + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item() + clip_fractions.append(clip_fraction) + + if self.clip_range_vf is None: + # No clipping + values_pred = values + else: + # Clip the different between old and new value + # NOTE: this depends on the reward scaling + values_pred = rollout_data.old_values + th.clamp( + values - rollout_data.old_values, -clip_range_vf, clip_range_vf + ) + # Value loss using the TD(gae_lambda) target + # Mask padded sequences + value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask]) + + value_losses.append(value_loss.item()) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -th.mean(-log_prob[mask]) + else: + entropy_loss = -th.mean(entropy[mask]) + + entropy_losses.append(entropy_loss.item()) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with th.no_grad(): + log_ratio = log_prob - rollout_data.old_log_prob + approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy() + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + if self.verbose >= 1: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + break + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + if not continue_training: + break + + self._n_updates += self.n_epochs + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + + # Logs + self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/loss", loss.item()) + self.logger.record("train/explained_variance", explained_var) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + if self.clip_range_vf is not None: + self.logger.record("train/clip_range_vf", clip_range_vf) + + def learn( + self: SelfMaskableRecurrentPPO, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + tb_log_name: str = "MaskableRecurrentPPO", + reset_num_timesteps: bool = True, + use_masking: bool = True, + progress_bar: bool = False, + ) -> SelfMaskableRecurrentPPO: + iteration = 0 + + total_timesteps, callback = self._setup_learn( + total_timesteps, + callback, + reset_num_timesteps, + tb_log_name, + progress_bar, + ) + + callback.on_training_start(locals(), globals()) + + assert self.env is not None + + while self.num_timesteps < total_timesteps: + continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking) + + if not continue_training: + break + + iteration += 1 + self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + + # Display training infos + if log_interval is not None and iteration % log_interval == 0: + self.dump_logs(iteration) + + self.train() + + callback.on_training_end() + + return self + + def _excluded_save_params(self) -> list[str]: + return super()._excluded_save_params() + ["_last_lstm_states"] # noqa: RUF005