Skip to content

Commit e809f98

Browse files
authored
feat(container): add structural __eq__/__ne__/__hash__ to Array, List, Map, Dict (#545)
## Summary - Add structural `__eq__`, `__ne__`, and `__hash__` methods to the four container classes (`Array`, `List`, `Map`, `Dict`) in `python/tvm_ffi/container.py` - Delegates to existing `RecursiveEq` and `RecursiveHash` C++ FFI functions — the same infrastructure used by `_install_dataclass_dunders` in `_dunder.py` for `@c_class`/`@py_class` - Returns `NotImplemented` for unrelated types so Python's default comparison fallback applies - `Shape`, `String`, and `Bytes` are unchanged (already inherit correct behavior from `tuple`, `str`, `bytes`) ## Breaking Change Code that relied on identity-based equality for containers will now see structural equality instead. For example, `Array([1, 2]) == Array([1, 2])` now returns `True` (previously `False`). Two container objects with the same contents now compare equal and produce the same hash. ## Test Plan - [x] 17 new tests added in `tests/python/test_container.py` covering: - Structural equality and inequality for all four container types - Empty containers - Nested containers (Array of Arrays) - `NotImplemented` return for unrelated types (plain list, dict, str) - Hash consistency (equal objects produce equal hashes) - Usability as set members and dict keys - [x] Full Python test suite passes: 2072 passed, 38 skipped, 3 xfailed - [x] All pre-commit hooks pass (ruff, ty check, etc.)
1 parent ef511d4 commit e809f98

2 files changed

Lines changed: 180 additions & 0 deletions

File tree

python/tvm_ffi/container.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,22 @@ def __contains__(self, value: object) -> bool:
201201
"""Check if the array contains a value."""
202202
return _ffi_api.ArrayContains(self, value)
203203

204+
def __eq__(self, other: object) -> bool:
205+
"""Structural equality."""
206+
if not (isinstance(other, type(self)) or isinstance(self, type(other))):
207+
return NotImplemented
208+
return _ffi_api.RecursiveEq(self, other)
209+
210+
def __ne__(self, other: object) -> bool:
211+
"""Structural inequality."""
212+
if not (isinstance(other, type(self)) or isinstance(self, type(other))):
213+
return NotImplemented
214+
return not _ffi_api.RecursiveEq(self, other)
215+
216+
def __hash__(self) -> int:
217+
"""Structural hash."""
218+
return _ffi_api.RecursiveHash(self)
219+
204220
def __bool__(self) -> bool:
205221
"""Return True if the array is non-empty."""
206222
return len(self) > 0
@@ -344,6 +360,22 @@ def __contains__(self, value: object) -> bool:
344360
"""Check if the list contains a value."""
345361
return _ffi_api.ListContains(self, value)
346362

363+
def __eq__(self, other: object) -> bool:
364+
"""Structural equality."""
365+
if not (isinstance(other, type(self)) or isinstance(self, type(other))):
366+
return NotImplemented
367+
return _ffi_api.RecursiveEq(self, other)
368+
369+
def __ne__(self, other: object) -> bool:
370+
"""Structural inequality."""
371+
if not (isinstance(other, type(self)) or isinstance(self, type(other))):
372+
return NotImplemented
373+
return not _ffi_api.RecursiveEq(self, other)
374+
375+
def __hash__(self) -> int:
376+
"""Structural hash."""
377+
return _ffi_api.RecursiveHash(self)
378+
347379
def __bool__(self) -> bool:
348380
"""Return True if the list is non-empty."""
349381
return len(self) > 0
@@ -499,6 +531,22 @@ def __contains__(self, k: object) -> bool:
499531
"""Return True if the map contains key `k`."""
500532
return _ffi_api.MapCount(self, k) != 0
501533

534+
def __eq__(self, other: object) -> bool:
535+
"""Structural equality."""
536+
if not (isinstance(other, type(self)) or isinstance(self, type(other))):
537+
return NotImplemented
538+
return _ffi_api.RecursiveEq(self, other)
539+
540+
def __ne__(self, other: object) -> bool:
541+
"""Structural inequality."""
542+
if not (isinstance(other, type(self)) or isinstance(self, type(other))):
543+
return NotImplemented
544+
return not _ffi_api.RecursiveEq(self, other)
545+
546+
def __hash__(self) -> int:
547+
"""Structural hash."""
548+
return _ffi_api.RecursiveHash(self)
549+
502550
def keys(self) -> KeysView[K]:
503551
"""Return a dynamic view of the map's keys."""
504552
return KeysView(self)
@@ -607,6 +655,22 @@ def __contains__(self, k: object) -> bool:
607655
"""Return True if the dict contains key `k`."""
608656
return _ffi_api.DictCount(self, k) != 0
609657

658+
def __eq__(self, other: object) -> bool:
659+
"""Structural equality."""
660+
if not (isinstance(other, type(self)) or isinstance(self, type(other))):
661+
return NotImplemented
662+
return _ffi_api.RecursiveEq(self, other)
663+
664+
def __ne__(self, other: object) -> bool:
665+
"""Structural inequality."""
666+
if not (isinstance(other, type(self)) or isinstance(self, type(other))):
667+
return NotImplemented
668+
return not _ffi_api.RecursiveEq(self, other)
669+
670+
def __hash__(self) -> int:
671+
"""Structural hash."""
672+
return _ffi_api.RecursiveHash(self)
673+
610674
def __len__(self) -> int:
611675
"""Return the number of items in the dict."""
612676
return _ffi_api.DictSize(self)

tests/python/test_container.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,3 +733,119 @@ def test_map_cross_conv_incompatible_map_to_dict() -> None:
733733
m = tvm_ffi.Map({"a": "not_int", "b": "still_not_int"})
734734
with pytest.raises(TypeError):
735735
testing.schema_id_dict_str_int(m) # type: ignore[invalid-argument-type]
736+
737+
738+
# ---------------------------------------------------------------------------
739+
# Structural __eq__ / __ne__ / __hash__ tests
740+
# ---------------------------------------------------------------------------
741+
742+
743+
def test_array_structural_eq() -> None:
744+
a = tvm_ffi.Array([1, 2, 3])
745+
b = tvm_ffi.Array([1, 2, 3])
746+
c = tvm_ffi.Array([1, 2, 4])
747+
assert a == b
748+
assert a != c
749+
assert not (a != b)
750+
assert not (a == c)
751+
752+
753+
def test_array_eq_empty() -> None:
754+
assert tvm_ffi.Array([]) == tvm_ffi.Array([])
755+
756+
757+
def test_array_eq_nested() -> None:
758+
a = tvm_ffi.Array([tvm_ffi.Array([1, 2]), tvm_ffi.Array([3])])
759+
b = tvm_ffi.Array([tvm_ffi.Array([1, 2]), tvm_ffi.Array([3])])
760+
c = tvm_ffi.Array([tvm_ffi.Array([1, 2]), tvm_ffi.Array([4])])
761+
assert a == b
762+
assert a != c
763+
764+
765+
def test_array_eq_not_implemented_for_unrelated() -> None:
766+
a = tvm_ffi.Array([1, 2, 3])
767+
assert a.__eq__([1, 2, 3]) is NotImplemented
768+
assert a.__ne__([1, 2, 3]) is NotImplemented
769+
assert a.__eq__("hello") is NotImplemented
770+
771+
772+
def test_array_hash() -> None:
773+
a = tvm_ffi.Array([1, 2, 3])
774+
b = tvm_ffi.Array([1, 2, 3])
775+
assert hash(a) == hash(b)
776+
# Usable in sets and as dict keys
777+
s = {a, b}
778+
assert len(s) == 1
779+
d = {a: "value"}
780+
assert d[b] == "value"
781+
782+
783+
def test_list_structural_eq() -> None:
784+
a = tvm_ffi.List([1, 2, 3])
785+
b = tvm_ffi.List([1, 2, 3])
786+
c = tvm_ffi.List([1, 2, 4])
787+
assert a == b
788+
assert a != c
789+
790+
791+
def test_list_eq_empty() -> None:
792+
assert tvm_ffi.List([]) == tvm_ffi.List([])
793+
794+
795+
def test_list_eq_not_implemented_for_unrelated() -> None:
796+
a = tvm_ffi.List([1, 2, 3])
797+
assert a.__eq__([1, 2, 3]) is NotImplemented
798+
799+
800+
def test_list_hash() -> None:
801+
a = tvm_ffi.List([1, 2, 3])
802+
b = tvm_ffi.List([1, 2, 3])
803+
assert hash(a) == hash(b)
804+
805+
806+
def test_map_structural_eq() -> None:
807+
a = tvm_ffi.Map({"x": 1, "y": 2})
808+
b = tvm_ffi.Map({"x": 1, "y": 2})
809+
c = tvm_ffi.Map({"x": 1, "y": 3})
810+
assert a == b
811+
assert a != c
812+
813+
814+
def test_map_eq_empty() -> None:
815+
assert tvm_ffi.Map({}) == tvm_ffi.Map({})
816+
817+
818+
def test_map_eq_not_implemented_for_unrelated() -> None:
819+
a = tvm_ffi.Map({"x": 1})
820+
assert a.__eq__({"x": 1}) is NotImplemented
821+
822+
823+
def test_map_hash() -> None:
824+
a = tvm_ffi.Map({"x": 1, "y": 2})
825+
b = tvm_ffi.Map({"x": 1, "y": 2})
826+
assert hash(a) == hash(b)
827+
s = {a, b}
828+
assert len(s) == 1
829+
830+
831+
def test_dict_structural_eq() -> None:
832+
a = tvm_ffi.Dict({"x": 1, "y": 2})
833+
b = tvm_ffi.Dict({"x": 1, "y": 2})
834+
c = tvm_ffi.Dict({"x": 1, "y": 3})
835+
assert a == b
836+
assert a != c
837+
838+
839+
def test_dict_eq_empty() -> None:
840+
assert tvm_ffi.Dict({}) == tvm_ffi.Dict({})
841+
842+
843+
def test_dict_eq_not_implemented_for_unrelated() -> None:
844+
a = tvm_ffi.Dict({"x": 1})
845+
assert a.__eq__({"x": 1}) is NotImplemented
846+
847+
848+
def test_dict_hash() -> None:
849+
a = tvm_ffi.Dict({"x": 1, "y": 2})
850+
b = tvm_ffi.Dict({"x": 1, "y": 2})
851+
assert hash(a) == hash(b)

0 commit comments

Comments
 (0)