diff --git a/src/diffusers/commands/custom_blocks.py b/src/diffusers/commands/custom_blocks.py index 43d9ea88577a..953240c5a2c3 100644 --- a/src/diffusers/commands/custom_blocks.py +++ b/src/diffusers/commands/custom_blocks.py @@ -89,8 +89,6 @@ def run(self): # automap = self._create_automap(parent_class=parent_class, child_class=child_class) # with open(CONFIG, "w") as f: # json.dump(automap, f) - with open("requirements.txt", "w") as f: - f.write("") def _choose_block(self, candidates, chosen=None): for cls, base in candidates: diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 307698245e5b..df2ae4837316 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -39,6 +39,7 @@ InputParam, InsertableDict, OutputParam, + _validate_requirements, format_components, format_configs, make_doc_string, @@ -239,6 +240,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): config_name = "modular_config.json" model_name = None + _requirements: Optional[Dict[str, str]] = None @classmethod def _get_signature_keys(cls, obj): @@ -301,6 +303,19 @@ def from_pretrained( trust_remote_code: bool = False, **kwargs, ): + config = cls.load_config(pretrained_model_name_or_path) + has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_remote_code + ) + if not (has_remote_code and trust_remote_code): + raise ValueError( + "Selected model repository does not happear to have any custom code or does not have a valid `config.json` file." + ) + + if "requirements" in config and config["requirements"] is not None: + _ = _validate_requirements(config["requirements"]) + hub_kwargs_names = [ "cache_dir", "force_download", @@ -313,16 +328,6 @@ def from_pretrained( ] hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} - config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs) - has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] - trust_remote_code = resolve_trust_remote_code( - trust_remote_code, pretrained_model_name_or_path, has_remote_code - ) - if not has_remote_code and trust_remote_code: - raise ValueError( - "Selected model repository does not happear to have any custom code or does not have a valid `config.json` file." - ) - class_ref = config["auto_map"][cls.__name__] module_file, class_name = class_ref.split(".") module_file = module_file + ".py" @@ -347,8 +352,13 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs): module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] auto_map = {f"{parent_module}": f"{module}.{cls_name}"} - self.register_to_config(auto_map=auto_map) + + # resolve requirements + requirements = _validate_requirements(getattr(self, "_requirements", None)) + if requirements: + self.register_to_config(requirements=requirements) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) config = dict(self.config) self._internal_dict = FrozenDict(config) @@ -1132,6 +1142,14 @@ def doc(self): expected_configs=self.expected_configs, ) + @property + def _requirements(self) -> Dict[str, str]: + requirements = {} + for block_name, block in self.sub_blocks.items(): + if getattr(block, "_requirements", None): + requirements[block_name] = block._requirements + return requirements + class LoopSequentialPipelineBlocks(ModularPipelineBlocks): """ diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index b15126868634..50190305be5e 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -19,9 +19,11 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union import torch +from packaging.specifiers import InvalidSpecifier, SpecifierSet from ..configuration_utils import ConfigMixin, FrozenDict from ..utils import is_torch_available, logging +from ..utils.import_utils import _is_package_available if is_torch_available(): @@ -670,3 +672,86 @@ def make_doc_string( output += format_output_params(outputs, indent_level=2) return output + + +def _validate_requirements(reqs): + if reqs is None: + normalized_reqs = {} + else: + if not isinstance(reqs, dict): + raise ValueError( + "Requirements must be provided as a dictionary mapping package names to version specifiers." + ) + normalized_reqs = _normalize_requirements(reqs) + + if not normalized_reqs: + return {} + + final: Dict[str, str] = {} + for req, specified_ver in normalized_reqs.items(): + req_available, req_actual_ver = _is_package_available(req) + if not req_available: + logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.") + + if specified_ver: + try: + specifier = SpecifierSet(specified_ver) + except InvalidSpecifier as err: + raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err + + if req_actual_ver == "N/A": + logger.warning( + f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected." + ) + elif not specifier.contains(req_actual_ver, prereleases=True): + logger.warning( + f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected." + ) + + final[req] = specified_ver + + return final + + +def _normalize_requirements(reqs): + if not reqs: + return {} + + normalized: "OrderedDict[str, str]" = OrderedDict() + + def _accumulate(mapping: Dict[str, Any]): + for pkg, spec in mapping.items(): + if isinstance(spec, dict): + # This is recursive because blocks are composable. This way, we can merge requirements + # from multiple blocks. + _accumulate(spec) + continue + + pkg_name = str(pkg).strip() + if not pkg_name: + raise ValueError("Requirement package name cannot be empty.") + + spec_str = "" if spec is None else str(spec).strip() + if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")): + spec_str = f"=={spec_str}" + + existing_spec = normalized.get(pkg_name) + if existing_spec is not None: + if not existing_spec and spec_str: + normalized[pkg_name] = spec_str + elif existing_spec and spec_str and existing_spec != spec_str: + try: + combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str]))) + except InvalidSpecifier: + logger.warning( + f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'." + ) + else: + normalized[pkg_name] = str(combined_spec) + continue + + normalized[pkg_name] = spec_str + + _accumulate(reqs) + + return normalized