diff --git a/changelog.md b/changelog.md index 241515e97..89e08e38f 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,10 @@ # Changelog +## Unreleased + +### Added +- New parameter `pruning_params` to `edsnlp.tune` in order to control pruning during tuning. + ## v0.19.0 (2025-10-04) šŸ“¢ EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.20.0), in October 2025. Please upgrade to Python 3.10 or later. diff --git a/edsnlp/metrics/span_attribute.py b/edsnlp/metrics/span_attribute.py index d701813e0..62a9c9dc5 100644 --- a/edsnlp/metrics/span_attribute.py +++ b/edsnlp/metrics/span_attribute.py @@ -41,7 +41,7 @@ import warnings from collections import defaultdict -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union, Sequence from edsnlp import registry from edsnlp.metrics import Examples, average_precision, make_examples, prf @@ -57,6 +57,7 @@ def span_attribute_metric( default_values: Dict = {}, micro_key: str = "micro", filter_expr: Optional[str] = None, + split_by_values: Union[str, Sequence[str]] = None, **kwargs: Any, ): if "qualifiers" in kwargs: @@ -80,6 +81,8 @@ def span_attribute_metric( if filter_expr is not None: filter_fn = eval(f"lambda doc: {filter_expr}") examples = [eg for eg in examples if filter_fn(eg.reference)] + if isinstance(split_by_values, str): + split_by_values = [split_by_values] labels = defaultdict(lambda: (set(), set(), dict())) labels["micro"] = (set(), set(), dict()) total_pred_count = 0 @@ -108,9 +111,15 @@ def span_attribute_metric( if (top_val or include_falsy) and default_values[attr] != top_val: labels[attr][2][(eg_idx, beg, end, attr, top_val)] = top_p labels[micro_key][2][(eg_idx, beg, end, attr, top_val)] = top_p + if split_by_values and attr in split_by_values: + key = f"{attr}:{top_val}" + labels[key][2][(eg_idx, beg, end, attr, top_val)] = top_p if (value or include_falsy) and default_values[attr] != value: labels[micro_key][0].add((eg_idx, beg, end, attr, value)) labels[attr][0].add((eg_idx, beg, end, attr, value)) + if split_by_values and attr in split_by_values: + key = f"{attr}:{value}" + labels[key][0].add((eg_idx, beg, end, attr, value)) doc_spans = get_spans(eg.reference, span_getter) for span in doc_spans: @@ -124,6 +133,9 @@ def span_attribute_metric( if (value or include_falsy) and default_values[attr] != value: labels[micro_key][1].add((eg_idx, beg, end, attr, value)) labels[attr][1].add((eg_idx, beg, end, attr, value)) + if split_by_values and attr in split_by_values: + key = f"{attr}:{value}" + labels[key][1].add((eg_idx, beg, end, attr, value)) if total_pred_count != total_gold_count: raise ValueError( @@ -133,7 +145,7 @@ def span_attribute_metric( "predicted by another NER pipe in your model." ) - return { + metrics = { name: { **prf(pred, gold), "ap": average_precision(pred_with_prob, gold), @@ -141,6 +153,17 @@ def span_attribute_metric( for name, (pred, gold, pred_with_prob) in labels.items() } + if split_by_values: + for attr in split_by_values: + submetrics = {"micro": metrics[attr]} + for key in list(metrics.keys()): + if key.startswith(f"{attr}:"): + val = key.split(":", 1)[1] + submetrics[val] = metrics.pop(key) + metrics[attr] = submetrics + + return metrics + @registry.metrics.register( "eds.span_attribute", @@ -230,7 +253,10 @@ class SpanAttributeMetric: Key under which to store the micro‐averaged results across all attributes. filter_expr : Optional[str] A Python expression (using `doc`) to filter which examples are scored. - + split_by_values : Union[str, Sequence[str]] = None + One or more attributes for which metrics should reported separately for each + attribute value. If `None` (default), metrics are computed on the global attribute-level. + Useful when attributes are multiclass. Returns ------- Dict[str, Dict[str, float]] @@ -258,6 +284,7 @@ def __init__( include_falsy: bool = False, micro_key: str = "micro", filter_expr: Optional[str] = None, + split_by_values: Union[str, Sequence[str]] = None, ): if qualifiers is not None: warnings.warn( @@ -270,6 +297,7 @@ def __init__( self.include_falsy = include_falsy self.micro_key = micro_key self.filter_expr = filter_expr + self.split_by_values = split_by_values __init__.__doc__ = span_attribute_metric.__doc__ @@ -296,6 +324,7 @@ def __call__(self, *examples: Any): include_falsy=self.include_falsy, micro_key=self.micro_key, filter_expr=self.filter_expr, + split_by_values=self.split_by_values, ) diff --git a/edsnlp/tune.py b/edsnlp/tune.py index 0f56e0194..381ba693a 100644 --- a/edsnlp/tune.py +++ b/edsnlp/tune.py @@ -271,7 +271,7 @@ def update_config( return config -def objective_with_param(config, tuned_parameters, trial, metric): +def objective_with_param(config, tuned_parameters, trial, metric, pruning_params): kwargs, _ = update_config(config, tuned_parameters, trial=trial) seed = random.randint(0, 2**32 - 1) set_seed(seed) @@ -282,8 +282,9 @@ def on_validation_callback(all_metrics): for key in metric: score = score[key] trial.report(score, step) - if trial.should_prune(): - raise optuna.TrialPruned() + if pruning_params: + if trial.should_prune(): + raise optuna.TrialPruned() try: nlp = train(**kwargs, on_validation_callback=on_validation_callback) @@ -299,15 +300,30 @@ def on_validation_callback(all_metrics): def optimize( - config_path, tuned_parameters, n_trials, metric, checkpoint_dir, study=None + config_path, + tuned_parameters, + n_trials, + metric, + checkpoint_dir, + pruning_params, + study=None, ): def objective(trial): - return objective_with_param(config_path, tuned_parameters, trial, metric) + return objective_with_param( + config_path, tuned_parameters, trial, metric, pruning_params + ) if not study: + pruner = None + if pruning_params: + n_startup_trials = pruning_params.get("n_startup_trials", 5) + n_warmup_steps = pruning_params.get("n_warmup_steps", 5) + pruner = MedianPruner( + n_startup_trials=n_startup_trials, n_warmup_steps=n_warmup_steps + ) study = optuna.create_study( direction="maximize", - pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=2), + pruner=pruner, sampler=TPESampler(seed=random.randint(0, 2**32 - 1)), ) study.optimize( @@ -444,6 +460,7 @@ def tune_two_phase( is_fixed_n_trials: bool = False, gpu_hours: float = 1.0, skip_phase_1: bool = False, + pruning_params: Dict[str, int] = None, ) -> None: """ Perform two-phase hyperparameter tuning using Optuna. @@ -505,6 +522,7 @@ def tune_two_phase( n_trials_1, metric, checkpoint_dir, + pruning_params, study, ) best_params_phase_1, importances = process_results( @@ -551,6 +569,7 @@ def tune_two_phase( n_trials_2, metric, checkpoint_dir, + pruning_params, study, ) @@ -612,6 +631,7 @@ def tune( seed: int = 42, metric="ner.micro.f", keep_checkpoint: bool = False, + pruning_params: Optional[Dict[str, int]] = None, ): """ Perform hyperparameter tuning for a model using Optuna. @@ -652,6 +672,11 @@ def tune( Metric used to evaluate trials. Default is "ner.micro.f". keep_checkpoint : bool, optional If True, keeps the checkpoint file after tuning. Default is False. + pruning_params : dict, optional + A dictionary specifying pruning parameters: + - "n_startup_trials": Number of startup trials before pruning starts. + - "n_warmup_steps": Number of warmup steps before pruning starts. + Default is None, meaning no pruning. """ setup_logging() viz = is_plotly_install() @@ -679,6 +704,7 @@ def tune( n_trials=1, metric=metric, checkpoint_dir=checkpoint_dir, + pruning_params=pruning_params, ) n_trials = compute_n_trials(gpu_hours, compute_time_per_trial(study)) - 1 else: @@ -708,6 +734,7 @@ def tune( is_fixed_n_trials=is_fixed_n_trials, gpu_hours=gpu_hours, skip_phase_1=skip_phase_1, + pruning_params=pruning_params, ) else: logger.info("Starting single-phase tuning.") @@ -717,6 +744,7 @@ def tune( n_trials, metric, checkpoint_dir, + pruning_params, study, ) if not is_fixed_n_trials: @@ -732,6 +760,7 @@ def tune( n_trials, metric, checkpoint_dir, + pruning_params, study, ) process_results(study, output_dir, viz, config, config_path, hyperparameters)