diff --git a/torch_geometric/explain/algorithm/graphmask_explainer.py b/torch_geometric/explain/algorithm/graphmask_explainer.py index 31d9a8e24644..ecca54d6144b 100644 --- a/torch_geometric/explain/algorithm/graphmask_explainer.py +++ b/torch_geometric/explain/algorithm/graphmask_explainer.py @@ -86,9 +86,9 @@ def __init__( epochs: int = 100, lr: float = 0.01, penalty_scaling: int = 5, - lambda_optimizer_lr: int = 1e-2, - init_lambda: int = 0.55, - allowance: int = 0.03, + lambda_optimizer_lr: float = 1e-2, + init_lambda: float = 0.55, + allowance: float = 0.03, allow_multiple_explanations: bool = False, log: bool = True, **kwargs,