@@ -710,6 +710,88 @@ def main(x: R.Tensor((5, 30), dtype="float32")) -> R.Tensor(out_shape, dtype="in
710710 verify (TfInput , Expected )
711711
712712
713+ def test_fully_connected ():
714+ class FullyConnected (tf .Module ):
715+ @tf .function (input_signature = [tf .TensorSpec (shape = (1 , 8 ), dtype = tf .float32 )])
716+ def func (self , x ):
717+ weight = tf .constant (np .arange (24 , dtype = np .float32 ).reshape ((3 , 8 )))
718+ bias = tf .constant (np .array ([0.5 , 1.0 , - 1.0 ], dtype = np .float32 ))
719+ out = tf .matmul (x , weight , transpose_b = True )
720+ return tf .nn .bias_add (out , bias )
721+
722+ verify (FullyConnected )
723+
724+
725+ def test_depthwise_conv2d ():
726+ class DepthwiseConv2D (tf .Module ):
727+ @tf .function (
728+ input_signature = [
729+ tf .TensorSpec (shape = (1 , 8 , 8 , 2 ), dtype = tf .float32 ),
730+ tf .TensorSpec (shape = (3 , 3 , 2 , 1 ), dtype = tf .float32 ),
731+ ]
732+ )
733+ def func (self , data , kernel ):
734+ return tf .nn .depthwise_conv2d (
735+ input = data ,
736+ filter = kernel ,
737+ strides = [1 , 1 , 1 , 1 ],
738+ padding = "SAME" ,
739+ )
740+
741+ verify (DepthwiseConv2D )
742+
743+
744+ def test_transpose_conv ():
745+ class TransposeConv (tf .Module ):
746+ @tf .function (
747+ input_signature = [
748+ tf .TensorSpec (shape = (1 , 8 , 8 , 2 ), dtype = tf .float32 ),
749+ tf .TensorSpec (shape = (3 , 3 , 3 , 2 ), dtype = tf .float32 ),
750+ ]
751+ )
752+ def func (self , data , kernel ):
753+ output_shape = tf .constant ([1 , 8 , 8 , 3 ], dtype = tf .int32 )
754+ return tf .nn .conv2d_transpose (
755+ input = data ,
756+ filters = kernel ,
757+ output_shape = output_shape ,
758+ strides = [1 , 1 , 1 , 1 ],
759+ padding = "SAME" ,
760+ )
761+
762+ verify (TransposeConv )
763+
764+ def test_l2_pool2d ():
765+ class L2Pool2D (tf .Module ):
766+ @tf .function (input_signature = [tf .TensorSpec (shape = (1 , 8 , 8 , 2 ), dtype = tf .float32 )])
767+ def func (self , data ):
768+ squared = tf .math .square (data )
769+ pooled = tf .nn .avg_pool2d (squared , ksize = [2 , 2 ], strides = [1 , 1 ], padding = "SAME" )
770+ return tf .math .sqrt (pooled )
771+
772+ @I .ir_module
773+ class Expected :
774+ @R .function
775+ def main (
776+ data : R .Tensor ((1 , 8 , 8 , 2 ), dtype = "float32" )
777+ ) -> R .Tensor ((1 , 8 , 8 , 2 ), dtype = "float32" ):
778+ R .func_attr ({"num_input" : 1 })
779+ with R .dataflow ():
780+ squared = R .power (data , R .const (2.0 , "float32" ))
781+ pooled = R .nn .avg_pool2d (
782+ squared ,
783+ pool_size = [2 , 2 ],
784+ strides = [1 , 1 ],
785+ padding = [0 , 0 , 1 , 1 ],
786+ layout = "NHWC" ,
787+ )
788+ gv = R .sqrt (pooled )
789+ R .output (gv )
790+ return gv
791+
792+ verify (L2Pool2D , Expected )
793+
794+
713795def test_l2_normalization ():
714796 class L2Normalization (tf .Module ):
715797 @tf .function (input_signature = [tf .TensorSpec (shape = (2 , 4 ), dtype = tf .float32 )])
@@ -758,7 +840,6 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3
758840
759841 verify (ReverseV2 , Expected )
760842
761-
762843def _make_conv2d_module (data_shape , kernel_shape , data_format , strides , padding ):
763844 class Conv2DModule (tf .Module ):
764845 @tf .function (
0 commit comments