Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions OmniGen/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def evict_previous_layer(self, layer_idx: int):
prev_layer_idx = -1
else:
prev_layer_idx = (layer_idx - 1) % len(self)
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu")
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu")


def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
Expand All @@ -49,9 +49,9 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# Load current layer cache to its original device if not already there
original_device = self.original_device[layer_idx]
#original_device = self.original_device[layer_idx]
# self.prefetch_stream.synchronize(original_device)
torch.cuda.synchronize(self.prefetch_stream)
self.prefetch_stream.synchronize()
key_tensor = self.key_cache[layer_idx]
value_tensor = self.value_cache[layer_idx]

Expand Down
18 changes: 10 additions & 8 deletions OmniGen/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,22 @@ def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
prev_layer_idx = layer_idx - 1
for name, param in self.layers[prev_layer_idx].named_parameters():
param.data = param.data.to("cpu", non_blocking=True)
param.data = param.data.to("cpu")

def get_offlaod_layer(self, layer_idx: int, device: torch.device):
# init stream
if not hasattr(self, "prefetch_stream"):
self.prefetch_stream = torch.cuda.Stream()

# delete previous layer
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# main stream sync shouldn't be necessary since all computation on iter i-1 is finished by iter i
# torch.cuda.current_stream().synchronize()
# avoid extra eviction of last layer
if layer_idx > 0:
self.evict_previous_layer(layer_idx)

# make sure the current layer is ready
torch.cuda.synchronize(self.prefetch_stream)
self.prefetch_stream.synchronize()

# load next layer
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
Expand Down Expand Up @@ -133,10 +136,9 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

layer_idx = -1
for decoder_layer in self.layers:
layer_idx += 1

for layer_idx in range(len(self.layers)):
# direct indexing since offloading may mutate self.layers during iteration
decoder_layer = self.layers[layer_idx]
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down