-
Notifications
You must be signed in to change notification settings - Fork 423
Description
Motivation
When using MultiSyncDataCollector with split_trajs=True, it is surprisingly difficult to obtain a batch of N complete trajectories.
Many RL algorithms (Monte Carlo / REINFORCE, episodic PPO rollouts, imitation learning, some RLHF setups) require full episodes padded into a single batch.
Right now users need to manually:
- track
traj_ids, - rebuild trajectories across collector iterations,
- detect completion via
mask/done, - pad variable-length episodes,
- optionally discard “mixed-policy” episodes when the policy is updated.
This logic is both error-prone and reimplemented repeatedly in user code.
There is currently no high-level utility in TorchRL that provides “give me exactly N full episodes as a padded TensorDict”, even though the underlying collectors already expose all the needed information.
Solution
I propose a new utility, tentatively named TrajectoryBatcher, in torchrl.collectors.utils.
It would wrap any TorchRL collector and yield batches of exactly num_trajectories full episodes, reconstructed across collector steps and padded to max_length. Example:
batcher = TrajectoryBatcher(
collector,
num_trajectories=32,
strict_on_policy=True,
)
batch = next(iter(batcher))
batch["observation"].shape # [32, max_len, obs_dim]
batch["action"].shape # padded
batch["mask"].shape # valid-step maskFeatures:
- reconstruct trajectories split across collector iterations (using
traj_ids) - optional strict on-policy mode (“burn” ongoing trajectories after a policy update)
- padded output using standard TensorDict conventions
- compatible with multi-worker collectors without resetting environments
I already have a working prototype and can contribute the implementation, documentation, and unit tests.
Alternatives
- Users manually stitch trajectories in Python (complex and repetitive).
- Resetting environments at every update is possible but costly and problematic with multi-process collectors.
- Using
frames_per_batchas a proxy for episodes is unreliable and does not guarantee complete trajectories.
These alternatives are not ideal and introduce unnecessary boilerplate for common RL workflows.
Additional context
This feature would make it much easier to train Monte Carlo / episodic algorithms on top of TorchRL collectors, and aligns with patterns commonly found in other RL libraries (e.g., SB3 rollouts, RLlib sample batches, Acme episodes).
Checklist
- I have checked that there is no similar issue in the repo (required)