Skip to content

Commit 7f25681

Browse files
authored
Merge pull request #1834 from dbcli/RW/migrate-show-warnings-from-main-to-special
Move show_warnings from `main.py` to `special/iocommands.py`
2 parents b15fc19 + 8a70ca4 commit 7f25681

10 files changed

Lines changed: 138 additions & 101 deletions

File tree

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Internal
4949
* Move CLI utilities to a new `cli_utils.py`.
5050
* Move keybinding utilities to a new `key_binding_utils.py`.
5151
* Move interactive utilities to `interactive_utils.py`.
52+
* Move special commands out of `main.py`.
5253
* Modernize orthography of prompt_toolkit filters.
5354
* Pin all GitHub Actions to hashes.
5455

mycli/main.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __init__(
144144
auto_vertical_output: bool = False,
145145
warn: bool | None = None,
146146
myclirc: str = "~/.myclirc",
147+
show_warnings: bool | None = None,
147148
) -> None:
148149
self.sqlexecute = sqlexecute
149150
self.logfile = logfile
@@ -178,6 +179,10 @@ def __init__(
178179
self.vi_ttimeoutlen = c['keys'].as_float('vi_ttimeoutlen')
179180
special.set_timing_enabled(c["main"].as_bool("timing"))
180181
special.set_show_favorite_query(c["main"].as_bool("show_favorite_query"))
182+
if show_warnings is not None:
183+
special.set_show_warnings_enabled(show_warnings)
184+
else:
185+
special.set_show_warnings_enabled(c['main'].as_bool('show_warnings'))
181186
self.beep_after_seconds = float(c["main"]["beep_after_seconds"] or 0)
182187
self.default_keepalive_ticks = c['connection'].as_int('default_keepalive_ticks')
183188

@@ -223,7 +228,6 @@ def __init__(
223228

224229
# read from cli argument or user config file
225230
self.auto_vertical_output = auto_vertical_output or c["main"].as_bool("auto_vertical_output")
226-
self.show_warnings = c["main"].as_bool("show_warnings")
227231

228232
# Write user config if system config wasn't the last config loaded.
229233
if c.filename not in self.system_config_files and not os.path.exists(myclirc):
@@ -328,22 +332,6 @@ def register_special_commands(self) -> None:
328332
aliases=["\\Tr"],
329333
case_sensitive=True,
330334
)
331-
special.register_special_command(
332-
self.disable_show_warnings,
333-
"nowarnings",
334-
"nowarnings",
335-
"Disable automatic warnings display.",
336-
aliases=["\\w"],
337-
case_sensitive=True,
338-
)
339-
special.register_special_command(
340-
self.enable_show_warnings,
341-
"warnings",
342-
"warnings",
343-
"Enable automatic warnings display.",
344-
aliases=["\\W"],
345-
case_sensitive=True,
346-
)
347335
special.register_special_command(
348336
self.execute_from_file, "source", "source <filename>", "Execute queries from a file.", aliases=["\\."]
349337
)
@@ -363,16 +351,6 @@ def manual_reconnect(self, arg: str = "", **_) -> Generator[SQLResult, None, Non
363351
else:
364352
yield self.change_db(arg).send(None)
365353

366-
def enable_show_warnings(self, **_) -> Generator[SQLResult, None, None]:
367-
self.show_warnings = True
368-
msg = "Show warnings enabled."
369-
yield SQLResult(status=msg)
370-
371-
def disable_show_warnings(self, **_) -> Generator[SQLResult, None, None]:
372-
self.show_warnings = False
373-
msg = "Show warnings disabled."
374-
yield SQLResult(status=msg)
375-
376354
def change_table_format(self, arg: str, **_) -> Generator[SQLResult, None, None]:
377355
try:
378356
self.main_formatter.format_name = arg
@@ -557,7 +535,6 @@ def connect(
557535
use_keyring: bool | None = None,
558536
reset_keyring: bool | None = None,
559537
keepalive_ticks: int | None = None,
560-
show_warnings: bool | None = None,
561538
) -> None:
562539
cnf = {
563540
"database": None,
@@ -587,8 +564,6 @@ def connect(
587564
ssl_config: dict[str, Any] = ssl or {}
588565
user_connection_config = self.config_without_package_defaults.get('connection', {})
589566
self.keepalive_ticks = keepalive_ticks
590-
if show_warnings is not None:
591-
self.show_warnings = show_warnings
592567

593568
int_port = port and int(port)
594569
if not int_port:
@@ -1088,7 +1063,7 @@ def run_query(
10881063
click.echo(line, nl=new_line)
10891064

10901065
# get and display warnings if enabled
1091-
if self.show_warnings and isinstance(result.rows, Cursor) and result.rows.warning_count > 0:
1066+
if special.is_show_warnings_enabled() and isinstance(result.rows, Cursor) and result.rows.warning_count > 0:
10921067
warnings = self.sqlexecute.run("SHOW WARNINGS")
10931068
for warning in warnings:
10941069
output = self.format_sqlresult(
@@ -1555,6 +1530,7 @@ def get_password_from_file(password_file: str | None) -> str | None:
15551530
auto_vertical_output=cli_args.auto_vertical_output,
15561531
warn=cli_args.warn,
15571532
myclirc=cli_args.myclirc,
1533+
show_warnings=cli_args.show_warnings,
15581534
)
15591535

15601536
if cli_args.checkup:
@@ -1916,7 +1892,6 @@ def get_password_from_file(password_file: str | None) -> str | None:
19161892
use_keyring=use_keyring,
19171893
reset_keyring=reset_keyring,
19181894
keepalive_ticks=keepalive_ticks,
1919-
show_warnings=cli_args.show_warnings,
19201895
)
19211896

19221897
if combined_init_cmd:

mycli/main_modes/repl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def _output_results(
415415
result_count += 1
416416
state.mutating = state.mutating or is_mutating(result.status_plain)
417417

418-
if mycli.show_warnings and isinstance(result.rows, Cursor) and result.rows.warning_count > 0:
418+
if special.is_show_warnings_enabled() and isinstance(result.rows, Cursor) and result.rows.warning_count > 0:
419419
warnings = sqlexecute.run('SHOW WARNINGS')
420420
warnings_duration = time.time() - start
421421
saw_warning = False

mycli/packages/special/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
close_tee,
99
copy_query_to_clipboard,
1010
disable_pager,
11+
disable_show_warnings,
1112
editor_command,
13+
enable_show_warnings,
1214
flush_pipe_once_if_written,
1315
forced_horizontal,
1416
get_clip_query,
@@ -19,6 +21,7 @@
1921
is_pager_enabled,
2022
is_redirected,
2123
is_show_favorite_query,
24+
is_show_warnings_enabled,
2225
is_timing_enabled,
2326
open_external_editor,
2427
set_delimiter,
@@ -30,6 +33,7 @@
3033
set_pager_enabled,
3134
set_redirect,
3235
set_show_favorite_query,
36+
set_show_warnings_enabled,
3337
set_timing_enabled,
3438
split_queries,
3539
unset_once_if_written,
@@ -58,7 +62,9 @@
5862
'close_tee',
5963
'copy_query_to_clipboard',
6064
'disable_pager',
65+
'disable_show_warnings',
6166
'editor_command',
67+
'enable_show_warnings',
6268
'execute',
6369
'flush_pipe_once_if_written',
6470
'forced_horizontal',
@@ -71,6 +77,7 @@
7177
'is_llm_command',
7278
'is_pager_enabled',
7379
'is_redirected',
80+
'is_show_warnings_enabled',
7481
'is_timing_enabled',
7582
'list_databases',
7683
'list_tables',
@@ -85,6 +92,7 @@
8592
'set_pager',
8693
'set_pager_enabled',
8794
'set_redirect',
95+
'set_show_warnings_enabled',
8896
'set_timing_enabled',
8997
'set_show_favorite_query',
9098
'is_show_favorite_query',

mycli/packages/special/iocommands.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
delimiter_command = DelimiterCommand()
4747
favoritequeries = FavoriteQueries(ConfigObj())
4848
DESTRUCTIVE_KEYWORDS: list[str] = []
49+
SHOW_WARNINGS_ENABLED: bool = False
4950

5051

5152
def set_favorite_queries(config):
@@ -81,6 +82,45 @@ def set_destructive_keywords(val: list[str]) -> None:
8182
DESTRUCTIVE_KEYWORDS = val
8283

8384

85+
def set_show_warnings_enabled(val: bool) -> None:
86+
global SHOW_WARNINGS_ENABLED
87+
SHOW_WARNINGS_ENABLED = val
88+
89+
90+
def is_show_warnings_enabled() -> bool:
91+
return SHOW_WARNINGS_ENABLED
92+
93+
94+
@special_command(
95+
'warnings',
96+
'warnings',
97+
'Enable automatic warnings display.',
98+
arg_type=ArgType.NO_QUERY,
99+
aliases=['\\W'],
100+
case_sensitive=True,
101+
)
102+
def enable_show_warnings() -> Generator[SQLResult, None, None]:
103+
global SHOW_WARNINGS_ENABLED
104+
SHOW_WARNINGS_ENABLED = True
105+
msg = "Show warnings enabled."
106+
yield SQLResult(status=msg)
107+
108+
109+
@special_command(
110+
'nowarnings',
111+
'nowarnings',
112+
'Disable automatic warnings display.',
113+
arg_type=ArgType.NO_QUERY,
114+
aliases=['\\w'],
115+
case_sensitive=True,
116+
)
117+
def disable_show_warnings() -> Generator[SQLResult, None, None]:
118+
global SHOW_WARNINGS_ENABLED
119+
SHOW_WARNINGS_ENABLED = False
120+
msg = 'Show warnings disabled.'
121+
yield SQLResult(status=msg)
122+
123+
84124
@special_command(
85125
"pager",
86126
"pager [command]",

test/pytests/test_main.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
PORT,
3838
TEMPFILE_PREFIX,
3939
USER,
40+
DummyFormatter,
41+
FakeCursorBase,
4042
ReusableLock,
4143
call_click_entrypoint_direct,
4244
dbtest,
@@ -2324,3 +2326,42 @@ def test_click_entrypoint_callback_covers_mycnf_underscore_fallback(monkeypatch:
23242326

23252327
call_click_entrypoint_direct(main.CliArgs())
23262328
assert any('ssl-ca = /tmp/ca.pem' in line for line in click_lines)
2329+
2330+
2331+
def test_format_sqlresult_uses_redirect_formatter_when_redirected() -> None:
2332+
cli = make_bare_mycli()
2333+
cli.main_formatter = DummyFormatter()
2334+
cli.redirect_formatter = DummyFormatter()
2335+
2336+
result = SQLResult(header=['id'], rows=[(1,)], status='ok')
2337+
assert list(main.MyCli.format_sqlresult(cli, result, is_redirected=True)) == ['plain output']
2338+
2339+
assert cli.main_formatter.calls == []
2340+
assert len(cli.redirect_formatter.calls) == 1
2341+
2342+
2343+
def test_format_sqlresult_materializes_cursor_rows_when_width_is_limited(monkeypatch: pytest.MonkeyPatch) -> None:
2344+
cli = make_bare_mycli()
2345+
cli.main_formatter = DummyFormatter()
2346+
rows = FakeCursorBase(rows=[(1,)], rowcount=1, description=[('id', 3)])
2347+
monkeypatch.setattr(main, 'Cursor', FakeCursorBase)
2348+
2349+
result = SQLResult(header=['id'], rows=cast(Any, rows), status='ok')
2350+
list(main.MyCli.format_sqlresult(cli, result, max_width=100))
2351+
2352+
formatted_rows = cli.main_formatter.calls[-1][0][0]
2353+
assert formatted_rows == [(1,)]
2354+
2355+
2356+
def test_format_sqlresult_appends_postamble() -> None:
2357+
cli = make_bare_mycli()
2358+
result = SQLResult(header=['id'], rows=[(1,)], status='ok', postamble='done')
2359+
2360+
assert list(main.MyCli.format_sqlresult(cli, result))[-1] == 'done'
2361+
2362+
2363+
def test_get_last_query_returns_latest_query() -> None:
2364+
cli = make_bare_mycli()
2365+
cli.query_history = [main.Query('select 1', True, False)]
2366+
2367+
assert main.MyCli.get_last_query(cli) == 'select 1'

test/pytests/test_main_modes_repl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,6 @@ def run(self, text: str) -> list[SQLResult]:
527527
cli.auto_vertical_output = True
528528
cli.prompt_session = FakePromptSession(columns=91)
529529
cli.beep_after_seconds = 0.1
530-
cli.show_warnings = True
531530
state = repl_mode.ReplState()
532531
format_widths: list[int | None] = []
533532

@@ -540,6 +539,7 @@ def format_sqlresult(result: SQLResult, **kwargs: Any) -> Iterator[str]:
540539
monkeypatch.setattr(repl_mode.time, 'time', lambda: next(time_values))
541540
monkeypatch.setattr(repl_mode.special, 'is_expanded_output', lambda: False)
542541
monkeypatch.setattr(repl_mode.special, 'is_redirected', lambda: False)
542+
monkeypatch.setattr(repl_mode.special, 'is_show_warnings_enabled', lambda: True)
543543
monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: True)
544544
monkeypatch.setattr(repl_mode, 'Cursor', FakeCursorBase)
545545
monkeypatch.setattr(repl_mode, 'is_select', lambda status: False)

0 commit comments

Comments
 (0)