Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions effectful/internals/unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,17 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions:
typing.get_origin(typ), collections.abc.Generator
):
return unify(typing.get_args(typ)[0], typing.get_args(subtyp)[0], subs)
elif typing.get_origin(subtyp) is effectful.ops.types.Operation and not (
isinstance(typing.get_origin(typ), type)
and issubclass(typing.get_origin(typ), effectful.ops.types.Operation)
):
# An Operation[P, R] is a Callable[P, R] (gh #669): unify the pattern
# against the operation's parameter/return signature. ``Operation``'s
# args are (params, return) just like ``Callable``'s, except params is
# a tuple (or ``...``) rather than a list.
op_params, op_ret = typing.get_args(subtyp)
callable_params = op_params if op_params is ... else list(op_params)
return unify(typ, collections.abc.Callable[callable_params, op_ret], subs) # type: ignore
elif typing.get_origin(typ) == typing.get_origin(subtyp):
return unify(typing.get_args(typ), typing.get_args(subtyp), subs)
elif types.get_original_bases(typing.get_origin(subtyp)):
Expand All @@ -556,6 +567,17 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions:
and issubclass(subtyp, typing.get_origin(typ))
):
return subs # implicit expansion to subtyp[Any]
elif isinstance(typ, GenericAlias):
# Special case for treating arrays as iterables of arrays
try:
import jax

if typing.get_origin(typ) is collections.abc.Iterable and issubclass(
subtyp, jax.Array
):
return unify(typing.get_args(typ)[0], jax.Array, subs)
except ImportError:
pass
raise TypeError(f"Cannot unify generic type {typ} with {subtyp} given {subs}.")


Expand Down Expand Up @@ -1077,6 +1099,111 @@ def _(value: str | bytes | range | None):
return Box(type(value))


def _iterable_element_type(value: collections.abc.Iterable) -> TypeExpression:
"""Element type of an iterable *value*.

Infers the value's type with :func:`nested_type` and unifies it against
``Iterable[E]`` to recover the element type ``E``. Raises ``TypeError`` (the
element is itself a :class:`~effectful.ops.types.Term`) or ``KeyError`` (the
iterable is empty/bare so ``E`` is unbound) when the element type cannot be
determined; callers fall back to the bare iterator type in that case.
"""
if isinstance(value, str | bytes):
# str/bytes are atomic to nested_type; their elements share their type.
return type(value)
E = typing.TypeVar("E")
try:
return unify(collections.abc.Iterable[E], nested_type(value).value)[E] # type: ignore[return-value, valid-type]
except KeyError:
raise TypeError("Could not resolve concrete element type")


@nested_type.register
def _(value: collections.abc.Iterator):
try:
reduced = value.__reduce__()
if isinstance(reduced, str):
return Box(type(value)) # reduced to a global name; opaque
ctor, args, *state = reduced
_ = [nested_type(arg).value for arg in args]
except (TypeError, AttributeError):
return Box(type(value)) # un-reducible iterators are opaque

if ctor is iter or ctor is reversed:
# ``args[0]`` is the underlying iterable. ``reversed([...])`` is a
# ``list_reverseiterator`` -- not a ``reversed`` instance -- so it is
# dispatched here by ctor rather than by a type-keyed registration.
try:
return Box(collections.abc.Iterator[_iterable_element_type(args[0])]) # type: ignore[misc]
except TypeError:
return Box(type(value))
else:
return Box(type(value))


@nested_type.register(map)
def _(value):
# ``map(f, *iterables)`` yields ``f(*items)``, so the element type is the
# return type of ``f`` rather than the source element types.
_ctor, (func, *sources), *_state = value.__reduce__()
try:
if typing.get_args(nested_type(func).value) and sources:
Xs = [typing.TypeVar(f"X{i}") for i in range(len(sources))]
Y = typing.TypeVar("Y")
typ = (
collections.abc.Callable[Xs, Y],
*[collections.abc.Iterable[Xi] for Xi in Xs],
)
subtyp = (
nested_type(func).value,
*[nested_type(source).value for source in sources],
)
subs = unify(typ, subtyp)
if Y not in subs:
raise TypeError("Could not resolve concrete return type")
return Box(collections.abc.Iterator[subs[Y]])
else: # un-annotated function: fall back to the bare iterator type
return nested_type.dispatch(collections.abc.Iterator)(value)
except TypeError:
return nested_type.dispatch(collections.abc.Iterator)(value)


@nested_type.register(filter)
def _(value):
# ``filter`` preserves the elements of its source iterable.
_ctor, (_func, source), *_state = value.__reduce__()
try:
return Box(collections.abc.Iterator[_iterable_element_type(source)])
except TypeError:
return Box(type(value))


@nested_type.register(zip)
def _(value):
# __reduce__() is (sources,) or (sources, strict); () if empty
_ctor, *rest = value.__reduce__()
sources = rest[0] if rest else ()
if not sources:
return Box(collections.abc.Iterator[tuple])

try:
elt_type = tuple[tuple(_iterable_element_type(s) for s in sources)]
except TypeError:
elt_type = tuple[tuple(typing.Any for _ in sources)]
return Box(collections.abc.Iterator[elt_type])


