diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index a1be8cfd..474bab5c 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -407,10 +407,10 @@ def train(self) -> None: th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() + self._n_updates += 1 if not continue_training: break - self._n_updates += self.n_epochs explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index aaf2756a..adfe7727 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -416,10 +416,10 @@ def train(self) -> None: th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() + self._n_updates += 1 if not continue_training: break - self._n_updates += self.n_epochs explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs