From 8526296ba02afc255aa6871dd4c2749b49ec865e Mon Sep 17 00:00:00 2001 From: Dongli He Date: Mon, 17 Jun 2024 19:38:48 +0400 Subject: [PATCH 1/4] case-insensitive arg --- main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/main.py b/main.py index 12564f6b..95d35548 100644 --- a/main.py +++ b/main.py @@ -77,6 +77,7 @@ def add_all_arguments(parser): "--optimizer", default="adam", choices=["adam", "adamw", "adamax", "sgd"], + type=str.lower, help="Optimizer (default: %(default)s)", ) parser.add_argument( From 1207d0c90674d20b4c0325ef265aeb513b367b1c Mon Sep 17 00:00:00 2001 From: Dongli He Date: Mon, 17 Jun 2024 20:10:58 +0400 Subject: [PATCH 2/4] optimize arguments in Model --- libmultilabel/nn/model.py | 24 ++++++++---------------- libmultilabel/nn/nn_utils.py | 12 +++--------- main.py | 4 +++- torch_trainer.py | 4 +--- 4 files changed, 15 insertions(+), 29 deletions(-) diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index 1f0ab95f..d8502c7c 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -15,10 +15,8 @@ class MultiLabelModel(L.LightningModule): Args: num_classes (int): Total number of classes. - learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001. optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'. - momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9. - weight_decay (int, optional): Weight decay factor. Defaults to 0. + optimizer_config (dict, optional): Optimizer parameters. The keys in the dictionary should match the parameter names defined by PyTorch for the optimizer. metric_threshold (float, optional): The decision value threshold over which a label is predicted as positive. Defaults to 0.5. monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None. log_path (str): Path to a directory holding the log files and models. @@ -30,10 +28,8 @@ class MultiLabelModel(L.LightningModule): def __init__( self, num_classes, - learning_rate=0.0001, optimizer="adam", - momentum=0.9, - weight_decay=0, + optimizer_config=None, lr_scheduler=None, scheduler_config=None, val_metric=None, @@ -43,15 +39,13 @@ def __init__( multiclass=False, silent=False, save_k_predictions=0, - **kwargs + **kwargs, ): super().__init__() # optimizer - self.learning_rate = learning_rate self.optimizer = optimizer - self.momentum = momentum - self.weight_decay = weight_decay + self.optimizer_config = optimizer_config if optimizer_config is not None else {} # lr_scheduler self.lr_scheduler = lr_scheduler @@ -78,15 +72,13 @@ def configure_optimizers(self): parameters = [p for p in self.parameters() if p.requires_grad] optimizer_name = self.optimizer if optimizer_name == "sgd": - optimizer = optim.SGD( - parameters, self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay - ) + optimizer = optim.SGD(parameters, **self.optimizer_config) elif optimizer_name == "adam": - optimizer = optim.Adam(parameters, weight_decay=self.weight_decay, lr=self.learning_rate) + optimizer = optim.Adam(parameters, **self.optimizer_config) elif optimizer_name == "adamw": - optimizer = optim.AdamW(parameters, weight_decay=self.weight_decay, lr=self.learning_rate) + optimizer = optim.AdamW(parameters, **self.optimizer_config) elif optimizer_name == "adamax": - optimizer = optim.Adamax(parameters, weight_decay=self.weight_decay, lr=self.learning_rate) + optimizer = optim.Adamax(parameters, **self.optimizer_config) else: raise RuntimeError("Unsupported optimizer: {self.optimizer}") diff --git a/libmultilabel/nn/nn_utils.py b/libmultilabel/nn/nn_utils.py index a4ac82c2..e6e75784 100644 --- a/libmultilabel/nn/nn_utils.py +++ b/libmultilabel/nn/nn_utils.py @@ -41,10 +41,8 @@ def init_model( embed_vecs=None, init_weight=None, log_path=None, - learning_rate=0.0001, optimizer="adam", - momentum=0.9, - weight_decay=0, + optimizer_config=None, lr_scheduler=None, scheduler_config=None, val_metric=None, @@ -69,10 +67,8 @@ def init_model( For example, the `init_weight` of `torch.nn.init.kaiming_uniform_` is `kaiming_uniform`. Defaults to None. log_path (str): Path to a directory holding the log files and models. - learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001. optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'. - momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9. - weight_decay (int, optional): Weight decay factor. Defaults to 0. + optimizer_config (dict, optional): Optimizer parameters. The keys in the dictionary should match the parameter names defined by PyTorch for the optimizer. lr_scheduler (str, optional): Name of the learning rate scheduler. Defaults to None. scheduler_config (dict, optional): The configuration for learning rate scheduler. Defaults to None. val_metric (str, optional): The metric to select the best model for testing. Used by some of the schedulers. Defaults to None. @@ -102,10 +98,8 @@ def init_model( word_dict=word_dict, network=network, log_path=log_path, - learning_rate=learning_rate, optimizer=optimizer, - momentum=momentum, - weight_decay=weight_decay, + optimizer_config=optimizer_config, lr_scheduler=lr_scheduler, scheduler_config=scheduler_config, val_metric=val_metric, diff --git a/main.py b/main.py index 95d35548..40c5ea81 100644 --- a/main.py +++ b/main.py @@ -77,7 +77,6 @@ def add_all_arguments(parser): "--optimizer", default="adam", choices=["adam", "adamw", "adamax", "sgd"], - type=str.lower, help="Optimizer (default: %(default)s)", ) parser.add_argument( @@ -266,6 +265,9 @@ def get_config(): args.early_stopping_metric = args.val_metric if not hasattr(args, "scheduler_config"): args.scheduler_config = None + args.optimizer_config = {"lr": args.learning_rate, "weight_decay": args.weight_decay} + if args.optimizer == "sgd": + args.optimizer_config["momentum"] = args.momentum config = AttributeDict(vars(args)) config.run_name = "{}_{}_{}".format( diff --git a/torch_trainer.py b/torch_trainer.py index 8dc259b5..bfb6370c 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -189,10 +189,8 @@ def _setup_model( embed_vecs=embed_vecs, init_weight=self.config.init_weight, log_path=log_path, - learning_rate=self.config.learning_rate, optimizer=self.config.optimizer, - momentum=self.config.momentum, - weight_decay=self.config.weight_decay, + optimizer_config=self.config.optimizer_config, lr_scheduler=self.config.lr_scheduler, scheduler_config=self.config.scheduler_config, val_metric=self.config.val_metric, From 027e5464728511d839a8c8df0d0a5ccf64dcd975 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Mon, 17 Jun 2024 20:16:09 +0400 Subject: [PATCH 3/4] add docstring for lr_scheduler --- libmultilabel/nn/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index d8502c7c..f878ea38 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -17,6 +17,8 @@ class MultiLabelModel(L.LightningModule): num_classes (int): Total number of classes. optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'. optimizer_config (dict, optional): Optimizer parameters. The keys in the dictionary should match the parameter names defined by PyTorch for the optimizer. + lr_scheduler: (str, optional): Learning rate scheduler. Defaults to None, i.e., no learning rate scheduler. Currently, the only supported lr_scheduler is 'ReduceLROnPlateau'. + scheduler_config (dict, optional): Learning rate scheduler parameters. The keys in the dictionary should match the parameter names defined by PyTorch for the learning rate scheduler. metric_threshold (float, optional): The decision value threshold over which a label is predicted as positive. Defaults to 0.5. monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None. log_path (str): Path to a directory holding the log files and models. From 4c4df39995db3100fe6a9777d220ed839df0ea5d Mon Sep 17 00:00:00 2001 From: He Dongli Date: Fri, 21 Jun 2024 06:59:58 +0000 Subject: [PATCH 4/4] fix f-string --- libmultilabel/nn/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index f878ea38..4bde03d7 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -82,7 +82,7 @@ def configure_optimizers(self): elif optimizer_name == "adamax": optimizer = optim.Adamax(parameters, **self.optimizer_config) else: - raise RuntimeError("Unsupported optimizer: {self.optimizer}") + raise RuntimeError(f"Unsupported optimizer: {self.optimizer}") if self.lr_scheduler: if self.lr_scheduler == "ReduceLROnPlateau":