Skip to content
Open
7 changes: 7 additions & 0 deletions src/dependency_injector/providers.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ cdef class Resource(Provider):
cpdef object _provide(self, tuple args, dict kwargs)


cdef class ContextLocalResource(Resource):
cdef object _resource_context_var
cdef object _shutdowner_context_var

cpdef object _provide(self, tuple args, dict kwargs)


cdef class Container(Provider):
cdef object _container_cls
cdef dict _overriding_providers
Expand Down
2 changes: 2 additions & 0 deletions src/dependency_injector/providers.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,8 @@ class Resource(Provider[T]):
def init(self) -> Optional[Awaitable[T]]: ...
def shutdown(self) -> Optional[Awaitable]: ...

class ContextLocalResource(Resource[T]):...

class Container(Provider[T]):
def __init__(
self,
Expand Down
129 changes: 128 additions & 1 deletion src/dependency_injector/providers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3186,7 +3186,7 @@ cdef class ThreadLocalSingleton(BaseSingleton):
return future_result

self._storage.instance = instance

return instance

def _async_init_instance(self, future_result, result):
Expand Down Expand Up @@ -3867,6 +3867,133 @@ cdef class Resource(Provider):
return self._resource


cdef class ContextLocalResource(Resource):
_none = object()

def __init__(self, provides=None, *args, **kwargs):
self._resource_context_var = ContextVar("_resource_context_var", default=self._none)
self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=self._none)
super().__init__(provides, *args, **kwargs)

def __deepcopy__(self, memo):
"""Create and return full copy of provider."""
copied = memo.get(id(self))
if copied is not None:
return copied

if self._resource_context_var.get() != self._none:
raise Error("Can not copy initialized resource")
copied = _memorized_duplicate(self, memo)
copied.set_provides(_copy_if_provider(self.provides, memo))
copied.set_args(*deepcopy_args(self, self.args, memo))
copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo))

self._copy_overridings(copied, memo)

return copied

@property
def initialized(self):
"""Check if resource is initialized."""
return self._resource_context_var.get() != self._none


def shutdown(self):
"""Shutdown resource."""
if self._resource_context_var.get() == self._none :
self._reset_all_contex_vars()
if self._async_mode == ASYNC_MODE_ENABLED:
return NULL_AWAITABLE
return
if self._shutdowner_context_var.get() != self._none:
future = self._shutdowner_context_var.get()(None, None, None)
if __is_future_or_coroutine(future):
self._reset_all_contex_vars()
return ensure_future(self._shutdown_async(future))


self._reset_all_contex_vars()
if self._async_mode == ASYNC_MODE_ENABLED:
return NULL_AWAITABLE

def _reset_all_contex_vars(self):
self._resource_context_var.set(self._none)
self._shutdowner_context_var.set(self._none)


async def _shutdown_async(self, future) -> None:
await future


async def _handle_async_cm(self, obj) -> None:
resource = await obj.__aenter__()
return resource

async def _provide_async(self, future):
try:
obj = await future

if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
resource = await obj.__aenter__()
shutdowner = obj.__aexit__
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
resource = obj.__enter__()
shutdowner = obj.__exit__
else:
resource = obj
shutdowner = None

return resource, shutdowner
except:
raise

cpdef object _provide(self, tuple args, dict kwargs):
if self._resource_context_var.get() != self._none:
return self._resource_context_var.get()
obj = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)

if __is_future_or_coroutine(obj):
future_result = asyncio.Future()
future = ensure_future(self._provide_async(obj))
future.add_done_callback(functools.partial(self._async_init_instance, future_result))
return future_result
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
resource = obj.__enter__()
self._resource_context_var.set(resource)
self._shutdowner_context_var.set(obj.__exit__)
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
resource = ensure_future(self._handle_async_cm(obj))
self._resource_context_var.set(resource)
self._shutdowner_context_var.set(obj.__aexit__)
return resource
else:
self._resource_context_var.set(obj)
self._shutdowner_context_var.set(self._none)

return self._resource_context_var.get()

def _async_init_instance(self, future_result, result):
try:
resource, shutdowner = result.result()
except Exception as exception:
self._resource_context_var.set(self._none)
self._shutdowner_context_var.set(self._none)
future_result.set_exception(exception)
else:
self._resource_context_var.set(resource)
self._shutdowner_context_var.set(shutdowner)
future_result.set_result(resource)


cdef class Container(Provider):
"""Container provider provides an instance of declarative container.

Expand Down
Loading