Skip to content

Commit 350e8d9

Browse files
committed
refacto: prefer optim group lists instead of a single dict selector -> group
1 parent 74ca846 commit 350e8d9

File tree

7 files changed

+115
-29
lines changed

7 files changed

+115
-29
lines changed

changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- LinearSchedule (mostly used for LR scheduling) now allows a `end_value` parameter to configure if the learning rate should decay to zero or another value.
1111
- New `eds.explode` pipe that splits one document into multiple documents, one per span yielded by its `span_getter` parameter, each new document containing exactly that single span.
1212
- New `Training a span classifier` tutorial, and reorganized deep-learning docs
13+
- `ScheduledOptimizer` now warns when a parameter selector does not match any parameter.
1314

1415
## Fixed
1516

@@ -24,6 +25,7 @@
2425

2526
- Sections cues in `eds.history` are now section titles, and not the full section.
2627
- :boom: Validation metrics are now found under the root field `validation` in the training logs (e.g. `metrics['validation']['ner']['micro']['f']`)
28+
- It is now recommended to define optimizer groups of `ScheduledOptimizer` as a list of dicts of optim hyper-parameters, each containing a `selector` regex key, rather than as a single dict with a `selector` as keys and a dict of optim hyper-parameters as values. This allows for more flexibility in defining the optimizer groups, and is more consistent with the rest of the EDS-NLP API. This makes it easier to reference groups values from other places in config files, since their path doesn't contain a complex regex string anymore. See the updated training tutorials for more details.
2729

2830
## v0.17.2 (2025-06-25)
2931

docs/tutorials/training-ner.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,14 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
120120
groups:
121121
# Assign parameters starting with transformer (ie the parameters of the transformer component)
122122
# to a first group
123-
"^transformer":
123+
- selector: "ner[.]embedding[.]embedding"
124124
lr:
125125
'@schedules': linear
126126
"warmup_rate": 0.1
127127
"start_value": 0
128128
"max_value": 5e-5
129129
# And every other parameters to the second group
130-
"":
130+
- selector: ".*"
131131
lr:
132132
'@schedules': linear
133133
"warmup_rate": 0.1

docs/tutorials/training-span-classifier.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,15 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
126126
"@core": optimizer !draft # (2)!
127127
optim: torch.optim.AdamW
128128
groups:
129-
'biopsy_classifier[.]embedding':
129+
# Small learning rate for the pretrained transformer model
130+
- selector: 'biopsy_classifier[.]embedding[.]embedding'
130131
lr:
131132
'@schedules': linear
132133
warmup_rate: 0.1
133134
start_value: 0.
134135
max_value: 5e-5
135-
'.*':
136+
# Larger learning rate for the rest of the model
137+
- selector: '.*'
136138
lr:
137139
'@schedules': linear
138140
warmup_rate: 0.1

edsnlp/training/optimizer.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib
2+
import warnings
23
from collections import defaultdict
34
from typing import (
45
Any,
@@ -166,7 +167,9 @@ def __init__(
166167
optim: Union[torch.optim.Optimizer, Type[torch.optim.Optimizer], str],
167168
module: Optional[Union[PipelineProtocol, torch.nn.Module]] = None,
168169
total_steps: Optional[int] = None,
169-
groups: Optional[Dict[str, Union[Dict, Literal[False]]]] = None,
170+
groups: Optional[
171+
Union[List[Dict], Dict[str, Union[Dict, Literal[False]]]]
172+
] = None,
170173
init_schedules: bool = True,
171174
**kwargs,
172175
):
@@ -183,13 +186,17 @@ def __init__(
183186
optim = ScheduledOptimizer(
184187
cls="adamw",
185188
module=model,
186-
groups={
189+
groups=[
187190
# Exclude all parameters matching 'bias' from optimization.
188-
"bias": False,
189-
# Parameters starting with 'transformer' receive this learning rate
191+
{
192+
"selector": "bias",
193+
"exclude": True,
194+
},
195+
# Parameters of the NER module's embedding receive this learning rate
190196
# schedule. If a parameter matches both 'transformer' and 'ner',
191-
# the 'transformer' settings take precedence due to the order.
192-
"^transformer": {
197+
# the first group settings take precedence due to the order.
198+
{
199+
"selector": "^ner[.]embedding"
193200
"lr": {
194201
"@schedules": "linear",
195202
"start_value": 0.0,
@@ -199,7 +206,8 @@ def __init__(
199206
},
200207
# Parameters starting with 'ner' receive this learning rate schedule,
201208
# unless a 'lr' value has already been set by an earlier selector.
202-
"^ner": {
209+
{
210+
"selector": "^ner"
203211
"lr": {
204212
"@schedules": "linear",
205213
"start_value": 0.0,
@@ -209,10 +217,11 @@ def __init__(
209217
},
210218
# Apply a weight_decay of 0.01 to all parameters not excluded.
211219
# This setting doesn't conflict with others and applies to all.
212-
"": {
220+
{
221+
"selector": "",
213222
"weight_decay": 0.01,
214223
},
215-
},
224+
],
216225
total_steps=1000,
217226
)
218227
```
@@ -221,24 +230,28 @@ def __init__(
221230
----------
222231
optim : Union[str, Type[torch.optim.Optimizer], torch.optim.Optimizer]
223232
The optimizer to use. If a string (like "adamw") or a type to instantiate,
224-
the`module` and `groups` must be provided.
233+
the `module` and `groups` must be provided.
225234
module : Optional[Union[PipelineProtocol, torch.nn.Module]]
226235
The module to optimize. Usually the `nlp` pipeline object.
227236
total_steps : Optional[int]
228237
The total number of steps, used for schedules.
229-
groups : Optional[Dict[str, Group]]
230-
The groups to optimize. The key is a regex selector to match parameters in
231-
`module.named_parameters()` and the value is a dictionary with the keys
232-
`params` and `schedules`.
233-
234-
The matching is performed by running `regex.search(selector, name)` so you
235-
do not have to match the full name. Note that the order of dict keys
236-
matter. If a parameter name matches multiple selectors, the
238+
groups : Optional[List[Group]]
239+
The groups to optimize. Each group is a dictionary containing:
240+
241+
- a regex `selector` key to match the parameter of that group by their names
242+
(as listed by `nlp.named_parameters()`)
243+
- and several other keys that define the optimizer parameters for that
244+
group, such as `lr`, `weight_decay` etc. The value for these keys can
245+
be a `Schedule` instance or a simple value
246+
- an `exclude` key that can be set to True to exclude parameters
247+
248+
The matching is performed by running `regex.search(selector, name)` so you
249+
do not have to match the full name. Note that the order of the groups
250+
matters. If a parameter name matches multiple selectors, the
237251
configurations of these selectors are combined in reverse order (from the
238252
last matched selector to the first), allowing later selectors to complete
239-
options from earlier ones. If a selector maps to `False`, any parameters
240-
matching it are excluded from optimization and not included in any parameter
241-
group.
253+
options from earlier ones. If a selector contains `exclude=True`, any
254+
parameter matching it is excluded from optimization.
242255
"""
243256
should_instantiate_optim = isinstance(optim, str) or isinstance(optim, type)
244257
if should_instantiate_optim and (groups is None or module is None):
@@ -257,6 +270,15 @@ def __init__(
257270
if should_instantiate_optim:
258271
named_parameters = list(module.named_parameters())
259272
groups = Config.resolve(groups, registry=edsnlp.registry)
273+
274+
# New groups format
275+
if isinstance(groups, list):
276+
groups = [dict(g) for g in groups]
277+
groups = {
278+
g.pop("selector"): g if not g.get("exclude") else False
279+
for g in groups
280+
}
281+
260282
groups = {
261283
sel: dict(group) if group else False for sel, group in groups.items()
262284
}
@@ -268,8 +290,20 @@ def __init__(
268290
)
269291
)
270292
groups_to_params = defaultdict(lambda: [])
293+
empty_selectors = {sel for sel in groups}
271294
for params, group in param_to_groups.items():
272295
groups_to_params[group].append(params)
296+
for sel in group:
297+
empty_selectors.discard(sel)
298+
299+
if empty_selectors:
300+
warnings.warn(
301+
f"Selectors {list(empty_selectors)} did not match any parameters."
302+
)
303+
warnings.warn(
304+
"For reference, here are the parameters of the module:\n"
305+
+ "\n".join("- " + name for name, _ in named_parameters)
306+
)
273307

274308
cliques = []
275309
for selectors, params in groups_to_params.items():

tests/training/ner_qlf_diff_bert_config.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ optimizer:
6464
optim: torch.optim.AdamW
6565
module: ${ nlp }
6666
groups:
67-
"^transformer": false
68-
".*":
67+
# Transformer
68+
- selector: "ner[.]embedding[.]embedding"
69+
exclude: true
70+
- selector: ".*"
6971
lr:
7072
"@schedules": linear
7173
start_value: 1e-3

tests/training/ner_qlf_same_bert_config.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ optimizer:
6060
optim: AdamW
6161
module: ${ nlp }
6262
groups:
63-
"^transformer": false
64-
".*":
63+
# Transformer
64+
- selector: "transformer"
65+
exclude: true
66+
- selector: ".*"
6567
lr: 1e-3
6668

6769
# 📚 DATA

tests/training/test_optimizer.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,28 @@ def net():
6868
"weight_decay": 0.0,
6969
},
7070
},
71+
# New groups format
72+
[
73+
{
74+
"selector": "fc1[.].*",
75+
"lr": {
76+
"@schedules": "linear",
77+
"start_value": 0.0,
78+
"max_value": 0.1,
79+
"warmup_rate": 0.2,
80+
},
81+
"weight_decay": 0.01,
82+
},
83+
{
84+
"selector": "fc2[.]bias",
85+
"exclude": True,
86+
},
87+
{
88+
"selector": "",
89+
"lr": 0.0001,
90+
"weight_decay": 0.0,
91+
},
92+
],
7193
],
7294
)
7395
def test_old_parameter_selection(net, groups):
@@ -172,3 +194,25 @@ def test_repr(net):
172194
optim.initialize()
173195

174196
assert "ScheduledOptimizer[AdamW]" in repr(optim)
197+
198+
199+
def test_warn_empty_selector(net):
200+
with pytest.warns(
201+
UserWarning,
202+
match="Selectors ['fc3[.].*'] did not match any parameters.",
203+
):
204+
ScheduledOptimizer(
205+
optim="adamw",
206+
module=net,
207+
groups=[
208+
{
209+
"selector": "fc3[.].*",
210+
"lr": 0.1,
211+
"weight_decay": 0.01,
212+
"schedules": LinearSchedule(start_value=0.0, warmup_rate=0.2),
213+
},
214+
{"selector": "fc2[.]bias", "exclude": True},
215+
{"selector": "", "lr": 0.0001, "weight_decay": 0.0},
216+
],
217+
total_steps=10,
218+
)

0 commit comments

Comments
 (0)