Skip to content

Commit 3c05b9f

Browse files
KimbingNgsayakpaul
andauthored
Fixes #12673. record_stream in group offloading is not working properly (#12721)
* Fixes #12673. Wrong default_stream is used. leading to wrong execution order when record_steram is enabled. * update * Update test --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 9379b23 commit 3c05b9f

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,27 +153,27 @@ def _pinned_memory_tensors(self):
153153
finally:
154154
pinned_dict = None
155155

156-
def _transfer_tensor_to_device(self, tensor, source_tensor):
156+
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
157157
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
158158
if self.record_stream:
159-
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
159+
tensor.data.record_stream(default_stream)
160160

161-
def _process_tensors_from_modules(self, pinned_memory=None):
161+
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
162162
for group_module in self.modules:
163163
for param in group_module.parameters():
164164
source = pinned_memory[param] if pinned_memory else param.data
165-
self._transfer_tensor_to_device(param, source)
165+
self._transfer_tensor_to_device(param, source, default_stream)
166166
for buffer in group_module.buffers():
167167
source = pinned_memory[buffer] if pinned_memory else buffer.data
168-
self._transfer_tensor_to_device(buffer, source)
168+
self._transfer_tensor_to_device(buffer, source, default_stream)
169169

170170
for param in self.parameters:
171171
source = pinned_memory[param] if pinned_memory else param.data
172-
self._transfer_tensor_to_device(param, source)
172+
self._transfer_tensor_to_device(param, source, default_stream)
173173

174174
for buffer in self.buffers:
175175
source = pinned_memory[buffer] if pinned_memory else buffer.data
176-
self._transfer_tensor_to_device(buffer, source)
176+
self._transfer_tensor_to_device(buffer, source, default_stream)
177177

178178
def _onload_from_disk(self):
179179
if self.stream is not None:
@@ -208,10 +208,12 @@ def _onload_from_memory(self):
208208
self.stream.synchronize()
209209

210210
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
211+
default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None
212+
211213
with context:
212214
if self.stream is not None:
213215
with self._pinned_memory_tensors() as pinned_memory:
214-
self._process_tensors_from_modules(pinned_memory)
216+
self._process_tensors_from_modules(pinned_memory, default_stream=default_stream)
215217
else:
216218
self._process_tensors_from_modules(None)
217219

tests/models/test_modeling_common.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1814,9 +1814,6 @@ def _run_forward(model, inputs_dict):
18141814
torch.manual_seed(0)
18151815
return model(**inputs_dict)[0]
18161816

1817-
if self.__class__.__name__ == "AutoencoderKLCosmosTests" and offload_type == "leaf_level":
1818-
pytest.skip("With `leaf_type` as the offloading type, it fails. Needs investigation.")
1819-
18201817
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
18211818
torch.manual_seed(0)
18221819
model = self.model_class(**init_dict)

0 commit comments

Comments
 (0)