Skip to content

Commit 95edd65

Browse files
authored
Allow getting from scope context without needing to provide scope (#38)
1 parent b26f1be commit 95edd65

2 files changed

Lines changed: 48 additions & 0 deletions

File tree

rodi/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,15 @@ def __enter__(self):
246246
def __exit__(self, exc_type, exc_val, exc_tb):
247247
self.dispose()
248248

249+
def get(
250+
self,
251+
desired_type: Union[Type[T], str],
252+
scope: Optional["ActivationScope"] = None,
253+
*,
254+
default: Optional[Any] = ...,
255+
) -> T:
256+
return self.provider.get(desired_type, scope or self, default=default)
257+
249258
def dispose(self):
250259
if self.provider:
251260
self.provider = None

tests/test_services.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,45 @@ def test_transient_services():
466466
assert d is not c
467467

468468

469+
def test_scoped_services_use_scope_context_by_default():
470+
container = Container()
471+
container._add_exact_scoped(IdGetter)
472+
provider = container.build_provider()
473+
474+
with ActivationScope(provider) as scoped_provider:
475+
a = scoped_provider.get(IdGetter)
476+
b = scoped_provider.get(IdGetter)
477+
c = scoped_provider.get(IdGetter)
478+
d = provider.get(IdGetter)
479+
480+
assert a is b
481+
assert b is c
482+
assert a is not d
483+
assert b is not d
484+
485+
486+
def test_scoped_services_use_correct_scope_context_by_default_with_multiple_scopes():
487+
container = Container()
488+
container._add_exact_scoped(IdGetter)
489+
provider = container.build_provider()
490+
491+
with ActivationScope(provider) as scoped_provider_1:
492+
a = scoped_provider_1.get(IdGetter)
493+
b = scoped_provider_1.get(IdGetter)
494+
with ActivationScope(provider) as scoped_provider_2:
495+
c = scoped_provider_2.get(IdGetter)
496+
d = scoped_provider_2.get(IdGetter)
497+
e = scoped_provider_2.get(IdGetter, scoped_provider_1)
498+
f = provider.get(IdGetter)
499+
500+
assert a is b
501+
assert b is e
502+
assert c is d
503+
assert a is not c
504+
assert a is not f
505+
assert c is not f
506+
507+
469508
def test_scoped_services():
470509
container = Container()
471510
container._add_exact_scoped(IdGetter)

0 commit comments

Comments
 (0)