3434# TODO(epot): Test dtype `complex`, `str`
3535
3636
37+ @dca .dataclass_array (broadcast = True , cast_dtype = True )
3738@dataclasses .dataclass (frozen = True )
3839class Point (dca .DataclassArray ):
3940 x : f32 ['*shape' ]
4041 y : f32 ['*shape' ]
4142
4243
44+ @dca .dataclass_array (broadcast = True , cast_dtype = True )
4345@dataclasses .dataclass (frozen = True )
4446class PointWrapper (dca .DataclassArray ):
4547 pts : Point
4648 rgb : f32 ['*shape 3' ]
4749
4850
51+ @dca .dataclass_array (broadcast = True , cast_dtype = True )
4952@dataclasses .dataclass (frozen = True )
5053class Isometrie (dca .DataclassArray ):
5154 r : f32 ['... 3 3' ]
5255 t : i32 [..., 2 ]
5356
5457
58+ @dca .dataclass_array (broadcast = True , cast_dtype = True )
5559@dataclasses .dataclass (frozen = True )
5660class Nested (dca .DataclassArray ):
5761 # pytype: disable=annotation-type-mismatch
@@ -61,6 +65,7 @@ class Nested(dca.DataclassArray):
6165 # pytype: enable=annotation-type-mismatch
6266
6367
68+ @dca .dataclass_array (broadcast = True , cast_dtype = True )
6469@dataclasses .dataclass (frozen = True )
6570class WithStatic (dca .DataclassArray ):
6671 """Mix of static and array fields."""
@@ -541,8 +546,6 @@ def test_dataclass_params_no_cast(xnp: enp.NpModule):
541546
542547 @dataclasses .dataclass (frozen = True )
543548 class PointNoCast (dca .DataclassArray ):
544- __dca_params__ = dca .DataclassParams (cast_dtype = False )
545-
546549 x : FloatArray ['*shape' ]
547550 y : IntArray ['*shape' ]
548551
@@ -564,10 +567,9 @@ class PointNoCast(dca.DataclassArray):
564567@enp .testing .parametrize_xnp ()
565568def test_dataclass_params_no_list (xnp : enp .NpModule ):
566569
570+ @dca .dataclass_array (cast_list = False )
567571 @dataclasses .dataclass (frozen = True )
568572 class PointNoList (dca .DataclassArray ):
569- __dca_params__ = dca .DataclassParams (cast_list = False )
570-
571573 x : FloatArray ['*shape' ]
572574 y : IntArray ['*shape' ]
573575
@@ -583,15 +585,13 @@ def test_dataclass_params_no_broadcast(xnp: enp.NpModule):
583585
584586 @dataclasses .dataclass (frozen = True )
585587 class PointNoBroadcast (dca .DataclassArray ):
586- __dca_params__ = dca .DataclassParams (broadcast = False )
587-
588588 x : FloatArray ['*shape' ]
589589 y : IntArray ['*shape' ]
590590
591591 with pytest .raises (ValueError , match = 'Cannot broadcast' ):
592592 PointNoBroadcast (
593593 x = xnp .array (1 , dtype = np .float16 ),
594- y = xnp .array ([1 , 2 , 3 ], dtype = np .float16 ),
594+ y = xnp .array ([1 , 2 , 3 ], dtype = np .int32 ),
595595 )
596596
597597
0 commit comments