diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index ba770ad3..ad022a63 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -539,6 +539,17 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions: typing.get_origin(typ), collections.abc.Generator ): return unify(typing.get_args(typ)[0], typing.get_args(subtyp)[0], subs) + elif typing.get_origin(subtyp) is effectful.ops.types.Operation and not ( + isinstance(typing.get_origin(typ), type) + and issubclass(typing.get_origin(typ), effectful.ops.types.Operation) + ): + # An Operation[P, R] is a Callable[P, R] (gh #669): unify the pattern + # against the operation's parameter/return signature. ``Operation``'s + # args are (params, return) just like ``Callable``'s, except params is + # a tuple (or ``...``) rather than a list. + op_params, op_ret = typing.get_args(subtyp) + callable_params = op_params if op_params is ... else list(op_params) + return unify(typ, collections.abc.Callable[callable_params, op_ret], subs) # type: ignore elif typing.get_origin(typ) == typing.get_origin(subtyp): return unify(typing.get_args(typ), typing.get_args(subtyp), subs) elif types.get_original_bases(typing.get_origin(subtyp)): @@ -556,6 +567,17 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions: and issubclass(subtyp, typing.get_origin(typ)) ): return subs # implicit expansion to subtyp[Any] + elif isinstance(typ, GenericAlias): + # Special case for treating arrays as iterables of arrays + try: + import jax + + if typing.get_origin(typ) is collections.abc.Iterable and issubclass( + subtyp, jax.Array + ): + return unify(typing.get_args(typ)[0], jax.Array, subs) + except ImportError: + pass raise TypeError(f"Cannot unify generic type {typ} with {subtyp} given {subs}.") diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index fc1a4ba1..fe3a9ed0 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -1967,3 +1967,69 @@ class Info(typing.TypedDict): subs = unify(collections.abc.Mapping, Info) assert subs == {} + + +def test_unify_jax_array_iterable(): + import jax + + subs = unify(collections.abc.Iterable[T], jax.Array) + assert subs == {T: jax.Array} + + +def test_unify_operation_callable(): + """An ``Operation[P, R]`` unifies as a ``Callable[P, R]`` (gh #669).""" + from effectful.ops.types import Operation + + # TypeVar params bind to the operation's parameter/return types + assert unify(collections.abc.Callable[[T], V], Operation[[int], int]) == { + T: int, + V: int, + } + # a repeated TypeVar binds consistently + assert unify(collections.abc.Callable[[T], T], Operation[[int], int]) == {T: int} + # multiple parameters + assert unify(collections.abc.Callable[[T, U], V], Operation[[int, str], bool]) == { + T: int, + U: str, + V: bool, + } + # ``...`` parameters in the pattern ignore the operation's parameter types + assert unify(collections.abc.Callable[..., V], Operation[[int], int]) == {V: int} + # fully concrete: nothing to bind + assert unify(collections.abc.Callable[[int], int], Operation[[int], int]) == {} + # nested: an operation-valued argument + assert unify( + collections.abc.Callable[[T], list[V]], Operation[[int], list[str]] + ) == {T: int, V: str} + + +def test_unify_operation_callable_failure(): + """An arity mismatch between the Callable pattern and the Operation fails.""" + from effectful.ops.types import Operation + + with pytest.raises(TypeError): + unify(collections.abc.Callable[[T, U], V], Operation[[int], int]) + with pytest.raises(TypeError): + unify(collections.abc.Callable[[T], V], Operation[[int, str], bool]) + + +def test_operation_unifies_with_callable_param_gh669(): + """An Operation passed where a ``Callable`` is expected infers correctly. + + Regression test for gh #669: calling an operation whose parameter is typed + ``Callable[[S], T]`` with another operation should unify and infer the return + type, rather than raising ``Cannot unify generic type ...``. + """ + from effectful.ops.semantics import typeof + from effectful.ops.types import NotHandled, Operation + + @Operation.define + def f(x: int) -> int: + raise NotHandled + + @Operation.define + def g[S, R](x: collections.abc.Callable[[S], R]) -> R: + raise NotHandled + + term = g(f) + assert typeof(term) is int