Skip to content

Commit 3630c12

Browse files
committed
Type hint main.py and introduce DBCursor protocol.
1 parent 7988113 commit 3630c12

4 files changed

Lines changed: 68 additions & 45 deletions

File tree

litecli/main.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# mypy: ignore-errors
22

3-
from __future__ import print_function, unicode_literals
3+
from __future__ import annotations
44

55
import itertools
66
import logging
@@ -19,6 +19,7 @@
1919
except ImportError:
2020
from sqlite3 import OperationalError, sqlite_version
2121
from time import time
22+
from typing import Any, Iterable, Optional, Dict
2223

2324
import click
2425
import sqlparse
@@ -36,6 +37,7 @@
3637
)
3738
from prompt_toolkit.lexers import PygmentsLexer
3839
from prompt_toolkit.shortcuts import CompleteStyle, PromptSession
40+
from prompt_toolkit.completion import Completion
3941

4042
from .__init__ import __version__
4143
from .clibuffer import cli_is_multiline
@@ -66,13 +68,13 @@ class LiteCli(object):
6668

6769
def __init__(
6870
self,
69-
sqlexecute=None,
70-
prompt=None,
71-
logfile=None,
72-
auto_vertical_output=False,
73-
warn=None,
74-
liteclirc=None,
75-
):
71+
sqlexecute: Optional[SQLExecute] = None,
72+
prompt: Optional[str] = None,
73+
logfile: Optional[Any] = None,
74+
auto_vertical_output: bool = False,
75+
warn: Optional[bool] = None,
76+
liteclirc: Optional[str] = None,
77+
) -> None:
7678
self.sqlexecute = sqlexecute
7779
self.logfile = logfile
7880

@@ -139,7 +141,7 @@ def __init__(
139141

140142
self.prompt_app = None
141143

142-
def register_special_commands(self):
144+
def register_special_commands(self) -> None:
143145
special.register_special_command(
144146
self.change_db,
145147
".open",
@@ -180,7 +182,7 @@ def register_special_commands(self):
180182
case_sensitive=True,
181183
)
182184

183-
def change_table_format(self, arg, **_):
185+
def change_table_format(self, arg: str, **_: Any):
184186
try:
185187
self.formatter.format_name = arg
186188
yield (None, None, None, "Changed table format to {}".format(arg))
@@ -190,7 +192,7 @@ def change_table_format(self, arg, **_):
190192
msg += "\n\t{}".format(table_type)
191193
yield (None, None, None, msg)
192194

193-
def change_db(self, arg, **_):
195+
def change_db(self, arg: Optional[str], **_: Any):
194196
if arg is None:
195197
self.sqlexecute.connect()
196198
else:
@@ -204,7 +206,7 @@ def change_db(self, arg, **_):
204206
'You are now connected to database "%s"' % (self.sqlexecute.dbname),
205207
)
206208

207-
def execute_from_file(self, arg, **_):
209+
def execute_from_file(self, arg: Optional[str], **_: Any):
208210
if not arg:
209211
message = "Missing required argument, filename."
210212
return [(None, None, None, message)]
@@ -220,7 +222,7 @@ def execute_from_file(self, arg, **_):
220222

221223
return self.sqlexecute.run(query)
222224

223-
def change_prompt_format(self, arg, **_):
225+
def change_prompt_format(self, arg: Optional[str], **_: Any):
224226
"""
225227
Change the prompt format.
226228
"""
@@ -231,7 +233,7 @@ def change_prompt_format(self, arg, **_):
231233
self.prompt_format = self.get_prompt(arg)
232234
return [(None, None, None, "Changed prompt format to %s" % arg)]
233235

234-
def initialize_logging(self):
236+
def initialize_logging(self) -> None:
235237
log_file = self.config["main"]["log_file"]
236238
if log_file == "default":
237239
log_file = config_location() + "log"
@@ -279,7 +281,7 @@ def initialize_logging(self):
279281
root_logger.debug("Initializing litecli logging.")
280282
root_logger.debug("Log file %r.", log_file)
281283

282-
def read_my_cnf_files(self, keys):
284+
def read_my_cnf_files(self, keys: Iterable[str]) -> Dict[str, Optional[str]]:
283285
"""
284286
Reads a list of config files and merges them. The last one will win.
285287
:param files: list of files to read
@@ -299,7 +301,7 @@ def get(key):
299301

300302
return {x: get(x) for x in keys}
301303

302-
def connect(self, database=""):
304+
def connect(self, database: str = "") -> None:
303305
cnf = {"database": None}
304306

305307
cnf = self.read_my_cnf_files(cnf.keys())
@@ -321,7 +323,7 @@ def _connect():
321323
self.echo(str(e), err=True, fg="red")
322324
exit(1)
323325

324-
def handle_editor_command(self, text):
326+
def handle_editor_command(self, text: str) -> str:
325327
R"""Editor command is any query that is prefixed or suffixed by a '\e'.
326328
The reason for a while loop is because a user might edit a query
327329
multiple times. For eg:
@@ -351,7 +353,7 @@ def handle_editor_command(self, text):
351353
continue
352354
return text
353355

