From e536f79df5dc53a0d0f481f2082e921bbe35ab40 Mon Sep 17 00:00:00 2001 From: Sarthak Dayal Date: Sun, 20 Jul 2025 01:42:36 -0500 Subject: [PATCH 1/3] Refactor maskable categoricals and improve logit updates --- sb3_contrib/common/maskable/distributions.py | 97 ++++++++++++-------- 1 file changed, 59 insertions(+), 38 deletions(-) diff --git a/sb3_contrib/common/maskable/distributions.py b/sb3_contrib/common/maskable/distributions.py index d2a92ae0..3d0e9aea 100644 --- a/sb3_contrib/common/maskable/distributions.py +++ b/sb3_contrib/common/maskable/distributions.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, TypeVar, Union +from typing import Optional, TypeVar, Union, List, Tuple import numpy as np import torch as th @@ -7,7 +7,7 @@ from stable_baselines3.common.distributions import Distribution from torch import nn from torch.distributions import Categorical -from torch.distributions.utils import logits_to_probs +from torch.distributions.utils import probs_to_logits SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistribution", bound="MaskableCategoricalDistribution") SelfMaskableMultiCategoricalDistribution = TypeVar( @@ -16,6 +16,29 @@ MaybeMasks = Union[th.Tensor, np.ndarray, None] +def _mask_logits(logits: th.Tensor, mask: MaybeMasks, neg_inf: float) -> th.Tensor: + """ + Eliminate chosen categorical outcomes by setting their logits to `neg_inf`. + + :param logits: A tensor of unnormalized log probabilities (logits) for the categorical distribution. + The shape should be compatible with the mask. + + :param mask: An optional boolean ndarray of compatible shape with the distribution. + If True, the corresponding choice's logit value is preserved. If False, it is set + to a large negative value, resulting in near 0 probability. If mask is None, any + previously applied masking is removed, and the original logits are restored. + + :param neg_inf: The value to use for masked logits, typically negative infinity + to ensure the masked actions have zero (or near-zero) probability when passed + through a softmax or categorical distribution. + """ + + if mask is None: + return logits + mask_t = th.as_tensor(mask, dtype=th.bool, device=logits.device).reshape(logits.shape) + return th.where(mask_t, logits, th.tensor(neg_inf, dtype=logits.dtype, device=logits.device)) + + class MaskableCategorical(Categorical): """ Modified PyTorch Categorical distribution with support for invalid action masking. @@ -39,49 +62,47 @@ def __init__( validate_args: Optional[bool] = None, masks: MaybeMasks = None, ): - self.masks: Optional[th.Tensor] = None - super().__init__(probs, logits, validate_args) - self._original_logits = self.logits - self.apply_masking(masks) + # Validate that exactly one of probs or logits is provided + if (probs is None) == (logits is None): + raise ValueError("Specify exactly one of probs or logits but not both.") + + # If probs provided, convert it to logits + if probs is not None: + logits = probs_to_logits(probs) + + # Save pristine logits for later masking + self._original_logits = logits.detach().clone() + self._neg_inf = float("-inf") + self.masks = None if masks is None else th.as_tensor(masks, dtype=th.bool, device=logits.device).reshape( + logits.shape + ) + masked_logits = _mask_logits(logits, self.masks, self._neg_inf) + super().__init__(logits=masked_logits, validate_args=validate_args) def apply_masking(self, masks: MaybeMasks) -> None: - """ - Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0. - - :param masks: An optional boolean ndarray of compatible shape with the distribution. - If True, the corresponding choice's logit value is preserved. If False, it is set - to a large negative value, resulting in near 0 probability. If masks is None, any - previously applied masking is removed, and the original logits are restored. - """ - - if masks is not None: - device = self.logits.device - self.masks = th.as_tensor(masks, dtype=th.bool, device=device).reshape(self.logits.shape) - HUGE_NEG = th.tensor(-1e8, dtype=self.logits.dtype, device=device) - - logits = th.where(self.masks, self._original_logits, HUGE_NEG) - else: + if masks is None: self.masks = None logits = self._original_logits - + else: + self.masks = th.as_tensor(masks, dtype=th.bool, device=self._original_logits.device).reshape( + self._original_logits.shape + ) + logits = _mask_logits(self._original_logits, self.masks, self._neg_inf) # Reinitialize with updated logits super().__init__(logits=logits) - # self.probs may already be cached, so we must force an update - self.probs = logits_to_probs(self.logits) - def entropy(self) -> th.Tensor: if self.masks is None: return super().entropy() - # Highly negative logits don't result in 0 probs, so we must replace - # with 0s to ensure 0 contribution to the distribution's entropy, since - # masked actions possess no uncertainty. - device = self.logits.device - p_log_p = self.logits * self.probs - p_log_p = th.where(self.masks, p_log_p, th.tensor(0.0, device=device)) - return -p_log_p.sum(-1) - + # Prevent numerical issues with masked logits + min_real = th.finfo(self.logits.dtype).min + logits = self.logits.clone() + mask = (~self.masks) | (~logits.isfinite()) + logits = logits.masked_fill(mask, min_real) + logits = logits - logits.logsumexp(-1, keepdim=True) + probs = logits.exp() + return -(logits * probs).sum(-1) class MaskableDistribution(Distribution, ABC): @abstractmethod @@ -157,7 +178,7 @@ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = Fa self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob @@ -174,9 +195,9 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution): :param action_dims: List of sizes of discrete action spaces """ - def __init__(self, action_dims: list[int]): + def __init__(self, action_dims: List[int]): super().__init__() - self.distributions: list[MaskableCategorical] = [] + self.distributions: List[MaskableCategorical] = [] self.action_dims = action_dims def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -232,7 +253,7 @@ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = Fa self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob From d698af53f936293d1b1ac956ba13ce75b4a5f12b Mon Sep 17 00:00:00 2001 From: Sarthak Dayal Date: Sun, 20 Jul 2025 11:53:47 -0500 Subject: [PATCH 2/3] Update changelog with bug fix --- docs/misc/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index d6df17e1..eaa90a78 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -17,6 +17,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Use the ``FloatSchedule`` and ``LinearSchedule`` classes instead of lambdas in the ARS, PPO, and QRDQN implementations to improve model portability across different operating systems +- Fixed a bug in the ``MaskableCategoricalDistribution`` and ``MaskableMultiCategoricalDistribution`` classes where the `apply_masking` method was not correctly handling the masks for multi-dimensional action spaces Deprecations: ^^^^^^^^^^^^^ From fc8601f7a96049cc450fe825ae72cbe23d445b34 Mon Sep 17 00:00:00 2001 From: Sarthak Dayal Date: Sun, 20 Jul 2025 12:11:34 -0500 Subject: [PATCH 3/3] Fix formatting and linting issues --- sb3_contrib/common/maskable/distributions.py | 27 ++++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/sb3_contrib/common/maskable/distributions.py b/sb3_contrib/common/maskable/distributions.py index 3d0e9aea..3e59ee3a 100644 --- a/sb3_contrib/common/maskable/distributions.py +++ b/sb3_contrib/common/maskable/distributions.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, TypeVar, Union, List, Tuple +from typing import Optional, TypeVar, Union import numpy as np import torch as th @@ -19,7 +19,7 @@ def _mask_logits(logits: th.Tensor, mask: MaybeMasks, neg_inf: float) -> th.Tensor: """ Eliminate chosen categorical outcomes by setting their logits to `neg_inf`. - + :param logits: A tensor of unnormalized log probabilities (logits) for the categorical distribution. The shape should be compatible with the mask. @@ -27,9 +27,9 @@ def _mask_logits(logits: th.Tensor, mask: MaybeMasks, neg_inf: float) -> th.Tens If True, the corresponding choice's logit value is preserved. If False, it is set to a large negative value, resulting in near 0 probability. If mask is None, any previously applied masking is removed, and the original logits are restored. - + :param neg_inf: The value to use for masked logits, typically negative infinity - to ensure the masked actions have zero (or near-zero) probability when passed + to ensure the masked actions have zero (or near-zero) probability when passed through a softmax or categorical distribution. """ @@ -65,17 +65,15 @@ def __init__( # Validate that exactly one of probs or logits is provided if (probs is None) == (logits is None): raise ValueError("Specify exactly one of probs or logits but not both.") - + # If probs provided, convert it to logits - if probs is not None: + if logits is None: logits = probs_to_logits(probs) - + # Save pristine logits for later masking self._original_logits = logits.detach().clone() self._neg_inf = float("-inf") - self.masks = None if masks is None else th.as_tensor(masks, dtype=th.bool, device=logits.device).reshape( - logits.shape - ) + self.masks = None if masks is None else th.as_tensor(masks, dtype=th.bool, device=logits.device).reshape(logits.shape) masked_logits = _mask_logits(logits, self.masks, self._neg_inf) super().__init__(logits=masked_logits, validate_args=validate_args) @@ -104,6 +102,7 @@ def entropy(self) -> th.Tensor: probs = logits.exp() return -(logits * probs).sum(-1) + class MaskableDistribution(Distribution, ABC): @abstractmethod def apply_masking(self, masks: MaybeMasks) -> None: @@ -178,7 +177,7 @@ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = Fa self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob @@ -195,9 +194,9 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution): :param action_dims: List of sizes of discrete action spaces """ - def __init__(self, action_dims: List[int]): + def __init__(self, action_dims: list[int]): super().__init__() - self.distributions: List[MaskableCategorical] = [] + self.distributions: list[MaskableCategorical] = [] self.action_dims = action_dims def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -253,7 +252,7 @@ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = Fa self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob