Skip to content

Conversation

@alektebel
Copy link

Description

This PR fixes incorrect n_updates counting in both MaskablePPO and RecurrentPPO when early stopping is triggered due to the target_kl threshold.

The Problem:

  • Previously, n_updates was always incremented by the full n_epochs value (default: 10) regardless of actual epochs completed
  • When target_kl triggered early stopping (e.g., at epoch 5), n_updates would still increment by 10 instead of 5
  • This caused inaccurate learning rate scheduling, progress tracking, and inconsistent behavior with regular PPO

The Solution:

  • Move n_updates increment inside the epoch loop to count only completed epochs
  • Each epoch now increments n_updates by 1, matching the behavior of base PPO
  • Early stopping now correctly reflects actual training progress

Files Changed:

  • sb3_contrib/ppo_mask/ppo_mask.py
  • sb3_contrib/ppo_recurrent/ppo_recurrent.py

Testing

The fix has been verified with:

  • Reproduction cases showing the bug in both algorithms
  • Logs demonstrating correct counting after the fix
  • Early stopping at various steps (0-7) to validate accurate incrementing

Impact

  • Accurate learning rate scheduling
  • Proper training progress tracking
  • Consistent behavior with Stable-Baselines3 PPO
  • No breaking changes - only affects counting logic

…pping

- Count actual epochs completed instead of full n_epochs
- Fix affects both algorithms when target_kl triggers early stopping
- Ensures accurate learning rate scheduling and progress tracking
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant