Skip to content

Conversation

@eternalNight
Copy link
Contributor

@eternalNight eternalNight commented Dec 15, 2025

Summary

This PR, authored by @hahaha3210, introduces ParallelState, a class that manages process groups for an arbitrary combination of parallel strategies including TP, EP, PP and DP.

As is discussed in #7680, the primary approach is borrowing the process group creation logic from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py but encapsulating the states (i.e., process groups, ranks and world sizes) into a class. This design enables the coexistence of multiple, independent parallelism configurations within a single process, which is particularly valuable in scenarios involving multiple models, such as in reinforcement learning (RL) workflows. Objects of ParallelState can be created prior to calls to deepspeed.initialize so that process groups are available to custom modules, such as UlyssesSPAttentionHF, at an early stage.

Compatibility of ParallelState and current process group management facilities (including deepspeed.runtime.sequence_parallel.parallel_state_sp and deepspeed.utils.groups) is tested by test_mpu.py.

Opens

  1. Support for Ulysses SP is yet to be added.
  2. Support creating a ParallelState from a config object rather than specifying different parallel dimensions explicitly.
  3. Are wrappers in parallel_state_deepspeed.py necessary? If so, is there a better way to implement more concisely its APIs sharing similar code patterns?
  4. Are GLOO process groups necessary for DeepSpeed? If not, we can strip them from the draft.
  5. Tweaking NCCL options require ProcessGroupNCCL.options from torch.distributed, and that is not provided by deepspeed.comm today. Should we introduce that to deepspeed.comm, or make the format-checking script allowing that specific use of torch.distributed?

Signed-off-by: Jikang Mo <mojikang.mjk@alibaba-inc.com>
Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
@sfc-gh-truwase
Copy link
Collaborator

@stas00 @tohtana @delock FYI

@delock
Copy link
Collaborator

delock commented Dec 16, 2025

I like the idea of puting parallel dimension in a single place rather than relying on user reading deeply into document to figure out how to turn on each parallelism dimension. I also agree with the open that the class can be created from a config object. Does it make sense to have the config in config.json file, or a seperate config file is more flexible?

@eternalNight
Copy link
Contributor Author

I like the idea of puting parallel dimension in a single place rather than relying on user reading deeply into document to figure out how to turn on each parallelism dimension. I also agree with the open that the class can be created from a config object. Does it make sense to have the config in config.json file, or a seperate config file is more flexible?

My rough idea is to reuse the current config.json which already provides dimensions of various parallel techniques. Moving parallelism-related configs to a separate file is such a huge change that can impose a big obstable to users who try to upgrade.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants