11import inspect
2- from typing import Any , Dict , Type
2+
3+ from typing import Any , Callable , Dict , List , Optional , TypeVar , Type
34
45from dependency_injection .registration import Registration
56from dependency_injection .scope import DEFAULT_SCOPE_NAME , Scope
67from dependency_injection .utils .singleton_meta import SingletonMeta
78
9+ Self = TypeVar ('Self' , bound = 'DependencyContainer' )
10+
811
912DEFAULT_CONTAINER_NAME = "default_container"
1013
1114class DependencyContainer (metaclass = SingletonMeta ):
1215
13- def __init__ (self , name = None ):
16+ def __init__ (self , name : str = None ):
1417 self .name = name if name is not None else DEFAULT_CONTAINER_NAME
1518 self ._registrations = {}
1619 self ._singleton_instances = {}
1720 self ._scoped_instances = {}
1821
1922 @classmethod
20- def get_instance (cls , name = None ):
23+ def get_instance (cls , name : str = None ) -> Self :
2124 if name is None :
2225 name = DEFAULT_CONTAINER_NAME
2326
@@ -26,88 +29,118 @@ def get_instance(cls, name=None):
2629
2730 return cls ._instances [(cls , name )]
2831
29- def register_transient (self , interface , class_ , constructor_args = None ):
30- if interface in self ._registrations :
31- raise ValueError (f"Dependency { interface } is already registered." )
32- self ._registrations [interface ] = Registration (interface , class_ , Scope .TRANSIENT , constructor_args )
33-
34- def register_scoped (self , interface , class_ , constructor_args = None ):
35- if interface in self ._registrations :
36- raise ValueError (f"Dependency { interface } is already registered." )
37- self ._registrations [interface ] = Registration (interface , class_ , Scope .SCOPED , constructor_args )
38-
39- def register_singleton (self , interface , class_ , constructor_args = None ):
40- if interface in self ._registrations :
41- raise ValueError (f"Dependency { interface } is already registered." )
42- self ._registrations [interface ] = Registration (interface , class_ , Scope .SINGLETON , constructor_args )
43-
44- def resolve (self , interface , scope_name = DEFAULT_SCOPE_NAME ):
32+ def register_transient (self , dependency : Type , implementation : Optional [Type ] = None , tags : Optional [set ] = None , constructor_args : Optional [Dict [str , Any ]] = None ) -> None :
33+ if implementation is None :
34+ implementation = dependency
35+ if dependency in self ._registrations :
36+ raise ValueError (f"Dependency { dependency } is already registered." )
37+ self ._registrations [dependency ] = Registration (dependency , implementation , Scope .TRANSIENT , tags , constructor_args )
38+
39+ def register_scoped (self , dependency : Type , implementation : Optional [Type ] = None , tags : Optional [set ] = None , constructor_args : Optional [Dict [str , Any ]] = None ) -> None :
40+ if implementation is None :
41+ implementation = dependency
42+ if dependency in self ._registrations :
43+ raise ValueError (f"Dependency { dependency } is already registered." )
44+ self ._registrations [dependency ] = Registration (dependency , implementation , Scope .SCOPED , tags , constructor_args )
45+
46+ def register_singleton (self , dependency : Type , implementation : Optional [Type ] = None , tags : Optional [set ] = None , constructor_args : Optional [Dict [str , Any ]] = None ) -> None :
47+ if implementation is None :
48+ implementation = dependency
49+ if dependency in self ._registrations :
50+ raise ValueError (f"Dependency { dependency } is already registered." )
51+ self ._registrations [dependency ] = Registration (dependency , implementation , Scope .SINGLETON , tags , constructor_args )
52+
53+ def register_factory (self , dependency : Type , factory : Callable [[Any ], Any ], factory_args : Optional [Dict [str , Any ]] = None , tags : Optional [set ] = None ) -> None :
54+ if dependency in self ._registrations :
55+ raise ValueError (f"Dependency { dependency } is already registered." )
56+ self ._registrations [dependency ] = Registration (dependency , None , Scope .FACTORY , None , tags , factory , factory_args )
57+
58+ def register_instance (self , dependency : Type , instance : Any , tags : Optional [set ] = None ) -> None :
59+ if dependency in self ._registrations :
60+ raise ValueError (f"Dependency { dependency } is already registered." )
61+ self ._registrations [dependency ] = Registration (dependency , type (instance ), Scope .SINGLETON , constructor_args = {}, tags = tags )
62+ self ._singleton_instances [dependency ] = instance
63+
64+ def resolve (self , dependency : Type , scope_name : str = DEFAULT_SCOPE_NAME ) -> Type :
4565 if scope_name not in self ._scoped_instances :
4666 self ._scoped_instances [scope_name ] = {}
4767
48- if interface not in self ._registrations :
49- raise KeyError (f"Dependency { interface .__name__ } is not registered." )
68+ if dependency not in self ._registrations :
69+ raise KeyError (f"Dependency { dependency .__name__ } is not registered." )
5070
51- registration = self ._registrations [interface ]
52- dependency_scope = registration .scope
53- dependency_class = registration .class_
71+ registration = self ._registrations [dependency ]
72+ scope = registration .scope
73+ implementation = registration .implementation
5474 constructor_args = registration .constructor_args
5575
56- self ._validate_constructor_args (constructor_args = constructor_args , class_ = dependency_class )
76+ self ._validate_constructor_args (constructor_args = constructor_args , implementation = implementation )
5777
58- if dependency_scope == Scope .TRANSIENT :
78+ if scope == Scope .TRANSIENT :
5979 return self ._inject_dependencies (
60- class_ = dependency_class ,
80+ implementation = implementation ,
6181 constructor_args = constructor_args
6282 )
63- elif dependency_scope == Scope .SCOPED :
64- if interface not in self ._scoped_instances [scope_name ]:
65- self ._scoped_instances [scope_name ][interface ] = (
83+ elif scope == Scope .SCOPED :
84+ if dependency not in self ._scoped_instances [scope_name ]:
85+ self ._scoped_instances [scope_name ][dependency ] = (
6686 self ._inject_dependencies (
67- class_ = dependency_class ,
87+ implementation = implementation ,
6888 scope_name = scope_name ,
6989 constructor_args = constructor_args ,
7090 ))
71- return self ._scoped_instances [scope_name ][interface ]
72- elif dependency_scope == Scope .SINGLETON :
73- if interface not in self ._singleton_instances :
74- self ._singleton_instances [interface ] = (
91+ return self ._scoped_instances [scope_name ][dependency ]
92+ elif scope == Scope .SINGLETON :
93+ if dependency not in self ._singleton_instances :
94+ self ._singleton_instances [dependency ] = (
7595 self ._inject_dependencies (
76- class_ = dependency_class ,
96+ implementation = implementation ,
7797 constructor_args = constructor_args
7898 )
7999 )
80- return self ._singleton_instances [interface ]
81-
82- raise ValueError (f"Invalid dependency scope: { dependency_scope } " )
83-
84- def _validate_constructor_args (self , constructor_args : Dict [str , Any ], class_ : Type ) -> None :
85- class_constructor = inspect .signature (class_ .__init__ ).parameters
100+ return self ._singleton_instances [dependency ]
101+ elif scope == Scope .FACTORY :
102+ factory = registration .factory
103+ factory_args = registration .factory_args or {}
104+ return factory (** factory_args )
105+
106+ raise ValueError (f"Invalid dependency scope: { scope } " )
107+
108+ def resolve_all (self , tags : Optional [set ] = None ) -> List [Any ]:
109+ tags = tags or []
110+ resolved_dependencies = []
111+ for registration in self ._registrations .values ():
112+ if not len (tags ) or tags .intersection (registration .tags ):
113+ resolved_dependencies .append (
114+ self .resolve (registration .dependency ))
115+ return resolved_dependencies
116+
117+ def _validate_constructor_args (self , constructor_args : Dict [str , Any ], implementation : Type ) -> None :
118+ constructor = inspect .signature (implementation .__init__ ).parameters
86119
87120 # Check if any required parameter is missing
88- missing_params = [param for param in class_constructor .keys () if
121+ missing_params = [param for param in constructor .keys () if
89122 param not in ["self" , "cls" , "args" , "kwargs" ] and
90123 param not in constructor_args ]
91124 if missing_params :
92125 raise ValueError (
93126 f"Missing required constructor arguments: "
94- f"{ ', ' .join (missing_params )} for class '{ class_ .__name__ } '." )
127+ f"{ ', ' .join (missing_params )} for class '{ implementation .__name__ } '." )
95128
96129 for arg_name , arg_value in constructor_args .items ():
97- if arg_name not in class_constructor :
130+ if arg_name not in constructor :
98131 raise ValueError (
99- f"Invalid constructor argument '{ arg_name } ' for class '{ class_ .__name__ } '. "
132+ f"Invalid constructor argument '{ arg_name } ' for class '{ implementation .__name__ } '. "
100133 f"The class does not have a constructor parameter with this name." )
101134
102- expected_type = class_constructor [arg_name ].annotation
135+ expected_type = constructor [arg_name ].annotation
103136 if expected_type != inspect .Parameter .empty :
104137 if not isinstance (arg_value , expected_type ):
105138 raise TypeError (
106139 f"Constructor argument '{ arg_name } ' has an incompatible type. "
107140 f"Expected type: { expected_type } , provided type: { type (arg_value )} ." )
108141
109- def _inject_dependencies (self , class_ , scope_name = None , constructor_args = None ):
110- constructor = inspect .signature (class_ .__init__ )
142+ def _inject_dependencies (self , implementation : Type , scope_name : str = None , constructor_args : Optional [ Dict [ str , Any ]] = None ) -> Type :
143+ constructor = inspect .signature (implementation .__init__ )
111144 params = constructor .parameters
112145
113146 dependencies = {}
@@ -127,4 +160,4 @@ def _inject_dependencies(self, class_, scope_name=None, constructor_args=None):
127160 else :
128161 dependencies [param_name ] = self .resolve (param_info .annotation , scope_name = scope_name )
129162
130- return class_ (** dependencies )
163+ return implementation (** dependencies )
0 commit comments