99import warnings
1010from copy import deepcopy
1111from functools import partial , wraps
12- from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple
12+ from typing import Any , Callable , Iterator
1313
1414import numpy as np
1515import torch
@@ -457,7 +457,7 @@ def __init__(
457457 self ,
458458 * ,
459459 device : DEVICE_TYPING = None ,
460- batch_size : Optional [ torch .Size ] = None ,
460+ batch_size : torch .Size | None = None ,
461461 run_type_checks : bool = False ,
462462 allow_done_after_reset : bool = False ,
463463 spec_locked : bool = True ,
@@ -568,10 +568,10 @@ def auto_specs_(
568568 policy : Callable [[TensorDictBase ], TensorDictBase ],
569569 * ,
570570 tensordict : TensorDictBase | None = None ,
571- action_key : NestedKey | List [NestedKey ] = "action" ,
572- done_key : NestedKey | List [NestedKey ] | None = None ,
573- observation_key : NestedKey | List [NestedKey ] = "observation" ,
574- reward_key : NestedKey | List [NestedKey ] = "reward" ,
571+ action_key : NestedKey | list [NestedKey ] = "action" ,
572+ done_key : NestedKey | list [NestedKey ] | None = None ,
573+ observation_key : NestedKey | list [NestedKey ] = "observation" ,
574+ reward_key : NestedKey | list [NestedKey ] = "reward" ,
575575 ):
576576 """Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
577577
@@ -673,7 +673,7 @@ def auto_specs_(
673673 if full_action_spec is not None :
674674 self .full_action_spec = full_action_spec
675675 if full_done_spec is not None :
676- self .full_done_specs = full_done_spec
676+ self .full_done_spec = full_done_spec
677677 if full_observation_spec is not None :
678678 self .full_observation_spec = full_observation_spec
679679 if full_reward_spec is not None :
@@ -685,8 +685,7 @@ def auto_specs_(
685685
686686 @wraps (check_env_specs_func )
687687 def check_env_specs (self , * args , ** kwargs ):
688- return_contiguous = kwargs .pop ("return_contiguous" , not self ._has_dynamic_specs )
689- kwargs ["return_contiguous" ] = return_contiguous
688+ kwargs .setdefault ("return_contiguous" , not self ._has_dynamic_specs )
690689 return check_env_specs_func (self , * args , ** kwargs )
691690
692691 check_env_specs .__doc__ = check_env_specs_func .__doc__
@@ -831,8 +830,7 @@ def ndim(self):
831830
832831 def append_transform (
833832 self ,
834- transform : "Transform" # noqa: F821
835- | Callable [[TensorDictBase ], TensorDictBase ],
833+ transform : Transform | Callable [[TensorDictBase ], TensorDictBase ], # noqa: F821
836834 ) -> EnvBase :
837835 """Returns a transformed environment where the callable/transform passed is applied.
838836
@@ -976,7 +974,7 @@ def output_spec(self, value: TensorSpec) -> None:
976974
977975 @property
978976 @_cache_value
979- def action_keys (self ) -> List [NestedKey ]:
977+ def action_keys (self ) -> list [NestedKey ]:
980978 """The action keys of an environment.
981979
982980 By default, there will only be one key named "action".
@@ -989,7 +987,7 @@ def action_keys(self) -> List[NestedKey]:
989987
990988 @property
991989 @_cache_value
992- def state_keys (self ) -> List [NestedKey ]:
990+ def state_keys (self ) -> list [NestedKey ]:
993991 """The state keys of an environment.
994992
995993 By default, there will only be one key named "state".
@@ -1186,7 +1184,7 @@ def full_action_spec(self, spec: Composite) -> None:
11861184 # Reward spec
11871185 @property
11881186 @_cache_value
1189- def reward_keys (self ) -> List [NestedKey ]:
1187+ def reward_keys (self ) -> list [NestedKey ]:
11901188 """The reward keys of an environment.
11911189
11921190 By default, there will only be one key named "reward".
@@ -1196,6 +1194,20 @@ def reward_keys(self) -> List[NestedKey]:
11961194 reward_keys = sorted (self .full_reward_spec .keys (True , True ), key = _repr_by_depth )
11971195 return reward_keys
11981196
1197+ @property
1198+ @_cache_value
1199+ def observation_keys (self ) -> list [NestedKey ]:
1200+ """The observation keys of an environment.
1201+
1202+ By default, there will only be one key named "observation".
1203+
1204+ Keys are sorted by depth in the data tree.
1205+ """
1206+ observation_keys = sorted (
1207+ self .full_observation_spec .keys (True , True ), key = _repr_by_depth
1208+ )
1209+ return observation_keys
1210+
11991211 @property
12001212 def reward_key (self ):
12011213 """The reward key of an environment.
@@ -1383,7 +1395,7 @@ def full_reward_spec(self, spec: Composite) -> None:
13831395 # done spec
13841396 @property
13851397 @_cache_value
1386- def done_keys (self ) -> List [NestedKey ]:
1398+ def done_keys (self ) -> list [NestedKey ]:
13871399 """The done keys of an environment.
13881400
13891401 By default, there will only be one key named "done".
@@ -2113,8 +2125,8 @@ def register_gym(
21132125 id : str ,
21142126 * ,
21152127 entry_point : Callable | None = None ,
2116- transform : " Transform" | None = None , # noqa: F821
2117- info_keys : List [NestedKey ] | None = None ,
2128+ transform : Transform | None = None , # noqa: F821
2129+ info_keys : list [NestedKey ] | None = None ,
21182130 backend : str = None ,
21192131 to_numpy : bool = False ,
21202132 reward_threshold : float | None = None ,
@@ -2303,8 +2315,8 @@ def _register_gym(
23032315 cls ,
23042316 id ,
23052317 entry_point : Callable | None = None ,
2306- transform : " Transform" | None = None , # noqa: F821
2307- info_keys : List [NestedKey ] | None = None ,
2318+ transform : Transform | None = None , # noqa: F821
2319+ info_keys : list [NestedKey ] | None = None ,
23082320 to_numpy : bool = False ,
23092321 reward_threshold : float | None = None ,
23102322 nondeterministic : bool = False ,
@@ -2345,8 +2357,8 @@ def _register_gym( # noqa: F811
23452357 cls ,
23462358 id ,
23472359 entry_point : Callable | None = None ,
2348- transform : " Transform" | None = None , # noqa: F821
2349- info_keys : List [NestedKey ] | None = None ,
2360+ transform : Transform | None = None , # noqa: F821
2361+ info_keys : list [NestedKey ] | None = None ,
23502362 to_numpy : bool = False ,
23512363 reward_threshold : float | None = None ,
23522364 nondeterministic : bool = False ,
@@ -2393,8 +2405,8 @@ def _register_gym( # noqa: F811
23932405 cls ,
23942406 id ,
23952407 entry_point : Callable | None = None ,
2396- transform : " Transform" | None = None , # noqa: F821
2397- info_keys : List [NestedKey ] | None = None ,
2408+ transform : Transform | None = None , # noqa: F821
2409+ info_keys : list [NestedKey ] | None = None ,
23982410 to_numpy : bool = False ,
23992411 reward_threshold : float | None = None ,
24002412 nondeterministic : bool = False ,
@@ -2446,8 +2458,8 @@ def _register_gym( # noqa: F811
24462458 cls ,
24472459 id ,
24482460 entry_point : Callable | None = None ,
2449- transform : " Transform" | None = None , # noqa: F821
2450- info_keys : List [NestedKey ] | None = None ,
2461+ transform : Transform | None = None , # noqa: F821
2462+ info_keys : list [NestedKey ] | None = None ,
24512463 to_numpy : bool = False ,
24522464 reward_threshold : float | None = None ,
24532465 nondeterministic : bool = False ,
@@ -2502,8 +2514,8 @@ def _register_gym( # noqa: F811
25022514 cls ,
25032515 id ,
25042516 entry_point : Callable | None = None ,
2505- transform : " Transform" | None = None , # noqa: F821
2506- info_keys : List [NestedKey ] | None = None ,
2517+ transform : Transform | None = None , # noqa: F821
2518+ info_keys : list [NestedKey ] | None = None ,
25072519 to_numpy : bool = False ,
25082520 reward_threshold : float | None = None ,
25092521 nondeterministic : bool = False ,
@@ -2560,8 +2572,8 @@ def _register_gym( # noqa: F811
25602572 cls ,
25612573 id ,
25622574 entry_point : Callable | None = None ,
2563- transform : " Transform" | None = None , # noqa: F821
2564- info_keys : List [NestedKey ] | None = None ,
2575+ transform : Transform | None = None , # noqa: F821
2576+ info_keys : list [NestedKey ] | None = None ,
25652577 to_numpy : bool = False ,
25662578 reward_threshold : float | None = None ,
25672579 nondeterministic : bool = False ,
@@ -2618,7 +2630,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
26182630
26192631 def reset (
26202632 self ,
2621- tensordict : Optional [ TensorDictBase ] = None ,
2633+ tensordict : TensorDictBase | None = None ,
26222634 ** kwargs ,
26232635 ) -> TensorDictBase :
26242636 """Resets the environment.
@@ -2727,8 +2739,8 @@ def numel(self) -> int:
27272739 return prod (self .batch_size )
27282740
27292741 def set_seed (
2730- self , seed : Optional [ int ] = None , static_seed : bool = False
2731- ) -> Optional [ int ] :
2742+ self , seed : int | None = None , static_seed : bool = False
2743+ ) -> int | None :
27322744 """Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
27332745
27342746 Args:
@@ -2749,7 +2761,7 @@ def set_seed(
27492761 return seed
27502762
27512763 @abc .abstractmethod
2752- def _set_seed (self , seed : Optional [ int ] ):
2764+ def _set_seed (self , seed : int | None ):
27532765 raise NotImplementedError
27542766
27552767 def set_state (self ):
@@ -2764,7 +2776,26 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
27642776 f"got { tensordict .batch_size } and { self .batch_size } "
27652777 )
27662778
2767- def rand_action (self , tensordict : Optional [TensorDictBase ] = None ):
2779+ def all_actions (self , tensordict : TensorDictBase | None = None ) -> TensorDictBase :
2780+ """Generates all possible actions from the action spec.
2781+
2782+ This only works in environments with fully discrete actions.
2783+
2784+ Args:
2785+ tensordict (TensorDictBase, optional): If given, :meth:`~.reset`
2786+ is called with this tensordict.
2787+
2788+ Returns:
2789+ a tensordict object with the "action" entry updated with a batch of
2790+ all possible actions. The actions are stacked together in the
2791+ leading dimension.
2792+ """
2793+ if tensordict is not None :
2794+ self .reset (tensordict )
2795+
2796+ return self .full_action_spec .enumerate (use_mask = True )
2797+
2798+ def rand_action (self , tensordict : TensorDictBase | None = None ):
27682799 """Performs a random action given the action_spec attribute.
27692800
27702801 Args:
@@ -2798,7 +2829,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None):
27982829 tensordict .update (r )
27992830 return tensordict
28002831
2801- def rand_step (self , tensordict : Optional [ TensorDictBase ] = None ) -> TensorDictBase :
2832+ def rand_step (self , tensordict : TensorDictBase | None = None ) -> TensorDictBase :
28022833 """Performs a random step in the environment given the action_spec attribute.
28032834
28042835 Args:
@@ -2834,15 +2865,15 @@ def _has_dynamic_specs(self) -> bool:
28342865 def rollout (
28352866 self ,
28362867 max_steps : int ,
2837- policy : Optional [ Callable [[TensorDictBase ], TensorDictBase ]] = None ,
2838- callback : Optional [ Callable [[TensorDictBase , ...], Any ]] = None ,
2868+ policy : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
2869+ callback : Callable [[TensorDictBase , ...], Any ] | None = None ,
28392870 * ,
28402871 auto_reset : bool = True ,
28412872 auto_cast_to_device : bool = False ,
28422873 break_when_any_done : bool | None = None ,
28432874 break_when_all_done : bool | None = None ,
28442875 return_contiguous : bool | None = False ,
2845- tensordict : Optional [ TensorDictBase ] = None ,
2876+ tensordict : TensorDictBase | None = None ,
28462877 set_truncated : bool = False ,
28472878 out = None ,
28482879 trust_policy : bool = False ,
@@ -3364,7 +3395,7 @@ def _rollout_nonstop(
33643395
33653396 def step_and_maybe_reset (
33663397 self , tensordict : TensorDictBase
3367- ) -> Tuple [TensorDictBase , TensorDictBase ]:
3398+ ) -> tuple [TensorDictBase , TensorDictBase ]:
33683399 """Runs a step in the environment and (partially) resets it if needed.
33693400
33703401 Args:
@@ -3472,7 +3503,7 @@ def empty_cache(self):
34723503
34733504 @property
34743505 @_cache_value
3475- def reset_keys (self ) -> List [NestedKey ]:
3506+ def reset_keys (self ) -> list [NestedKey ]:
34763507 """Returns a list of reset keys.
34773508
34783509 Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
@@ -3629,14 +3660,14 @@ class _EnvWrapper(EnvBase):
36293660 """
36303661
36313662 git_url : str = ""
3632- available_envs : Dict [str , Any ] = {}
3663+ available_envs : dict [str , Any ] = {}
36333664 libname : str = ""
36343665
36353666 def __init__ (
36363667 self ,
36373668 * args ,
36383669 device : DEVICE_TYPING = None ,
3639- batch_size : Optional [ torch .Size ] = None ,
3670+ batch_size : torch .Size | None = None ,
36403671 allow_done_after_reset : bool = False ,
36413672 spec_locked : bool = True ,
36423673 ** kwargs ,
@@ -3685,7 +3716,7 @@ def _sync_device(self):
36853716 return sync_func
36863717
36873718 @abc .abstractmethod
3688- def _check_kwargs (self , kwargs : Dict ):
3719+ def _check_kwargs (self , kwargs : dict ):
36893720 raise NotImplementedError
36903721
36913722 def __getattr__ (self , attr : str ) -> Any :
@@ -3711,7 +3742,7 @@ def __getattr__(self, attr: str) -> Any:
37113742 )
37123743
37133744 @abc .abstractmethod
3714- def _init_env (self ) -> Optional [ int ] :
3745+ def _init_env (self ) -> int | None :
37153746 """Runs all the necessary steps such that the environment is ready to use.
37163747
37173748 This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
@@ -3725,7 +3756,7 @@ def _init_env(self) -> Optional[int]:
37253756 raise NotImplementedError
37263757
37273758 @abc .abstractmethod
3728- def _build_env (self , ** kwargs ) -> " gym.Env" : # noqa: F821
3759+ def _build_env (self , ** kwargs ) -> gym .Env : # noqa: F821
37293760 """Creates an environment from the target library and stores it with the `_env` attribute.
37303761
37313762 When overwritten, this function should pass all the required kwargs to the env instantiation method.
@@ -3734,7 +3765,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
37343765 raise NotImplementedError
37353766
37363767 @abc .abstractmethod
3737- def _make_specs (self , env : " gym.Env" ) -> None : # noqa: F821
3768+ def _make_specs (self , env : gym .Env ) -> None : # noqa: F821
37383769 raise NotImplementedError
37393770
37403771 def close (self ) -> None :
@@ -3748,7 +3779,7 @@ def close(self) -> None:
37483779
37493780def make_tensordict (
37503781 env : _EnvWrapper ,
3751- policy : Optional [ Callable [[TensorDictBase , ...], TensorDictBase ]] = None ,
3782+ policy : Callable [[TensorDictBase , ...], TensorDictBase ] | None = None ,
37523783) -> TensorDictBase :
37533784 """Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
37543785
0 commit comments