Skip to content

Commit 7c58309

Browse files
committed
fix: support multi-gpu training
1 parent e5b6afc commit 7c58309

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080
- for computing correct loss means when accumulating gradients over multiple mini-mini-batches
8181
- for computing correct loss means in multi-GPU setups, since these stats are synchronized and accumulated across GPUs
8282
83+
- Support multi GPU training via hugginface `accelerate` and EDS-NLP `Stream` API consideration of env['WOLRD_SIZE'] and env['LOCAL_RANK'] environment variables
84+
8385
## v0.13.1
8486
8587
### Added

edsnlp/training/trainer.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020
from accelerate import Accelerator
21+
from accelerate.utils import gather_object
2122
from confit import validate_arguments
2223
from confit.utils.random import set_seed
2324
from rich_logger import RichTablePrinter
@@ -35,10 +36,7 @@
3536
from edsnlp.utils.span_getters import get_spans
3637
from edsnlp.utils.typing import AsList
3738

38-
from .optimizer import ( # noqa: F401
39-
LinearSchedule,
40-
ScheduledOptimizer,
41-
)
39+
from .optimizer import LinearSchedule, ScheduledOptimizer
4240

4341
LOGGER_FIELDS = {
4442
"step": {},
@@ -540,15 +538,17 @@ def train(
540538
batches = list(flatten(batches))
541539

542540
# Synchronize stats between sub-batches across workers
543-
input_stats = {}
541+
batch_stats = {}
544542
for b in batches:
545-
fill_flat_stats(b, result=input_stats)
546-
input_stats = list(flatten(accelerator.gather([input_stats])))
547-
input_stats = {k: sum(v) for k, v in ld_to_dl(input_stats).items()}
543+
fill_flat_stats(b, result=batch_stats)
544+
batch_stats = {
545+
k: sum(v)
546+
for k, v in ld_to_dl(gather_object([batch_stats])).items()
547+
}
548548
for b in batches:
549-
set_flat_stats(b, input_stats)
549+
set_flat_stats(b, batch_stats)
550550

551-
output_stats = defaultdict(lambda: 0.0)
551+
res_stats = defaultdict(lambda: 0.0)
552552
for batch, batch_pipe_names in zip(batches, batches_pipe_names):
553553
loss = torch.zeros((), device=accelerator.device)
554554
with nlp.cache():
@@ -566,25 +566,25 @@ def train(
566566
or isinstance(v, torch.Tensor)
567567
and v.ndim == 0
568568
):
569-
output_stats[k] += float(v)
569+
res_stats[k] += float(v)
570570
if torch.isnan(loss):
571571
raise ValueError(f"NaN loss at component {name}")
572572
del k, v, res, pipe
573573
accelerator.backward(loss)
574574
del loss
575575

576576
# Sync output stats after forward such as losses, supports, etc.
577-
output_stats = list(flatten(accelerator.gather([output_stats])))
578-
output_stats = {
579-
k: sum(v) for k, v in ld_to_dl(output_stats).items()
577+
res_stats = {
578+
k: sum(v)
579+
for k, v in ld_to_dl(gather_object([dict(res_stats)])).items()
580580
}
581581
if is_main_process:
582-
for k, v in input_stats.items():
582+
for k, v in batch_stats.items():
583583
cumulated_data[k] += v
584-
for k, v in output_stats.items():
584+
for k, v in res_stats.items():
585585
cumulated_data[k] += v
586586

587-
del input_stats, output_stats
587+
del batch_stats, res_stats
588588
accelerator.clip_grad_norm_(grad_params, max_grad_norm)
589589
accel_optim.step()
590590

0 commit comments

Comments
 (0)