Skip to content

[Bug]: MaskablePPO Inaccurate update counting when target_kl early exists #292

@Sean-Fuhrman

Description

@Sean-Fuhrman

🐛 Bug

When MaskablePPO early exits due to target_kl, n_updates is still updated by 'self.n_epochs' instead being incremented only on successful epochs. Therefore if it early exits at epoch 5/10, n_updates will be updated by 10 when it should be updated by 5.

To fix:
Line 413 of ppo_mask.py self._n_updates += self.n_epochs should be changed to self._n_updates += 1` and be moved to Line 409 inside the loop. To match normal PPO.

To Reproduce

from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.envs import InvalidActionEnvDiscrete

env = InvalidActionEnvDiscrete(dim=10, n_invalid_actions=3)

model = MaskablePPO(
    policy=MaskableActorCriticPolicy,
    env=env,
    verbose=1,
    target_kl=0.0003,   #set low to ensure early stop        
)

# 4) Train
model.learn(total_timesteps=100_000)

Relevant log output / Error message

System Info

No response

Checklist

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinggood first issueGood for newcomershelp wantedHelp from contributors is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions