Skip to content

Commit e5d4c55

Browse files
authored
[relax][frontend][tflite] Add tests for fully_connected/depthwise_conv2d/transpose_conv/l2_pool2d (#19372)
This PR adds Relax TFLite frontend test coverage for: - FULLY_CONNECTED - DEPTHWISE_CONV_2D - TRANSPOSE_CONV - L2_POOL_2D Part of fixing #18971.
1 parent 645fcf9 commit e5d4c55

2 files changed

Lines changed: 82 additions & 3 deletions

File tree

python/tvm/relax/frontend/tflite/tflite_frontend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3120,8 +3120,6 @@ def convert_transpose_conv(self, op):
31203120
weight_expr_iohw,
31213121
strides=(stride_h, stride_w),
31223122
padding=padding,
3123-
channels=int(out_channels),
3124-
kernel_size=(int(kernel_h), int(kernel_w)),
31253123
data_layout="NHWC",
31263124
kernel_layout="IOHW",
31273125
out_dtype=output_tensor_type_str,

tests/python/relax/test_frontend_tflite.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,88 @@ def main(x: R.Tensor((5, 30), dtype="float32")) -> R.Tensor(out_shape, dtype="in
710710
verify(TfInput, Expected)
711711

712712

713+
def test_fully_connected():
714+
class FullyConnected(tf.Module):
715+
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 8), dtype=tf.float32)])
716+
def func(self, x):
717+
weight = tf.constant(np.arange(24, dtype=np.float32).reshape((3, 8)))
718+
bias = tf.constant(np.array([0.5, 1.0, -1.0], dtype=np.float32))
719+
out = tf.matmul(x, weight, transpose_b=True)
720+
return tf.nn.bias_add(out, bias)
721+
722+
verify(FullyConnected)
723+
724+
725+
def test_depthwise_conv2d():
726+
class DepthwiseConv2D(tf.Module):
727+
@tf.function(
728+
input_signature=[
729+
tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32),
730+
tf.TensorSpec(shape=(3, 3, 2, 1), dtype=tf.float32),
731+
]
732+
)
733+
def func(self, data, kernel):
734+
return tf.nn.depthwise_conv2d(
735+
input=data,
736+
filter=kernel,
737+
strides=[1, 1, 1, 1],
738+
padding="SAME",
739+
)
740+
741+
verify(DepthwiseConv2D)
742+
743+
744+
def test_transpose_conv():
745+
class TransposeConv(tf.Module):
746+
@tf.function(
747+
input_signature=[
748+
tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32),
749+
tf.TensorSpec(shape=(3, 3, 3, 2), dtype=tf.float32),
750+
]
751+
)
752+
def func(self, data, kernel):
753+
output_shape = tf.constant([1, 8, 8, 3], dtype=tf.int32)
754+
return tf.nn.conv2d_transpose(
755+
input=data,
756+
filters=kernel,
757+
output_shape=output_shape,
758+
strides=[1, 1, 1, 1],
759+
padding="SAME",
760+
)
761+
762+
verify(TransposeConv)
763+
764+
def test_l2_pool2d():
765+
class L2Pool2D(tf.Module):
766+
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32)])
767+
def func(self, data):
768+
squared = tf.math.square(data)
769+
pooled = tf.nn.avg_pool2d(squared, ksize=[2, 2], strides=[1, 1], padding="SAME")
770+
return tf.math.sqrt(pooled)
771+
772+
@I.ir_module
773+
class Expected:
774+
@R.function
775+
def main(
776+
data: R.Tensor((1, 8, 8, 2), dtype="float32")
777+
) -> R.Tensor((1, 8, 8, 2), dtype="float32"):
778+
R.func_attr({"num_input": 1})
779+
with R.dataflow():
780+
squared = R.power(data, R.const(2.0, "float32"))
781+
pooled = R.nn.avg_pool2d(
782+
squared,
783+
pool_size=[2, 2],
784+
strides=[1, 1],
785+
padding=[0, 0, 1, 1],
786+
layout="NHWC",
787+
)
788+
gv = R.sqrt(pooled)
789+
R.output(gv)
790+
return gv
791+
792+
verify(L2Pool2D, Expected)
793+
794+
713795
def test_l2_normalization():
714796
class L2Normalization(tf.Module):
715797
@tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.float32)])
@@ -758,7 +840,6 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3
758840

759841
verify(ReverseV2, Expected)
760842

761-
762843
def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding):
763844
class Conv2DModule(tf.Module):
764845
@tf.function(

0 commit comments

Comments
 (0)