-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Fix Context Parallel validation checks #12446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
faf61a4
428399b
1d76322
a66787b
881e262
0845ca0
8018a6a
f925783
5bfc7dd
fb15ff5
56114f4
4505645
3b12a0b
d65f857
3dcc9ca
197dd5f
e41ca61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,11 +44,16 @@ class ContextParallelConfig: | |
|
|
||
| Args: | ||
| ring_degree (`int`, *optional*, defaults to `1`): | ||
| Number of devices to use for ring attention within a context parallel region. Must be a divisor of the | ||
| total number of devices in the context parallel mesh. | ||
| Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes | ||
| attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N | ||
| of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best | ||
| for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a | ||
| context parallel region. Must be a divisor of the total number of devices in the context parallel mesh. | ||
| ulysses_degree (`int`, *optional*, defaults to `1`): | ||
| Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the | ||
| total number of devices in the context parallel mesh. | ||
| Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes | ||
| local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all | ||
| KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with | ||
| good interconnect bandwidth. | ||
| convert_to_fp32 (`bool`, *optional*, defaults to `True`): | ||
| Whether to convert output and LSE to float32 for ring attention numerical stability. | ||
| rotate_method (`str`, *optional*, defaults to `"allgather"`): | ||
|
|
@@ -79,29 +84,46 @@ def __post_init__(self): | |
| if self.ulysses_degree is None: | ||
| self.ulysses_degree = 1 | ||
|
|
||
| if self.ring_degree == 1 and self.ulysses_degree == 1: | ||
| raise ValueError( | ||
| "Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference" | ||
| ) | ||
| if self.ring_degree < 1 or self.ulysses_degree < 1: | ||
| raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") | ||
| if self.ring_degree > 1 and self.ulysses_degree > 1: | ||
| raise ValueError( | ||
| "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." | ||
| ) | ||
| if self.rotate_method != "allgather": | ||
| raise NotImplementedError( | ||
| f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." | ||
| ) | ||
|
|
||
| @property | ||
| def mesh_shape(self) -> Tuple[int, int]: | ||
| return (self.ring_degree, self.ulysses_degree) | ||
|
|
||
| @property | ||
| def mesh_dim_names(self) -> Tuple[str, str]: | ||
| """Dimension names for the device mesh.""" | ||
| return ("ring", "ulysses") | ||
|
|
||
| def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh): | ||
| self._rank = rank | ||
| self._world_size = world_size | ||
| self._device = device | ||
| self._mesh = mesh | ||
| if self.ring_degree is None: | ||
| self.ring_degree = 1 | ||
| if self.ulysses_degree is None: | ||
| self.ulysses_degree = 1 | ||
| if self.rotate_method != "allgather": | ||
| raise NotImplementedError( | ||
| f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." | ||
|
|
||
| if self.ulysses_degree * self.ring_degree > world_size: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we hit line as both cannot be set, right?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both can be set techinically, but currently both can't be > 1. Also this is for cases where you have 3 GPUs available and you set something like ulysses_degree=1 and ring_degree==4 (more GPUs being requested is greater than world_size)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feels slightly confusing to me but since we're erroring out early for unsupported |
||
| raise ValueError( | ||
| f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})." | ||
| ) | ||
| if self._flattened_mesh is None: | ||
| self._flattened_mesh = self._mesh._flatten() | ||
| if self._ring_mesh is None: | ||
| self._ring_mesh = self._mesh["ring"] | ||
| if self._ulysses_mesh is None: | ||
| self._ulysses_mesh = self._mesh["ulysses"] | ||
| if self._ring_local_rank is None: | ||
| self._ring_local_rank = self._ring_mesh.get_local_rank() | ||
| if self._ulysses_local_rank is None: | ||
| self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() | ||
|
|
||
| self._flattened_mesh = self._mesh._flatten() | ||
| self._ring_mesh = self._mesh["ring"] | ||
| self._ulysses_mesh = self._mesh["ulysses"] | ||
| self._ring_local_rank = self._ring_mesh.get_local_rank() | ||
| self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() | ||
|
Comment on lines
+122
to
+126
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't they be
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are internal attributes that are derived from mesh which is set through the The guards are redundant, they would always be |
||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -119,22 +141,22 @@ class ParallelConfig: | |
| _rank: int = None | ||
| _world_size: int = None | ||
| _device: torch.device = None | ||
| _cp_mesh: torch.distributed.device_mesh.DeviceMesh = None | ||
| _mesh: torch.distributed.device_mesh.DeviceMesh = None | ||
|
|
||
| def setup( | ||
| self, | ||
| rank: int, | ||
| world_size: int, | ||
| device: torch.device, | ||
| *, | ||
| cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, | ||
| mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, | ||
| ): | ||
| self._rank = rank | ||
| self._world_size = world_size | ||
| self._device = device | ||
| self._cp_mesh = cp_mesh | ||
| self._mesh = mesh | ||
| if self.context_parallel_config is not None: | ||
| self.context_parallel_config.setup(rank, world_size, device, cp_mesh) | ||
| self.context_parallel_config.setup(rank, world_size, device, mesh) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -220,7 +220,7 @@ class _AttentionBackendRegistry: | |
| _backends = {} | ||
| _constraints = {} | ||
| _supported_arg_names = {} | ||
| _supports_context_parallel = {} | ||
| _supports_context_parallel = set() | ||
| _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) | ||
| _checks_enabled = DIFFUSERS_ATTN_CHECKS | ||
|
|
||
|
|
@@ -237,7 +237,9 @@ def decorator(func): | |
| cls._backends[backend] = func | ||
| cls._constraints[backend] = constraints or [] | ||
| cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) | ||
| cls._supports_context_parallel[backend] = supports_context_parallel | ||
| if supports_context_parallel: | ||
| cls._supports_context_parallel.add(backend.value) | ||
|
|
||
| return func | ||
|
|
||
| return decorator | ||
|
|
@@ -251,15 +253,12 @@ def list_backends(cls): | |
| return list(cls._backends.keys()) | ||
|
|
||
| @classmethod | ||
| def _is_context_parallel_enabled( | ||
| cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"] | ||
| def _is_context_parallel_available( | ||
| cls, | ||
| backend: AttentionBackendName, | ||
| ) -> bool: | ||
| supports_context_parallel = backend in cls._supports_context_parallel | ||
| is_degree_greater_than_1 = parallel_config is not None and ( | ||
| parallel_config.context_parallel_config.ring_degree > 1 | ||
| or parallel_config.context_parallel_config.ulysses_degree > 1 | ||
| ) | ||
| return supports_context_parallel and is_degree_greater_than_1 | ||
| supports_context_parallel = backend.value in cls._supports_context_parallel | ||
| return supports_context_parallel | ||
|
Comment on lines
-257
to
+261
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice cleanup here! |
||
|
|
||
|
|
||
| @contextlib.contextmanager | ||
|
|
@@ -306,14 +305,6 @@ def dispatch_attention_fn( | |
| backend_name = AttentionBackendName(backend) | ||
| backend_fn = _AttentionBackendRegistry._backends.get(backend_name) | ||
|
|
||
| if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled( | ||
| backend_name, parallel_config | ||
| ): | ||
| raise ValueError( | ||
| f"Backend {backend_name} either does not support context parallelism or context parallelism " | ||
| f"was enabled with a world size of 1." | ||
| ) | ||
|
|
||
| kwargs = { | ||
| "query": query, | ||
| "key": key, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to add a small explainer about what it would mean for different values, for example - "(3, 1), (1, 3)", etc.? When both are being set, both cannot be > 1.