354-
def run_cli(self):
356+
def run_cli(self) -> None:
355357
iterations = 0
356358
sqlexecute = self.sqlexecute
357359
logger = self.logger
@@ -645,12 +647,12 @@ def startup_commands():
645647
if not self.less_chatty:
646648
self.echo("Goodbye!")
647649

648-
def log_output(self, output):
650+
def log_output(self, output: str) -> None:
649651
"""Log the output in the audit log, if it's enabled."""
650652
if self.logfile:
651653
click.echo(output, file=self.logfile)
652654

653-
def echo(self, s, **kwargs):
655+
def echo(self, s: str, **kwargs: Any) -> None:
654656
"""Print a message to stdout.
655657
656658
The message will be logged in the audit log, if enabled.
@@ -661,7 +663,7 @@ def echo(self, s, **kwargs):
661663
self.log_output(s)
662664
click.secho(s, **kwargs)
663665

664-
def get_output_margin(self, status=None):
666+
def get_output_margin(self, status: Optional[str] = None) -> int:
665667
"""Get the output margin (number of rows for the prompt, footer and
666668
timing message."""
667669
margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count("\n") + 2
@@ -670,7 +672,7 @@ def get_output_margin(self, status=None):
670672

671673
return margin
672674

673-
def output(self, output, status=None):
675+
def output(self, output: Iterable[str], status: Optional[str] = None) -> None:
674676
"""Output text to stdout or a pager command.
675677
676678
The status text is not outputted to pager or files.
@@ -723,7 +725,7 @@ def output(self, output, status=None):
723725
self.log_output(status)
724726
click.secho(status)
725727

726-
def configure_pager(self):
728+
def configure_pager(self) -> None:
727729
# Provide sane defaults for less if they are empty.
728730
if not os.environ.get("LESS"):
729731
os.environ["LESS"] = "-RXF"
@@ -738,7 +740,7 @@ def configure_pager(self):
738740
if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"):
739741
special.disable_pager()
740742

741-
def refresh_completions(self, reset=False):
743+
def refresh_completions(self, reset: bool = False):
742744
if reset:
743745
with self._completer_lock:
744746
self.completer.reset_completions()
@@ -753,7 +755,7 @@ def refresh_completions(self, reset=False):
753755

754756
return [(None, None, None, "Auto-completion refresh started in the background.")]
755757

756-
def _on_completions_refreshed(self, new_completer):
758+
def _on_completions_refreshed(self, new_completer: SQLCompleter) -> None:
757759
"""Swap the completer object in cli with the newly created completer."""
758760
with self._completer_lock:
759761
self.completer = new_completer
@@ -763,11 +765,11 @@ def _on_completions_refreshed(self, new_completer):
763765
# "Refreshing completions..." indicator
764766
self.prompt_app.app.invalidate()
765767

766-
def get_completions(self, text, cursor_positition):
768+
def get_completions(self, text: str, cursor_positition: int) -> Iterable[Completion]:
767769
with self._completer_lock:
768770
return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None)
769771

770-
def get_prompt(self, string):
772+
def get_prompt(self, string: str) -> str:
771773
self.logger.debug("Getting prompt %r", string)
772774
sqlexecute = self.sqlexecute
773775
now = datetime.now()
@@ -795,7 +797,7 @@ def replacer(match):
795797
# Perform the substitution
796798
return pattern.sub(replacer, string)
797799

798-
def run_query(self, query, new_line=True):
800+
def run_query(self, query: str, new_line: bool = True):
799801
"""Runs *query*."""
800802
results = self.sqlexecute.run(query)
801803
for result in results:

litecli/packages/special/dbcommands.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import sys
99
import platform
1010
import shlex
11-
from typing import Any, List, Optional, Tuple
11+
from typing import Any, List, Optional, Tuple, Protocol, Sequence
1212

1313

1414
from litecli import __version__
@@ -18,6 +18,16 @@
1818
log = logging.getLogger(__name__)
1919

2020

21+
class DBCursor(Protocol):
22+
description: Optional[Sequence[Sequence[Any]]]
23+
24+
def execute(self, sql: str, params: Any = ...) -> Any: ...
25+
26+
def fetchall(self) -> List[Tuple[Any, ...]]: ...
27+
28+
def fetchone(self) -> Optional[Tuple[Any, ...]]: ...
29+
30+
2131
@special_command(
2232
".tables",
2333
"\\dt",
@@ -27,7 +37,7 @@
2737
aliases=("\\dt",),
2838
)
2939
def list_tables(
30-
cur: Any,
40+
cur: DBCursor,
3141
arg: Optional[str] = None,
3242
arg_type: int = PARSED_QUERY,
3343
verbose: bool = False,
@@ -74,7 +84,7 @@ def list_tables(
7484
aliases=("\\dv",),
7585
)
7686
def list_views(
77-
cur: Any,
87+
cur: DBCursor,
7888
arg: Optional[str] = None,
7989
arg_type: int = PARSED_QUERY,
8090
verbose: bool = False,
@@ -111,7 +121,7 @@ def list_views(
111121
arg_type=PARSED_QUERY,
112122
case_sensitive=True,
113123
)
114-
def show_schema(cur: Any, arg: Optional[str] = None, **_: Any) -> List[Tuple]:
124+
def show_schema(cur: DBCursor, arg: Optional[str] = None, **_: Any) -> List[Tuple]:
115125
if arg:
116126
args = (arg,)
117127
query = """
@@ -147,7 +157,7 @@ def show_schema(cur: Any, arg: Optional[str] = None, **_: Any) -> List[Tuple]:
147157
case_sensitive=True,
148158
aliases=("\\l",),
149159
)
150-
def list_databases(cur: Any, **_: Any) -> List[Tuple]:
160+
def list_databases(cur: DBCursor, **_: Any) -> List[Tuple]:
151161
query = "PRAGMA database_list"
152162
log.debug(query)
153163
cur.execute(query)
@@ -167,7 +177,7 @@ def list_databases(cur: Any, **_: Any) -> List[Tuple]:
167177
aliases=("\\di",),
168178
)
169179
def list_indexes(
170-
cur: Any,
180+
cur: DBCursor,
171181
arg: Optional[str] = None,
172182
arg_type: int = PARSED_QUERY,
173183
verbose: bool = False,
@@ -206,7 +216,7 @@ def list_indexes(
206216
aliases=("\\s",),
207217
case_sensitive=True,
208218
)
209-
def status(cur: Any, **_: Any) -> List[Tuple]:
219+
def status(cur: DBCursor, **_: Any) -> List[Tuple]:
210220
# Create output buffers.
211221
footer = []
212222
footer.append("--------------")
@@ -267,7 +277,7 @@ def load_extension(cur, arg, **_):
267277
case_sensitive=True,
268278
aliases=("\\d", "desc"),
269279
)
270-
def describe(cur: Any, arg: Optional[str], **_: Any) -> List[Tuple]:
280+
def describe(cur: DBCursor, arg: Optional[str], **_: Any) -> List[Tuple]:
271281
if arg:
272282
query = """
273283
PRAGMA table_info({})
@@ -294,7 +304,7 @@ def describe(cur: Any, arg: Optional[str], **_: Any) -> List[Tuple]:
294304
arg_type=PARSED_QUERY,
295305
case_sensitive=True,
296306
)
297-
def import_file(cur: Any, arg: Optional[str] = None, **_: Any) -> List[Tuple]:
307+
def import_file(cur: DBCursor, arg: Optional[str] = None, **_: Any) -> List[Tuple]:
298308
def split(s):
299309
# this is a modification of shlex.split function, just to make it support '`',
300310
# because table name might contain '`' character.

litecli/packages/special/iocommands.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import subprocess
1111
from io import open
1212
from time import sleep
13-
from typing import Any, Dict, Generator, List, Optional, Tuple
13+
from typing import Any, Dict, Generator, List, Optional, Tuple, TextIO
1414

1515
import click
1616
import sqlparse
@@ -24,10 +24,10 @@
2424

2525
use_expanded_output: bool = False
2626
PAGER_ENABLED: bool = True
27-
tee_file: Any = None
28-
once_file: Any = None
29-
written_to_once_file: Optional[bool] = None
30-
pipe_once_process: Any = None
27+
tee_file: Optional[TextIO] = None
28+
once_file: Optional[TextIO] = None
29+
written_to_once_file: bool = False
30+
pipe_once_process: Optional[subprocess.Popen[str]] = None
3131
written_to_pipe_once_process: bool = False
3232
favoritequeries: FavoriteQueries = FavoriteQueries(ConfigObj())
3333

litecli/packages/special/llm.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import sys
1717
from runpy import run_module
1818
from time import time
19-
from typing import Any, Dict, List, Optional, Tuple
19+
from typing import Any, Dict, List, Optional, Tuple, Protocol, Sequence
2020

2121
import click
2222
import llm
@@ -219,7 +219,18 @@ def ensure_litecli_template(replace: bool = False) -> None:
219219

220220

221221
@export
222-
def handle_llm(text: str, cur: Any) -> Tuple[str, Optional[str], float]:
222+
class DBCursor(Protocol):
223+
description: Optional[Sequence[Sequence[Any]]]
224+
225+
def execute(self, sql: str, params: Any = ...) -> Any: ...
226+
227+
def fetchall(self) -> List[Tuple[Any, ...]]: ...
228+
229+
def fetchone(self) -> Optional[Tuple[Any, ...]]: ...
230+
231+
232+
@export
233+
def handle_llm(text: str, cur: DBCursor) -> Tuple[str, Optional[str], float]:
223234
"""This function handles the special command `\\llm`.
224235
225236
If it deals with a question that results in a SQL query then it will return
@@ -323,7 +334,7 @@ def is_llm_command(command: str) -> bool:
323334

324335
@export
325336
def sql_using_llm(
326-
cur: Any,
337+
cur: DBCursor,
327338
question: Optional[str] = None,
328339
verbose: bool = False,
329340
) -> Tuple[str, Optional[str], Optional[str]]:

0 commit comments

Comments
 (0)