Skip to content

Commit 901d10e

Browse files
committed
refactor transformer part 1
1 parent 199c240 commit 901d10e

File tree

4 files changed

+268
-1133
lines changed

4 files changed

+268
-1133
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,6 +1506,98 @@ def __call__(
15061506
return hidden_states, encoder_hidden_states
15071507

15081508

1509+
class AllegroAttnProcessor2_0:
1510+
r"""
1511+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
1512+
used in the Allegro model. It applies a s normalization layer and rotary embedding on query and key vector.
1513+
"""
1514+
1515+
def __init__(self):
1516+
if not hasattr(F, "scaled_dot_product_attention"):
1517+
raise ImportError("AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1518+
1519+
def __call__(
1520+
self,
1521+
attn: Attention,
1522+
hidden_states: torch.Tensor,
1523+
encoder_hidden_states: Optional[torch.Tensor] = None,
1524+
attention_mask: Optional[torch.Tensor] = None,
1525+
temb: Optional[torch.Tensor] = None,
1526+
image_rotary_emb: Optional[torch.Tensor] = None,
1527+
) -> torch.Tensor:
1528+
residual = hidden_states
1529+
1530+
if attn.spatial_norm is not None:
1531+
hidden_states = attn.spatial_norm(hidden_states, temb)
1532+
1533+
input_ndim = hidden_states.ndim
1534+
1535+
if input_ndim == 4:
1536+
batch_size, channel, height, width = hidden_states.shape
1537+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1538+
1539+
batch_size, sequence_length, _ = (
1540+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1541+
)
1542+
1543+
if attention_mask is not None:
1544+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1545+
# scaled_dot_product_attention expects attention_mask shape to be
1546+
# (batch, heads, source_length, target_length)
1547+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1548+
1549+
if attn.group_norm is not None:
1550+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1551+
1552+
query = attn.to_q(hidden_states)
1553+
1554+
if encoder_hidden_states is None:
1555+
encoder_hidden_states = hidden_states
1556+
elif attn.norm_cross:
1557+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1558+
1559+
key = attn.to_k(encoder_hidden_states)
1560+
value = attn.to_v(encoder_hidden_states)
1561+
1562+
inner_dim = key.shape[-1]
1563+
head_dim = inner_dim // attn.heads
1564+
1565+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1566+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1567+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1568+
1569+
# Apply RoPE if needed
1570+
if image_rotary_emb is not None and not attn.is_cross_attention:
1571+
from .embeddings import apply_rotary_emb_allegro
1572+
1573+
query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
1574+
key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
1575+
1576+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
1577+
# TODO: add support for attn.scale when we move to Torch 2.1
1578+
hidden_states = F.scaled_dot_product_attention(
1579+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1580+
)
1581+
1582+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1583+
hidden_states = hidden_states.to(query.dtype)
1584+
1585+
# linear proj
1586+
hidden_states = attn.to_out[0](hidden_states)
1587+
# dropout
1588+
hidden_states = attn.to_out[1](hidden_states)
1589+
1590+
if input_ndim == 4:
1591+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1592+
1593+
if attn.residual_connection:
1594+
hidden_states = hidden_states + residual
1595+
1596+
hidden_states = hidden_states / attn.rescale_output_factor
1597+
1598+
return hidden_states
1599+
1600+
15091601
class AuraFlowAttnProcessor2_0:
15101602
"""Attention processor used typically in processing Aura Flow."""
15111603

src/diffusers/models/embeddings.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,31 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
564564
return cos, sin
565565

566566

567+
def get_3d_rotary_pos_embed_allegro(
568+
embed_dim, crops_coords, grid_size, temporal_size, interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0), theta: int = 10000
569+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
570+
# TODO(aryan): docs
571+
start, stop = crops_coords
572+
grid_size_h, grid_size_w = grid_size
573+
interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
574+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
575+
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
576+
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
577+
578+
# Compute dimensions for each axis
579+
dim_t = embed_dim // 3
580+
dim_h = embed_dim // 3
581+
dim_w = embed_dim // 3
582+
583+
# Temporal frequencies
584+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False)
585+
# Spatial frequencies for height and width
586+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False)
587+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False)
588+
589+
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
590+
591+
567592
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
568593
"""
569594
RoPE for image tokens with 2d structure.
@@ -684,7 +709,7 @@ def get_1d_rotary_pos_embed(
684709
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
685710
return freqs_cos, freqs_sin
686711
elif use_real:
687-
# stable audio
712+
# stable audio, allegro
688713
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
689714
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
690715
return freqs_cos, freqs_sin
@@ -743,6 +768,24 @@ def apply_rotary_emb(
743768
return x_out.type_as(x)
744769

745770

771+
def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
772+
# TODO(aryan): rewrite
773+
def apply_1d_rope(tokens, pos, cos, sin):
774+
cos = F.embedding(pos, cos)[:, None, :, :]
775+
sin = F.embedding(pos, sin)[:, None, :, :]
776+
x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2:]
777+
tokens_rotated = torch.cat((-x2, x1), dim=-1)
778+
return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype)
779+
780+
(t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis
781+
t, h, w = x.chunk(3, dim=-1)
782+
t = apply_1d_rope(t, positions[0], t_cos, t_sin)
783+
h = apply_1d_rope(h, positions[1], h_cos, h_sin)
784+
w = apply_1d_rope(w, positions[2], w_cos, w_sin)
785+
x = torch.cat([t, h, w], dim=-1)
786+
return x
787+
788+
746789
class FluxPosEmbed(nn.Module):
747790
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
748791
def __init__(self, theta: int, axes_dim: List[int]):

0 commit comments

Comments
 (0)