Skip to content

[Feature Request] **TrajectoryBatcher** high-level util for complete episode batching #3234

@Kirire

Description

@Kirire

Motivation

When using MultiSyncDataCollector with split_trajs=True, it is surprisingly difficult to obtain a batch of N complete trajectories.
Many RL algorithms (Monte Carlo / REINFORCE, episodic PPO rollouts, imitation learning, some RLHF setups) require full episodes padded into a single batch.

Right now users need to manually:

  • track traj_ids,
  • rebuild trajectories across collector iterations,
  • detect completion via mask / done,
  • pad variable-length episodes,
  • optionally discard “mixed-policy” episodes when the policy is updated.

This logic is both error-prone and reimplemented repeatedly in user code.

There is currently no high-level utility in TorchRL that provides “give me exactly N full episodes as a padded TensorDict”, even though the underlying collectors already expose all the needed information.

Solution

I propose a new utility, tentatively named TrajectoryBatcher, in torchrl.collectors.utils.

It would wrap any TorchRL collector and yield batches of exactly num_trajectories full episodes, reconstructed across collector steps and padded to max_length. Example:

batcher = TrajectoryBatcher(
    collector,
    num_trajectories=32,
    strict_on_policy=True,
)

batch = next(iter(batcher))
batch["observation"].shape  # [32, max_len, obs_dim]
batch["action"].shape       # padded
batch["mask"].shape         # valid-step mask

Features:

  • reconstruct trajectories split across collector iterations (using traj_ids)
  • optional strict on-policy mode (“burn” ongoing trajectories after a policy update)
  • padded output using standard TensorDict conventions
  • compatible with multi-worker collectors without resetting environments

I already have a working prototype and can contribute the implementation, documentation, and unit tests.

Alternatives

  • Users manually stitch trajectories in Python (complex and repetitive).
  • Resetting environments at every update is possible but costly and problematic with multi-process collectors.
  • Using frames_per_batch as a proxy for episodes is unreliable and does not guarantee complete trajectories.

These alternatives are not ideal and introduce unnecessary boilerplate for common RL workflows.

Additional context

This feature would make it much easier to train Monte Carlo / episodic algorithms on top of TorchRL collectors, and aligns with patterns commonly found in other RL libraries (e.g., SB3 rollouts, RLlib sample batches, Acme episodes).

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions