-
Notifications
You must be signed in to change notification settings - Fork 222
Open
Labels
bugSomething isn't workingSomething isn't workinggood first issueGood for newcomersGood for newcomershelp wantedHelp from contributors is neededHelp from contributors is needed
Description
🐛 Bug
When MaskablePPO early exits due to target_kl, n_updates is still updated by 'self.n_epochs' instead being incremented only on successful epochs. Therefore if it early exits at epoch 5/10, n_updates will be updated by 10 when it should be updated by 5.
To fix:
Line 413 of ppo_mask.py self._n_updates += self.n_epochs should be changed to self._n_updates += 1` and be moved to Line 409 inside the loop. To match normal PPO.
To Reproduce
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
env = InvalidActionEnvDiscrete(dim=10, n_invalid_actions=3)
model = MaskablePPO(
policy=MaskableActorCriticPolicy,
env=env,
verbose=1,
target_kl=0.0003, #set low to ensure early stop
)
# 4) Train
model.learn(total_timesteps=100_000)Relevant log output / Error message
System Info
No response
Checklist
- I have checked that there is no similar issue in the repo
- I have read the documentation
- I have provided a minimal and working example to reproduce the bug
- I've used the markdown code blocks for both code and stack traces.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinggood first issueGood for newcomersGood for newcomershelp wantedHelp from contributors is neededHelp from contributors is needed