Skip to content

Commit cd488f4

Browse files
authored
Check scoped_services before resolving from map when in a scope
1 parent 95edd65 commit cd488f4

2 files changed

Lines changed: 54 additions & 8 deletions

File tree

rodi/__init__.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _get_obj_locals(obj) -> Optional[Dict[str, Any]]:
7979

8080

8181
def class_name(input_type):
82-
if input_type in {list, set} and str(
82+
if input_type in {list, set} and str( # noqa: E721
8383
type(input_type) == "<class 'types.GenericAlias'>"
8484
):
8585
# for Python 3.9 list[T], set[T]
@@ -233,9 +233,9 @@ class ActivationScope:
233233
def __init__(
234234
self,
235235
provider: Optional["Services"] = None,
236-
scoped_services: Optional[Dict[Type[T], T]] = None,
236+
scoped_services: Optional[Dict[Union[Type[T], str], T]] = None,
237237
):
238-
self.provider = provider
238+
self.provider = provider or Services()
239239
self.scoped_services = scoped_services or {}
240240

241241
def __enter__(self):
@@ -248,7 +248,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
248248

249249
def get(
250250
self,
251-
desired_type: Union[Type[T], str],
251+
desired_type: Union[Union[Type[T], str], str],
252252
scope: Optional["ActivationScope"] = None,
253253
*,
254254
default: Optional[Any] = ...,
@@ -713,13 +713,14 @@ def get(
713713
scope = ActivationScope(self)
714714

715715
resolver = self._map.get(desired_type)
716+
scoped_service = scope.scoped_services.get(desired_type) if scope else None
716717

717-
if not resolver:
718+
if not resolver and not scoped_service:
718719
if default is not ...:
719720
return cast(T, default)
720721
raise CannotResolveTypeException(desired_type)
721722

722-
return cast(T, resolver(scope, desired_type))
723+
return cast(T, scoped_service or resolver(scope, desired_type))
723724

724725
def _get_getter(self, key, param):
725726
if param.annotation is _empty:
@@ -756,13 +757,15 @@ def get_executor(self, method: Callable) -> Callable:
756757

757758
if iscoroutinefunction(method):
758759

759-
async def async_executor(scoped: Optional[Dict[Type, Any]] = None):
760+
async def async_executor(
761+
scoped: Optional[Dict[Union[Type, str], Any]] = None
762+
):
760763
with ActivationScope(self, scoped) as context:
761764
return await method(*[fn(context) for fn in fns])
762765

763766
return async_executor
764767

765-
def executor(scoped: Optional[Dict[Type, Any]] = None):
768+
def executor(scoped: Optional[Dict[Union[Type, str], Any]] = None):
766769
with ActivationScope(self, scoped) as context:
767770
return method(*[fn(context) for fn in fns])
768771

tests/test_services.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,18 @@ def test_scoped_services_use_correct_scope_context_by_default_with_multiple_scop
505505
assert c is not f
506506

507507

508+
def test_scoped_services_works_with_str_keys():
509+
container = Container()
510+
container.add_singleton("Id", IdGetter)
511+
provider = container.build_provider()
512+
513+
with ActivationScope(provider) as scoped_provider:
514+
a = scoped_provider.get("Id")
515+
b = provider.get("id")
516+
517+
assert a is b
518+
519+
508520
def test_scoped_services():
509521
container = Container()
510522
container._add_exact_scoped(IdGetter)
@@ -522,6 +534,37 @@ def test_scoped_services():
522534
assert b is not d
523535

524536

537+
def test_scoped_service_from_scoped_services():
538+
container = Container()
539+
provider = container.build_provider()
540+
541+
scoped_service = IdGetter()
542+
543+
with ActivationScope(
544+
provider,
545+
{
546+
IdGetter: scoped_service,
547+
},
548+
) as context:
549+
a = provider.get(IdGetter, context)
550+
b = provider.get(IdGetter, default=None)
551+
c = provider.get(IdGetter, default=None)
552+
553+
with ActivationScope(
554+
scoped_services={
555+
IdGetter: scoped_service,
556+
}
557+
) as context:
558+
d = provider.get(IdGetter, context)
559+
e = provider.get(IdGetter, default=None)
560+
561+
assert a is scoped_service
562+
assert b is None
563+
assert c is None
564+
assert d is scoped_service
565+
assert e is None
566+
567+
525568
def test_scoped_services_with_shortcut():
526569
container = Container()
527570
container.add_scoped(IdGetter)

0 commit comments

Comments
 (0)