Skip to content

Commit 759ea58

Browse files
authored
[core] reuse AttentionMixin for compatible classes (#12463)
* remove attn_processors property * more * up * up more. * up * add AttentionMixin to AuraFlow. * up * up * up * up
1 parent f48f9c2 commit 759ea58

34 files changed

+106
-2139
lines changed

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 5 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Dict, Optional, Tuple, Union
14+
from typing import Optional, Tuple, Union
1515

1616
import torch
1717
import torch.nn as nn
@@ -21,11 +21,11 @@
2121
from ...loaders.single_file_model import FromOriginalModelMixin
2222
from ...utils import deprecate
2323
from ...utils.accelerate_utils import apply_forward_hook
24+
from ..attention import AttentionMixin
2425
from ..attention_processor import (
2526
ADDED_KV_ATTENTION_PROCESSORS,
2627
CROSS_ATTENTION_PROCESSORS,
2728
Attention,
28-
AttentionProcessor,
2929
AttnAddedKVProcessor,
3030
AttnProcessor,
3131
FusedAttnProcessor2_0,
@@ -35,7 +35,9 @@
3535
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
3636

3737

38-
class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
38+
class AutoencoderKL(
39+
ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin
40+
):
3941
r"""
4042
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
4143
@@ -138,66 +140,6 @@ def __init__(
138140
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
139141
self.tile_overlap_factor = 0.25
140142

141-
@property
142-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
143-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
144-
r"""
145-
Returns:
146-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
147-
indexed by its weight name.
148-
"""
149-
# set recursively
150-
processors = {}
151-
152-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
153-
if hasattr(module, "get_processor"):
154-
processors[f"{name}.processor"] = module.get_processor()
155-
156-
for sub_name, child in module.named_children():
157-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
158-
159-
return processors
160-
161-
for name, module in self.named_children():
162-
fn_recursive_add_processors(name, module, processors)
163-
164-
return processors
165-
166-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
167-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
168-
r"""
169-
Sets the attention processor to use to compute attention.
170-
171-
Parameters:
172-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
173-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
174-
for **all** `Attention` layers.
175-
176-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
177-
processor. This is strongly recommended when setting trainable attention processors.
178-
179-
"""
180-
count = len(self.attn_processors.keys())
181-
182-
if isinstance(processor, dict) and len(processor) != count:
183-
raise ValueError(
184-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
185-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
186-
)
187-
188-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
189-
if hasattr(module, "set_processor"):
190-
if not isinstance(processor, dict):
191-
module.set_processor(processor)
192-
else:
193-
module.set_processor(processor.pop(f"{name}.processor"))
194-
195-
for sub_name, child in module.named_children():
196-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
197-
198-
for name, module in self.named_children():
199-
fn_recursive_attn_processor(name, module, processor)
200-
201143
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
202144
def set_default_attn_processor(self):
203145
"""

src/diffusers/models/autoencoders/autoencoder_kl_flux2.py

Lines changed: 5 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
from typing import Dict, Optional, Tuple, Union
15+
from typing import Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
@@ -22,11 +22,11 @@
2222
from ...loaders.single_file_model import FromOriginalModelMixin
2323
from ...utils import deprecate
2424
from ...utils.accelerate_utils import apply_forward_hook
25+
from ..attention import AttentionMixin
2526
from ..attention_processor import (
2627
ADDED_KV_ATTENTION_PROCESSORS,
2728
CROSS_ATTENTION_PROCESSORS,
2829
Attention,
29-
AttentionProcessor,
3030
AttnAddedKVProcessor,
3131
AttnProcessor,
3232
FusedAttnProcessor2_0,
@@ -36,7 +36,9 @@
3636
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
3737

3838

39-
class AutoencoderKLFlux2(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
39+
class AutoencoderKLFlux2(
40+
ModelMixin, AutoencoderMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin
41+
):
4042
r"""
4143
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
4244
@@ -154,66 +156,6 @@ def __init__(
154156
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
155157
self.tile_overlap_factor = 0.25
156158

157-
@property
158-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
159-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
160-
r"""
161-
Returns:
162-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
163-
indexed by its weight name.
164-
"""
165-
# set recursively
166-
processors = {}
167-
168-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
169-
if hasattr(module, "get_processor"):
170-
processors[f"{name}.processor"] = module.get_processor()
171-
172-
for sub_name, child in module.named_children():
173-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
174-
175-
return processors
176-
177-
for name, module in self.named_children():
178-
fn_recursive_add_processors(name, module, processors)
179-
180-
return processors
181-
182-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
183-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
184-
r"""
185-
Sets the attention processor to use to compute attention.
186-
187-
Parameters:
188-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
189-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
190-
for **all** `Attention` layers.
191-
192-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
193-
processor. This is strongly recommended when setting trainable attention processors.
194-
195-
"""
196-
count = len(self.attn_processors.keys())
197-
198-
if isinstance(processor, dict) and len(processor) != count:
199-
raise ValueError(
200-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
201-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
202-
)
203-
204-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
205-
if hasattr(module, "set_processor"):
206-
if not isinstance(processor, dict):
207-
module.set_processor(processor)
208-
else:
209-
module.set_processor(processor.pop(f"{name}.processor"))
210-
211-
for sub_name, child in module.named_children():
212-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
213-
214-
for name, module in self.named_children():
215-
fn_recursive_attn_processor(name, module, processor)
216-
217159
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
218160
def set_default_attn_processor(self):
219161
"""

src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py

Lines changed: 4 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import itertools
15-
from typing import Dict, Optional, Tuple, Union
15+
from typing import Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
1919

2020
from ...configuration_utils import ConfigMixin, register_to_config
2121
from ...utils.accelerate_utils import apply_forward_hook
22-
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
22+
from ..attention import AttentionMixin
23+
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttnProcessor
2324
from ..modeling_outputs import AutoencoderKLOutput
2425
from ..modeling_utils import ModelMixin
2526
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
@@ -135,7 +136,7 @@ def forward(
135136
return sample
136137

137138

138-
class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
139+
class AutoencoderKLTemporalDecoder(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
139140
r"""
140141
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
141142
@@ -202,66 +203,6 @@ def __init__(
202203

203204
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
204205

205-
@property
206-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
207-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
208-
r"""
209-
Returns:
210-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
211-
indexed by its weight name.
212-
"""
213-
# set recursively
214-
processors = {}
215-
216-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
217-
if hasattr(module, "get_processor"):
218-
processors[f"{name}.processor"] = module.get_processor()
219-
220-
for sub_name, child in module.named_children():
221-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
222-
223-
return processors
224-
225-
for name, module in self.named_children():
226-
fn_recursive_add_processors(name, module, processors)
227-
228-
return processors
229-
230-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
231-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
232-
r"""
233-
Sets the attention processor to use to compute attention.
234-
235-
Parameters:
236-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
237-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
238-
for **all** `Attention` layers.
239-
240-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
241-
processor. This is strongly recommended when setting trainable attention processors.
242-
243-
"""
244-
count = len(self.attn_processors.keys())
245-
246-
if isinstance(processor, dict) and len(processor) != count:
247-
raise ValueError(
248-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
249-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
250-
)
251-
252-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
253-
if hasattr(module, "set_processor"):
254-
if not isinstance(processor, dict):
255-
module.set_processor(processor)
256-
else:
257-
module.set_processor(processor.pop(f"{name}.processor"))
258-
259-
for sub_name, child in module.named_children():
260-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
261-
262-
for name, module in self.named_children():
263-
fn_recursive_attn_processor(name, module, processor)
264-
265206
def set_default_attn_processor(self):
266207
"""
267208
Disables custom attention processors and sets the default attention implementation.

src/diffusers/models/autoencoders/consistency_decoder_vae.py

Lines changed: 3 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass
15-
from typing import Dict, Optional, Tuple, Union
15+
from typing import Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn.functional as F
@@ -23,10 +23,10 @@
2323
from ...utils import BaseOutput
2424
from ...utils.accelerate_utils import apply_forward_hook
2525
from ...utils.torch_utils import randn_tensor
26+
from ..attention import AttentionMixin
2627
from ..attention_processor import (
2728
ADDED_KV_ATTENTION_PROCESSORS,
2829
CROSS_ATTENTION_PROCESSORS,
29-
AttentionProcessor,
3030
AttnAddedKVProcessor,
3131
AttnProcessor,
3232
)
@@ -49,7 +49,7 @@ class ConsistencyDecoderVAEOutput(BaseOutput):
4949
latent_dist: "DiagonalGaussianDistribution"
5050

5151

52-
class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
52+
class ConsistencyDecoderVAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
5353
r"""
5454
The consistency decoder used with DALL-E 3.
5555
@@ -167,66 +167,6 @@ def __init__(
167167
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
168168
self.tile_overlap_factor = 0.25
169169

170-
@property
171-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
172-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
173-
r"""
174-
Returns:
175-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
176-
indexed by its weight name.
177-
"""
178-
# set recursively
179-
processors = {}
180-
181-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
182-
if hasattr(module, "get_processor"):
183-
processors[f"{name}.processor"] = module.get_processor()
184-
185-
for sub_name, child in module.named_children():
186-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
187-
188-
return processors
189-
190-
for name, module in self.named_children():
191-
fn_recursive_add_processors(name, module, processors)
192-
193-
return processors
194-
195-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
196-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
197-
r"""
198-
Sets the attention processor to use to compute attention.
199-
200-
Parameters:
201-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
202-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
203-
for **all** `Attention` layers.
204-
205-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
206-
processor. This is strongly recommended when setting trainable attention processors.
207-
208-
"""
209-
count = len(self.attn_processors.keys())
210-
211-
if isinstance(processor, dict) and len(processor) != count:
212-
raise ValueError(
213-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
214-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
215-
)
216-
217-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
218-
if hasattr(module, "set_processor"):
219-
if not isinstance(processor, dict):
220-
module.set_processor(processor)
221-
else:
222-
module.set_processor(processor.pop(f"{name}.processor"))
223-
224-
for sub_name, child in module.named_children():
225-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
226-
227-
for name, module in self.named_children():
228-
fn_recursive_attn_processor(name, module, processor)
229-
230170
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
231171
def set_default_attn_processor(self):
232172
"""

0 commit comments

Comments
 (0)