Skip to content

Commit 7988113

Browse files
committed
Type hints for completion_refresher and sqlexecute
1 parent 4205f70 commit 7988113

2 files changed

Lines changed: 49 additions & 38 deletions

File tree

litecli/completion_refresher.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# mypy: ignore-errors
22

3+
from __future__ import annotations
4+
35
import threading
6+
from typing import Callable, Dict, List, Optional, Tuple
7+
48
from .packages.special.main import COMMANDS
59
from collections import OrderedDict
610

@@ -9,13 +13,18 @@
913

1014

1115
class CompletionRefresher(object):
12-
refreshers = OrderedDict()
16+
refreshers: Dict[str, Callable] = OrderedDict()
1317

14-
def __init__(self):
15-
self._completer_thread = None
18+
def __init__(self) -> None:
19+
self._completer_thread: Optional[threading.Thread] = None
1620
self._restart_refresh = threading.Event()
1721

18-
def refresh(self, executor, callbacks, completer_options=None):
22+
def refresh(
23+
self,
24+
executor: SQLExecute,
25+
callbacks: Callable | List[Callable],
26+
completer_options: Optional[dict] = None,
27+
) -> List[Tuple]:
1928
"""Creates a SQLCompleter object and populates it with the relevant
2029
completion suggestions in a background thread.
2130
@@ -34,7 +43,7 @@ def refresh(self, executor, callbacks, completer_options=None):
3443
self._restart_refresh.set()
3544
return [(None, None, None, "Auto-completion refresh restarted.")]
3645
else:
37-
if executor.dbname == ":memory:":
46+
if executor.dbname == ":memory":
3847
# if DB is memory, needed to use same connection
3948
# So can't use same connection with different thread
4049
self._bg_refresh(executor, callbacks, completer_options)
@@ -46,19 +55,17 @@ def refresh(self, executor, callbacks, completer_options=None):
4655
)
4756
self._completer_thread.daemon = True
4857
self._completer_thread.start()
49-
return [
50-
(
51-
None,
52-
None,
53-
None,
54-
"Auto-completion refresh started in the background.",
55-
)
56-
]
57-
58-
def is_refreshing(self):
59-
return self._completer_thread and self._completer_thread.is_alive()
60-
61-
def _bg_refresh(self, sqlexecute, callbacks, completer_options):
58+
return [(None, None, None, "Auto-completion refresh started in the background.")]
59+
60+
def is_refreshing(self) -> bool:
61+
return bool(self._completer_thread and self._completer_thread.is_alive())
62+
63+
def _bg_refresh(
64+
self,
65+
sqlexecute: SQLExecute,
66+
callbacks: Callable | List[Callable],
67+
completer_options: dict,
68+
) -> None:
6269
completer = SQLCompleter(**completer_options)
6370

6471
e = sqlexecute
@@ -92,41 +99,42 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options):
9299
callback(completer)
93100

94101

95-
def refresher(name, refreshers=CompletionRefresher.refreshers):
102+
def refresher(name: str, refreshers: Dict[str, Callable] = CompletionRefresher.refreshers):
96103
"""Decorator to add the decorated function to the dictionary of
97104
refreshers. Any function decorated with a @refresher will be executed as
98105
part of the completion refresh routine."""
99106

100-
def wrapper(wrapped):
107+
def wrapper(wrapped: Callable) -> Callable:
101108
refreshers[name] = wrapped
102109
return wrapped
103110

104111
return wrapper
105112

106113

107114
@refresher("databases")
108-
def refresh_databases(completer, executor):
115+
def refresh_databases(completer: SQLCompleter, executor: SQLExecute) -> None:
109116
completer.extend_database_names(executor.databases())
110117

111118

112119
@refresher("schemata")
113-
def refresh_schemata(completer, executor):
120+
def refresh_schemata(completer: SQLCompleter, executor: SQLExecute) -> None:
114121
# name of the current database.
115122
completer.extend_schemata(executor.dbname)
116123
completer.set_dbname(executor.dbname)
117124

118125

119126
@refresher("tables")
120-
def refresh_tables(completer, executor):
121-
completer.extend_relations(executor.tables(), kind="tables")
122-
completer.extend_columns(executor.table_columns(), kind="tables")
127+
def refresh_tables(completer: SQLCompleter, executor: SQLExecute) -> None:
128+
table_cols = list(executor.table_columns())
129+
completer.extend_relations(table_cols, kind="tables")
130+
completer.extend_columns(table_cols, kind="tables")
123131

124132

125133
@refresher("functions")
126-
def refresh_functions(completer, executor):
134+
def refresh_functions(completer: SQLCompleter, executor: SQLExecute) -> None:
127135
completer.extend_functions(executor.functions())
128136

129137

130138
@refresher("special_commands")
131-
def refresh_special(completer, executor):
132-
completer.extend_special_commands(COMMANDS.keys())
139+
def refresh_special(completer: SQLCompleter, executor: SQLExecute) -> None:
140+
completer.extend_special_commands(list(COMMANDS.keys()))

litecli/sqlexecute.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# mypy: ignore-errors
22

3+
from __future__ import annotations
4+
35
import logging
6+
from typing import Any, Generator, Iterable, Optional, Tuple
47

58
from contextlib import closing
69

@@ -57,7 +60,7 @@ class SQLExecute(object):
5760
functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES
5861
WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"'''
5962

60-
def __init__(self, database):
63+
def __init__(self, database: Optional[str]):
6164
self.dbname = database
6265
self._server_type = None
6366
self.conn = None
@@ -66,7 +69,7 @@ def __init__(self, database):
6669
return
6770
self.connect()
6871

69-
def connect(self, database=None):
72+
def connect(self, database: Optional[str] = None) -> None:
7073
db = database or self.dbname
7174
_logger.debug("Connection DB Params: \n\tdatabase: %r", db)
7275

@@ -85,7 +88,7 @@ def connect(self, database=None):
8588
# successful connection.
8689
self.dbname = db
8790

88-
def run(self, statement):
91+
def run(self, statement: str) -> Iterable[Tuple]:
8992
"""Execute the sql in the database and return the results. The results
9093
are a list of tuples. Each tuple has 4 values
9194
(title, rows, headers, status).
@@ -140,7 +143,7 @@ def run(self, statement):
140143
cur.execute(sql)
141144
yield self.get_result(cur)
142145

143-
def get_result(self, cursor):
146+
def get_result(self, cursor: Any) -> Tuple[Optional[str], Optional[list], Optional[list], str]:
144147
"""Get the current result's data from the cursor."""
145148
title = headers = None
146149

@@ -164,7 +167,7 @@ def get_result(self, cursor):
164167

165168
return (title, cursor, headers, status)
166169

167-
def tables(self):
170+
def tables(self) -> Generator[Tuple[str], None, None]:
168171
"""Yields table names"""
169172

170173
with closing(self.conn.cursor()) as cur:
@@ -173,15 +176,15 @@ def tables(self):
173176
for row in cur:
174177
yield row
175178

176-
def table_columns(self):
179+
def table_columns(self) -> Generator[Tuple[str, str], None, None]:
177180
"""Yields column names"""
178181
with closing(self.conn.cursor()) as cur:
179182
_logger.debug("Columns Query. sql: %r", self.table_columns_query)
180183
cur.execute(self.table_columns_query)
181184
for row in cur:
182185
yield row
183186

184-
def databases(self):
187+
def databases(self) -> Iterable[str]:
185188
if not self.conn:
186189
return
187190

@@ -190,7 +193,7 @@ def databases(self):
190193
for row in cur.execute(self.databases_query):
191194
yield row[1]
192195

193-
def functions(self):
196+
def functions(self) -> Iterable[Tuple]:
194197
"""Yields tuples of (schema_name, function_name)"""
195198

196199
with closing(self.conn.cursor()) as cur:
@@ -199,7 +202,7 @@ def functions(self):
199202
for row in cur:
200203
yield row
201204

202-
def show_candidates(self):
205+
def show_candidates(self) -> Iterable[Any]:
203206
with closing(self.conn.cursor()) as cur:
204207
_logger.debug("Show Query. sql: %r", self.show_candidates_query)
205208
try:
@@ -211,6 +214,6 @@ def show_candidates(self):
211214
for row in cur:
212215
yield (row[0].split(None, 1)[-1],)
213216

214-
def server_type(self):
217+
def server_type(self) -> Tuple[str, str]:
215218
self._server_type = ("sqlite3", "3")
216219
return self._server_type

0 commit comments

Comments
 (0)