Skip to content

Commit 114482a

Browse files
percevalwThomzoy
authored andcommitted
fix: remove tuning study pickle from repo and rebuild them during the tests
1 parent 702ec25 commit 114482a

File tree

10 files changed

+50
-341
lines changed

10 files changed

+50
-341
lines changed

edsnlp/tune.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,11 +606,12 @@ def tune(
606606
hyperparameters: Dict[str, HyperparameterConfig],
607607
output_dir: str,
608608
checkpoint_dir: str,
609-
gpu_hours: confloat(gt=0) = DEFAULT_GPU_HOUR,
609+
gpu_hours: Optional[confloat(gt=0)] = DEFAULT_GPU_HOUR,
610610
n_trials: Optional[conint(gt=0)] = None,
611611
two_phase_tuning: bool = False,
612612
seed: int = 42,
613613
metric="ner.micro.f",
614+
keep_checkpoint: bool = False,
614615
):
615616
"""
616617
Perform hyperparameter tuning for a model using Optuna.
@@ -647,6 +648,10 @@ def tune(
647648
Default is False.
648649
seed : int, optional
649650
Random seed for reproducibility. Default is 42.
651+
metric : str, optional
652+
Metric used to evaluate trials. Default is "ner.micro.f".
653+
keep_checkpoint : bool, optional
654+
If True, keeps the checkpoint file after tuning. Default is False.
650655
"""
651656
setup_logging()
652657
viz = is_plotly_install()
@@ -665,6 +670,7 @@ def tune(
665670
logger.info(f"Elapsed trials: {elapsed_trials}")
666671

667672
if not is_fixed_n_trials:
673+
gpu_hours = gpu_hours or DEFAULT_GPU_HOUR
668674
if not study:
669675
logger.info(f"Computing number of trials for {gpu_hours} hours of GPU.")
670676
study = optimize(
@@ -734,7 +740,7 @@ def tune(
734740
f"Tuning completed. Results available in {output_dir}. Deleting checkpoint."
735741
)
736742
checkpoint_file = os.path.join(checkpoint_dir, CHECKPOINT)
737-
if os.path.exists(checkpoint_file):
743+
if os.path.exists(checkpoint_file) and not keep_checkpoint:
738744
os.remove(checkpoint_file)
739745

740746

-5.95 KB
Binary file not shown.
-6.47 KB
Binary file not shown.

tests/tuning/test_checkpoints/two_phase_gpu_hour/config.yml

Lines changed: 0 additions & 127 deletions
This file was deleted.

tests/tuning/test_checkpoints/two_phase_gpu_hour/results_summary.txt

Lines changed: 0 additions & 13 deletions
This file was deleted.
-5.69 KB
Binary file not shown.

tests/tuning/test_checkpoints/two_phase_n_trials/config.yml

Lines changed: 0 additions & 127 deletions
This file was deleted.

tests/tuning/test_checkpoints/two_phase_n_trials/results_summary.txt

Lines changed: 0 additions & 13 deletions
This file was deleted.
-6.36 KB
Binary file not shown.

0 commit comments

Comments
 (0)