Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/diffusers/commands/custom_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ def run(self):
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")

def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:
Expand Down
40 changes: 29 additions & 11 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
InputParam,
InsertableDict,
OutputParam,
_validate_requirements,
format_components,
format_configs,
make_doc_string,
Expand Down Expand Up @@ -239,6 +240,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):

config_name = "modular_config.json"
model_name = None
_requirements: Optional[Dict[str, str]] = None

@classmethod
def _get_signature_keys(cls, obj):
Expand Down Expand Up @@ -301,6 +303,19 @@ def from_pretrained(
trust_remote_code: bool = False,
**kwargs,
):
config = cls.load_config(pretrained_model_name_or_path)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not (has_remote_code and trust_remote_code):
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
Comment on lines +306 to +314
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same code as previous but just lifted a bit above.


if "requirements" in config and config["requirements"] is not None:
_ = _validate_requirements(config["requirements"])

hub_kwargs_names = [
"cache_dir",
"force_download",
Expand All @@ -313,16 +328,6 @@ def from_pretrained(
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}

config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not has_remote_code and trust_remote_code:
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)

class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
Expand All @@ -347,8 +352,13 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs):
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}

self.register_to_config(auto_map=auto_map)

# resolve requirements
requirements = _validate_requirements(getattr(self, "_requirements", None))
if requirements:
self.register_to_config(requirements=requirements)

self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
Expand Down Expand Up @@ -1132,6 +1142,14 @@ def doc(self):
expected_configs=self.expected_configs,
)

@property
def _requirements(self) -> Dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements


class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""
Expand Down
85 changes: 85 additions & 0 deletions src/diffusers/modular_pipelines/modular_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from typing import Any, Dict, List, Literal, Optional, Type, Union

import torch
from packaging.specifiers import InvalidSpecifier, SpecifierSet

from ..configuration_utils import ConfigMixin, FrozenDict
from ..utils import is_torch_available, logging
from ..utils.import_utils import _is_package_available


if is_torch_available():
Expand Down Expand Up @@ -670,3 +672,86 @@ def make_doc_string(
output += format_output_params(outputs, indent_level=2)

return output


def _validate_requirements(reqs):
if reqs is None:
normalized_reqs = {}
else:
if not isinstance(reqs, dict):
raise ValueError(
"Requirements must be provided as a dictionary mapping package names to version specifiers."
)
normalized_reqs = _normalize_requirements(reqs)

if not normalized_reqs:
return {}

final: Dict[str, str] = {}
for req, specified_ver in normalized_reqs.items():
req_available, req_actual_ver = _is_package_available(req)
if not req_available:
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")

if specified_ver:
try:
specifier = SpecifierSet(specified_ver)
except InvalidSpecifier as err:
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err

if req_actual_ver == "N/A":
logger.warning(
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
)
elif not specifier.contains(req_actual_ver, prereleases=True):
logger.warning(
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
)

final[req] = specified_ver

return final


def _normalize_requirements(reqs):
if not reqs:
return {}

normalized: "OrderedDict[str, str]" = OrderedDict()

def _accumulate(mapping: Dict[str, Any]):
for pkg, spec in mapping.items():
if isinstance(spec, dict):
# This is recursive because blocks are composable. This way, we can merge requirements
# from multiple blocks.
_accumulate(spec)
continue

pkg_name = str(pkg).strip()
if not pkg_name:
raise ValueError("Requirement package name cannot be empty.")

spec_str = "" if spec is None else str(spec).strip()
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
spec_str = f"=={spec_str}"

existing_spec = normalized.get(pkg_name)
if existing_spec is not None:
if not existing_spec and spec_str:
normalized[pkg_name] = spec_str
elif existing_spec and spec_str and existing_spec != spec_str:
try:
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
except InvalidSpecifier:
logger.warning(
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
)
else:
normalized[pkg_name] = str(combined_spec)
continue

normalized[pkg_name] = spec_str

_accumulate(reqs)

return normalized
Loading