Skip to content

Commit 60bb59d

Browse files
authored
[Frontend][TFLite] Add TILE operator tests and edge cases (#19400)
This PR adds non-quantized TILE coverage in `test_frontend_tflite.py`. Related to #18971 What is included 1. One explicit Expected IR structural test for TILE 2. Parametrized TILE conversion tests covering: - baseline 2D and higher-rank cases - identity/no-op tiling - larger repeat factors - int32 non-quantized dtype path _**Note:** SHAPE and RANGE are excluded from this PR and will be handled separately because there's a related a bug in the frontend for them_ ### Validation ``` pytest test_frontend_tflite.py -v -k "test_tile_ir or test_tile" ``` <img width="1164" height="26" alt="image" src="https://github.com/user-attachments/assets/ede6c479-8b4d-4025-bb4a-2af8e132e162" />
1 parent 13867ea commit 60bb59d

1 file changed

Lines changed: 43 additions & 0 deletions

File tree

tests/python/relax/test_frontend_tflite.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,49 @@ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 2, 15), dtype="f
279279
verify(Reshape, Expected)
280280

281281

282+
def test_tile_ir():
283+
"""TILE conversion with explicit Relax IR structural check."""
284+
285+
class Tile(tf.Module):
286+
@tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)])
287+
def func(self, x):
288+
return tf.tile(x, [2, 1])
289+
290+
@I.ir_module
291+
class Expected:
292+
@R.function
293+
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((4, 3), dtype="float32"):
294+
R.func_attr({"num_input": 1})
295+
with R.dataflow():
296+
gv: R.Tensor((4, 3), dtype="float32") = R.tile(x, repeats=[2, 1])
297+
R.output(gv)
298+
return gv
299+
300+
verify(Tile, Expected)
301+
302+
303+
@pytest.mark.parametrize(
304+
"input_shape, multiples, dtype",
305+
[
306+
((2, 3), [2, 1], tf.float32),
307+
((1, 4, 2), [3, 1, 2], tf.float32),
308+
((2, 1, 3, 1), [1, 2, 1, 4], tf.float32),
309+
((2, 3), [1, 1], tf.float32),
310+
((3,), [2], tf.float32),
311+
((2, 3), [4, 2], tf.float32),
312+
((2, 2), [1, 3], tf.int32),
313+
],
314+
)
315+
def test_tile(input_shape, multiples, dtype):
316+
"""TILE conversion for non-quantized input and repeat factors."""
317+
318+
class Tile(tf.Module):
319+
@tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=dtype)])
320+
def func(self, x):
321+
return tf.tile(x, multiples)
322+
323+
verify(Tile)
324+
282325
def test_concat_v2():
283326
class ConcatV2(tf.Module):
284327
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)])

0 commit comments

Comments
 (0)