Skip to content

Commit ef511d4

Browse files
authored
feat(py_class): support frozen=True for immutable instances (#542)
## Summary - Add `frozen` parameter to `@py_class()` decorator and `field()` function, mirroring Python's `dataclasses.dataclass(frozen=True)` semantics. - **Class-level frozen**: `@py_class(frozen=True)` installs `__setattr__`/`__delattr__` guards raising `FrozenInstanceError`. Frozen classes auto-get `__hash__` (safely hashable). - **Field-level frozen**: `field(frozen=True)` sets `TypeField.frozen=True` (`fset=None`), independent of class-level frozen. - `__replace__` (`copy.replace`) uses direct `FieldSetter` for `py_class` (bypasses both frozen mechanisms) while `c_class` respects C++ readonly fields via `object.__setattr__`. ## Test plan - [x] 101 new tests in `tests/python/test_dataclass_frozen.py` covering: - Class-level frozen basic behavior and error messages - Field-level frozen (individual field override) - Inheritance (frozen parent + mutable child, mixed) - Interactions with `eq`, `hash`, `order`, `copy`, `replace` - `object.__setattr__` escape hatch - `FrozenInstanceError` is `AttributeError` subclass - [x] Full test suite: 2105 passed, 0 failed, 38 skipped, 3 xfailed
1 parent e10f635 commit ef511d4

11 files changed

Lines changed: 717 additions & 18 deletions

File tree

python/tvm_ffi/_dunder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,9 @@ def _make_replace(_type_info: TypeInfo) -> Callable[..., Any]:
193193

194194
def __replace__(self: Any, **kwargs: Any) -> Any:
195195
obj = copy_copy(self)
196+
cls = type(obj)
196197
for key, value in kwargs.items():
197-
setattr(obj, key, value)
198+
getattr(cls, key).set(obj, value)
198199
return obj
199200

200201
return __replace__

python/tvm_ffi/cython/type_info.pxi

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,24 @@ def _annotation_cobject(cls, targs):
593593
return TypeSchema(origin, origin_type_index=info.type_index)
594594

595595

596+
class FFIProperty(property):
597+
"""Property descriptor for FFI-backed fields.
598+
599+
When *frozen* is True the public setter (``fset``) is suppressed so
600+
that normal attribute assignment raises ``AttributeError``. The
601+
real setter is stashed in :attr:`_fset` and exposed via the
602+
:meth:`set` escape-hatch.
603+
"""
604+
605+
def __init__(self, fget, fset, frozen, fdel=None, doc=None):
606+
super().__init__(fget, None if frozen else fset, fdel, doc)
607+
self._fset = fset
608+
609+
def set(self, obj, value):
610+
"""Force-set the field value, bypassing the frozen guard."""
611+
self._fset(obj, value)
612+
613+
596614
@dataclasses.dataclass(eq=False)
597615
class TypeField:
598616
"""Description of a single reflected field on an FFI-backed type."""
@@ -616,17 +634,18 @@ class TypeField:
616634
assert self.getter is not None
617635

618636
def as_property(self, object cls):
619-
"""Create a Python ``property`` object for this field on ``cls``."""
637+
"""Create an :class:`FFIProperty` descriptor for this field on ``cls``."""
620638
cdef str name = self.name
621639
cdef FieldGetter fget = self.getter
622640
cdef FieldSetter fset = self.setter
623641
cdef object ret
624642
fget.__name__ = fset.__name__ = name
625643
fget.__module__ = fset.__module__ = cls.__module__
626644
fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}"
627-
ret = property(
645+
ret = FFIProperty(
628646
fget=fget,
629-
fset=fset if (not self.frozen) else None,
647+
fset=fset,
648+
frozen=self.frozen,
630649
)
631650
if self.doc:
632651
ret.__doc__ = self.doc
@@ -1003,7 +1022,7 @@ def _register_fields(type_info, fields, structure_kind=None):
10031022
doc=py_field.doc,
10041023
size=size,
10051024
offset=field_offset,
1006-
frozen=False,
1025+
frozen=py_field.frozen,
10071026
metadata={"type_schema": py_field.ty.to_json()},
10081027
getter=fgetter,
10091028
setter=fsetter,

python/tvm_ffi/dataclasses/field.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class Field:
5858
default_factory : Callable[[], object] | None
5959
A zero-argument callable that produces the default value.
6060
Mutually exclusive with *default*. ``None`` when not set.
61+
frozen : bool
62+
Whether this field is read-only after ``__init__``.
6163
init : bool
6264
Whether this field appears in the auto-generated ``__init__``.
6365
repr : bool
@@ -91,6 +93,7 @@ class Field:
9193
"default",
9294
"default_factory",
9395
"doc",
96+
"frozen",
9497
"hash",
9598
"init",
9699
"kw_only",
@@ -103,6 +106,7 @@ class Field:
103106
ty: TypeSchema | None
104107
default: object
105108
default_factory: Callable[[], object] | None
109+
frozen: bool
106110
init: bool
107111
repr: bool
108112
hash: bool | None
@@ -123,6 +127,7 @@ def __init__( # noqa: PLR0913
123127
*,
124128
default: object = MISSING,
125129
default_factory: Callable[[], object] | None = MISSING, # type: ignore[assignment]
130+
frozen: bool = False,
126131
init: bool = True,
127132
repr: bool = True,
128133
hash: bool | None = True,
@@ -151,6 +156,7 @@ def __init__( # noqa: PLR0913
151156
self.ty = ty
152157
self.default = default
153158
self.default_factory = default_factory
159+
self.frozen = frozen
154160
self.init = init
155161
self.repr = repr
156162
self.hash = hash
@@ -164,6 +170,7 @@ def field(
164170
*,
165171
default: object = MISSING,
166172
default_factory: Callable[[], object] | None = MISSING, # type: ignore[assignment]
173+
frozen: bool = False,
167174
init: bool = True,
168175
repr: bool = True,
169176
hash: bool | None = None,
@@ -189,6 +196,11 @@ def field(
189196
default_factory
190197
A zero-argument callable that produces the default value.
191198
Mutually exclusive with *default*.
199+
frozen
200+
Whether this field is read-only after ``__init__``. When True,
201+
the Python property descriptor has no setter; use the
202+
``type(obj).field_name.set(obj, value)`` escape hatch when
203+
mutation is necessary.
192204
init
193205
Whether this field appears in the auto-generated ``__init__``.
194206
repr
@@ -234,6 +246,7 @@ class MyFunc(Object):
234246
return Field(
235247
default=default,
236248
default_factory=default_factory,
249+
frozen=frozen,
237250
init=init,
238251
repr=repr,
239252
hash=hash,

python/tvm_ffi/dataclasses/py_class.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,11 @@ def _rollback_registration(cls: type, type_info: Any) -> None:
134134
# ---------------------------------------------------------------------------
135135

136136

137-
def _collect_own_fields(
137+
def _collect_own_fields( # noqa: PLR0912
138138
cls: type,
139139
hints: dict[str, Any],
140140
decorator_kw_only: bool,
141+
decorator_frozen: bool,
141142
) -> list[Field]:
142143
"""Parse own annotations into :class:`Field` objects.
143144
@@ -194,6 +195,10 @@ def _collect_own_fields(
194195
if f.kw_only is None:
195196
f.kw_only = kw_only_active
196197

198+
# Apply class-level frozen when the field doesn't explicitly set it
199+
if decorator_frozen and not f.frozen:
200+
f.frozen = True
201+
197202
# Resolve hash=None → follow compare (native dataclass semantics)
198203
if f.hash is None:
199204
f.hash = f.compare
@@ -248,7 +253,7 @@ def _register_fields_into_type(
248253
except (NameError, AttributeError):
249254
return False
250255

251-
own_fields = _collect_own_fields(cls, hints, params["kw_only"])
256+
own_fields = _collect_own_fields(cls, hints, params["kw_only"], params["frozen"])
252257
py_methods = _collect_py_methods(cls)
253258

254259
# Register fields and type-level structural eq/hash kind with the C layer.
@@ -414,11 +419,12 @@ def _install_deferred_init(
414419
order_default=False,
415420
field_specifiers=(field, Field),
416421
)
417-
def py_class(
422+
def py_class( # noqa: PLR0913
418423
cls_or_type_key: type | str | None = None,
419424
/,
420425
*,
421426
type_key: str | None = None,
427+
frozen: bool = False,
422428
init: bool = True,
423429
repr: bool = True,
424430
eq: bool = False,
@@ -465,6 +471,11 @@ class MyNode(Object):
465471
type_key
466472
Explicit FFI type key. Auto-generated from
467473
``{module}.{qualname}`` when omitted.
474+
frozen
475+
If True, all fields are read-only after ``__init__`` by default.
476+
Individual fields can still be marked ``field(frozen=True)`` on a
477+
non-frozen class. Use ``type(obj).field_name.set(obj, value)``
478+
as an escape hatch when mutation is necessary.
468479
init
469480
If True (default), generate ``__init__`` from field annotations.
470481
repr
@@ -514,6 +525,7 @@ class MyNode(Object):
514525

515526
effective_type_key = type_key
516527
params: dict[str, Any] = {
528+
"frozen": frozen,
517529
"init": init,
518530
"repr": repr,
519531
"eq": eq,

tests/python/test_cubin_launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _compile_kernel_to_cubin() -> bytes:
9393
)
9494

9595
if result.returncode != 0:
96-
pytest.skip(f"nvcc not available or compilation failed: {result.stderr}")
96+
pytest.skip(f"nvcc not available or compilation failed: {result.stderr}") # ty: ignore[invalid-argument-type, too-many-positional-arguments]
9797

9898
return cubin_file.read_bytes()
9999

tests/python/test_dataclass_copy.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -934,10 +934,14 @@ def test_original_unchanged(self) -> None:
934934
obj.__replace__(v_i64=100) # ty: ignore[unresolved-attribute]
935935
assert obj.v_i64 == 5 # ty: ignore[unresolved-attribute]
936936

937-
def test_replace_readonly_field_raises(self) -> None:
937+
def test_replace_readonly_field(self) -> None:
938+
# __replace__ uses the FFIProperty.set() escape hatch,
939+
# so it works even on frozen / read-only fields.
938940
pair = tvm_ffi.testing.TestIntPair(3, 4)
939-
with pytest.raises(AttributeError):
940-
pair.__replace__(a=10) # ty: ignore[unresolved-attribute]
941+
pair2 = pair.__replace__(a=10) # ty: ignore[unresolved-attribute]
942+
assert pair2.a == 10
943+
assert pair2.b == 4
944+
assert pair.a == 3 # original unchanged
941945

942946
def test_auto_replace_for_cxx_class(self) -> None:
943947
# _TestCxxClassBase is copy-constructible, so replace is auto-enabled

0 commit comments

Comments
 (0)