@nested_type.register(enumerate)
def _(value):
# ``enumerate`` yields ``(index, element)`` pairs.
_ctor, (source, _start), *_state = value.__reduce__()
try:
elt_type = tuple[int, _iterable_element_type(source)]
except TypeError:
elt_type = tuple[int, typing.Any]
return Box(collections.abc.Iterator[elt_type])


def freetypevars(typ) -> collections.abc.Set[TypeVariable]:
"""
Return a set of free type variables in the given type expression.
Expand Down
55 changes: 53 additions & 2 deletions effectful/ops/semantics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import builtins
import collections.abc
import contextlib
import dataclasses
Expand Down Expand Up @@ -207,8 +208,6 @@ def evaluate[T](


@evaluate.register(object)
@evaluate.register(str)
@evaluate.register(bytes)
def _evaluate_object[T](expr: T, **kwargs) -> T:
if dataclasses.is_dataclass(expr) and not isinstance(expr, type):
return typing.cast(
Expand All @@ -224,6 +223,13 @@ def _evaluate_object[T](expr: T, **kwargs) -> T:
return expr


@evaluate.register(builtins.range)
@evaluate.register(str | bytes | bytearray)
@evaluate.register(int | float | complex | bool | type(None))
def _evaluate_atomic(expr: Any, **kwargs):
return expr


@evaluate.register(Term)
def _evaluate_term(expr: Term, **kwargs):
args = tuple(evaluate(arg) for arg in expr.args)
Expand Down Expand Up @@ -284,6 +290,51 @@ def _evaluate_list_view(expr, **kwargs):
return [evaluate(item) for item in expr]


@evaluate.register(collections.abc.Set)
def _evaluate_set(expr, **kwargs):
return type(expr)(evaluate(item) for item in expr)


@evaluate.register(collections.abc.Iterator)
def _evaluate_iterator(expr, **kwargs):
try:
ctor, args, *state = expr.__reduce__()
except (TypeError, AttributeError):
return expr # un-reducible iterators are opaque, like any object we can't recurse into

from effectful.internals.unification import nested_type

ExprType = nested_type(expr).value

@Operation.define
def ctor_op(*args) -> ExprType:
return ctor(*args)

# Reify through ``ctor_op`` rather than calling the live constructor: when
# the evaluated args contain a ``Term`` the result is a structural ``Term``
# node whose source iterables stay traversable (so ``fvsof`` finds their
# free variables and term-reconstruction interpretations can rewrite them).
# For fully concrete args ``ctor_op``'s default rule rebuilds the live
# iterator, preserving laziness.
result = ctor_op(*evaluate(args))

if (
not isinstance(result, Term)
and state
and state[0] is not None
and hasattr(result, "__setstate__")
):
result.__setstate__(
evaluate(state[0])
) # preserve position so advanced iterators don't reset
return result


@evaluate.register(builtins.slice)
def _evaluate_slice(expr, **kwargs):
return builtins.slice(*evaluate((expr.start, expr.stop, expr.step)))


def _simple_type(tp: type) -> type:
"""Convert a type object into a type that can be dispatched on."""
if isinstance(tp, typing.TypeVar):
Expand Down
43 changes: 42 additions & 1 deletion tests/test_internals_product_n.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from effectful.internals.product_n import argsof, productN
from collections.abc import Iterable

from effectful.internals.product_n import Product, argsof, productN
from effectful.internals.unification import Box
from effectful.ops.semantics import apply, coproduct, evaluate, handler
from effectful.ops.syntax import defop
Expand Down Expand Up @@ -158,3 +160,42 @@ def add[T](x: T, y: T) -> T:

assert result1.values(i) == result2.values(i) == 2
assert result1.values(s) == result2.values(s) == "aa"


def test_evaluate_iterator_under_product():
"""Evaluating a builtin iterator under a ``productN`` analysis must not crash.

``productN`` that binds the universal ``apply`` operation (as the type/cast
analysis in ``defdata`` does) intercepts *every* operation application and
returns its result wrapped in a :class:`Product`. So under such an
interpretation any sub-term evaluates to a ``Product``.

``evaluate`` reconstructs a builtin iterator by calling its constructor on
its evaluated source iterables (``ctor(*evaluate(args))`` in
``_evaluate_iterator``). A ``map`` eagerly stores ``iter(s())`` -- an
iterator *term* -- as its source. Reconstruction therefore evaluates that
inner term to a ``Product`` and calls ``map(f, Product)``; since ``Product``
is not iterable, this raises ``TypeError: 'Product' object is not iterable``.

This reproduces the bug in isolation: a builtin iterator wrapping a term,
evaluated under a product interpretation, should evaluate successfully
rather than crashing in iterator reconstruction.
"""

@defop
def s() -> Iterable[int]:
raise NotHandled

def apply_type(op, *args, **kwargs):
return Box(op.__type_rule__(*args, **kwargs))

typ = defop(object, name="typ")
cast = defop(object, name="cast")
analysis = productN({typ: {apply: apply_type}, cast: {apply: apply_type}})

# ``map`` eagerly calls ``iter(s())``, storing an iterator term as its source.
m = map(lambda v: v, s())

# Currently raises ``TypeError: 'Product' object is not iterable``.
result = evaluate(m, intp=analysis)
assert isinstance(result, Product)
Loading
Loading