Skip to content

Commit 7ba2643

Browse files
authored
chore: add tests for statement formatter (#496)
1 parent 37a6529 commit 7ba2643

1 file changed

Lines changed: 83 additions & 0 deletions

File tree

tests/unit/common/test_typing_format.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,86 @@ def test_statement_to_set_errors(
269269

270270
def test_binary() -> None:
271271
assert Binary("abc") == b"abc"
272+
273+
274+
def test_convert_parameter_for_serialization(formatter: StatementFormatter) -> None:
275+
assert formatter.convert_parameter_for_serialization(1) == 1
276+
assert formatter.convert_parameter_for_serialization(1.1) == 1.1
277+
assert formatter.convert_parameter_for_serialization(True) is True
278+
assert formatter.convert_parameter_for_serialization(None) is None
279+
assert formatter.convert_parameter_for_serialization(Decimal("1.1")) == "1.1"
280+
assert formatter.convert_parameter_for_serialization(b"abc") == "abc"
281+
assert formatter.convert_parameter_for_serialization(
282+
[1, Decimal("1.1"), [b"x"]]
283+
) == [
284+
1,
285+
"1.1",
286+
["x"],
287+
]
288+
assert (
289+
formatter.convert_parameter_for_serialization(date(2022, 1, 1)) == "2022-01-01"
290+
)
291+
assert formatter.convert_parameter_for_serialization({"a": 1}) == "{'a': 1}"
292+
293+
294+
def test_format_bulk_insert(formatter: StatementFormatter) -> None:
295+
query = "INSERT INTO t VALUES (?, ?)"
296+
params = [[1, "a"], [2, "b"]]
297+
result = formatter.format_bulk_insert(query, params)
298+
assert result == "INSERT INTO t VALUES (1, 'a'); INSERT INTO t VALUES (2, 'b')"
299+
300+
with raises(DataError):
301+
formatter.format_bulk_insert("", [])
302+
303+
304+
def test_create_statement_formatter_invalid_version() -> None:
305+
with raises(ValueError) as excinfo:
306+
create_statement_formatter(3)
307+
assert "Unsupported version: 3" in str(excinfo.value)
308+
309+
310+
def test_patched_change_splitlevel(formatter: StatementFormatter) -> None:
311+
# Testing CREATE, DECLARE, BEGIN, END, CASE, IF, FOR, WHILE
312+
# These exercise _patched_change_splitlevel via split_format_sql which calls parse_sql
313+
314+
# CREATE, BEGIN, END
315+
sql = "CREATE PROCEDURE p AS BEGIN SELECT 1; END; SELECT 2;"
316+
results = formatter.split_format_sql(sql, None)
317+
assert len(results) == 2
318+
319+
# Testing CASE...END outside of CREATE
320+
sql = "SELECT CASE WHEN 1 THEN 'a' ELSE 'b' END FROM t; SELECT 2;"
321+
results = formatter.split_format_sql(sql, None)
322+
assert len(results) == 2
323+
assert "SELECT CASE" in str(results[0])
324+
assert "SELECT 2" in str(results[1])
325+
326+
# Testing IF, FOR, WHILE inside CREATE
327+
sql = """
328+
CREATE PROCEDURE p AS
329+
BEGIN
330+
IF 1 THEN
331+
SELECT 1;
332+
END IF;
333+
FOR i IN 1..10 LOOP
334+
SELECT i;
335+
END FOR;
336+
WHILE 1 LOOP
337+
SELECT 1;
338+
END WHILE;
339+
END;
340+
SELECT 2;
341+
"""
342+
results = formatter.split_format_sql(sql, None)
343+
# sqlparse might split at some ENDs depending on its internal state and how it tokens things
344+
assert len(results) >= 2
345+
346+
# Testing TRANSACTION, WORK etc
347+
sql = "BEGIN TRANSACTION; SELECT 1; COMMIT;"
348+
results = formatter.split_format_sql(sql, None)
349+
assert len(results) == 3
350+
351+
# Testing DECLARE
352+
sql = "CREATE PROCEDURE p AS DECLARE x INT; BEGIN SELECT x; END;"
353+
results = formatter.split_format_sql(sql, None)
354+
assert len(results) == 1

0 commit comments

Comments
 (0)