Skip to content

Commit 80c6873

Browse files
committed
Type hints for sqlcompleter.
1 parent ad1d5bf commit 80c6873

1 file changed

Lines changed: 46 additions & 38 deletions

File tree

litecli/sqlcompleter.py

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
"""SQL completer with typed interfaces aligned to mycli."""
2+
13
# mypy: ignore-errors
24

3-
from __future__ import print_function
4-
from __future__ import unicode_literals
5+
from __future__ import annotations
6+
57
import logging
68
from re import compile, escape
79
from collections import Counter
10+
from typing import Any, Collection, Generator, Iterable, List, Optional, Set, Tuple, Literal
811

9-
from prompt_toolkit.completion import Completer, Completion
12+
from prompt_toolkit.completion import CompleteEvent, Completer, Completion
13+
from prompt_toolkit.completion.base import Document
1014

1115
from .packages.completion_engine import suggest_type
1216
from .packages.parseutils import last_word
@@ -18,7 +22,7 @@
1822

1923

2024
class SQLCompleter(Completer):
21-
keywords = [
25+
keywords: List[str] = [
2226
"ABORT",
2327
"ACTION",
2428
"ADD",
@@ -184,7 +188,7 @@ class SQLCompleter(Completer):
184188
"WITHOUT",
185189
]
186190

187-
functions = [
191+
functions: List[str] = [
188192
"ABS",
189193
"AVG",
190194
"CHANGES",
@@ -255,49 +259,49 @@ class SQLCompleter(Completer):
255259
"TRIM",
256260
]
257261

258-
def __init__(self, supported_formats=(), keyword_casing="auto"):
262+
def __init__(self, supported_formats: Iterable[str] = (), keyword_casing: Literal["upper", "lower", "auto"] = "auto"):
259263
super(self.__class__, self).__init__()
260-
self.reserved_words = set()
264+
self.reserved_words: Set[str] = set()
261265
for x in self.keywords:
262266
self.reserved_words.update(x.split())
263267
self.name_pattern = compile(r"^[_a-zA-Z][_a-zA-Z0-9\$]*$")
264268

265-
self.special_commands = []
266-
self.table_formats = supported_formats
269+
self.special_commands: List[str] = []
270+
self.table_formats: List[str] = list(supported_formats)
267271
if keyword_casing not in ("upper", "lower", "auto"):
268272
keyword_casing = "auto"
269-
self.keyword_casing = keyword_casing
273+
self.keyword_casing: Literal["upper", "lower", "auto"] = keyword_casing
270274
self.reset_completions()
271275

272-
def escape_name(self, name):
276+
def escape_name(self, name: str) -> str:
273277
if name and ((not self.name_pattern.match(name)) or (name.upper() in self.reserved_words) or (name.upper() in self.functions)):
274278
name = "`%s`" % name
275279

276280
return name
277281

278-
def unescape_name(self, name):
282+
def unescape_name(self, name: str) -> str:
279283
"""Unquote a string."""
280284
if name and name[0] == '"' and name[-1] == '"':
281285
name = name[1:-1]
282286

283287
return name
284288

285-
def escaped_names(self, names):
289+
def escaped_names(self, names: Iterable[str]) -> List[str]:
286290
return [self.escape_name(name) for name in names]
287291

288-
def extend_special_commands(self, special_commands):
292+
def extend_special_commands(self, special_commands: Iterable[str]) -> None:
289293
# Special commands are not part of all_completions since they can only
290294
# be at the beginning of a line.
291295
self.special_commands.extend(special_commands)
292296

293-
def extend_database_names(self, databases):
297+
def extend_database_names(self, databases: Iterable[str]) -> None:
294298
self.databases.extend(databases)
295299

296-
def extend_keywords(self, additional_keywords):
300+
def extend_keywords(self, additional_keywords: Iterable[str]) -> None:
297301
self.keywords.extend(additional_keywords)
298302
self.all_completions.update(additional_keywords)
299303

300-
def extend_schemata(self, schema):
304+
def extend_schemata(self, schema: Optional[str]) -> None:
301305
if schema is None:
302306
return
303307
metadata = self.dbmetadata["tables"]
@@ -308,7 +312,7 @@ def extend_schemata(self, schema):
308312
metadata[schema] = {}
309313
self.all_completions.update(schema)
310314

311-
def extend_relations(self, data, kind):
315+
def extend_relations(self, data: Iterable[Iterable[str]], kind: str) -> None:
312316
"""Extend metadata for tables or views
313317
314318
:param data: list of (rel_name, ) tuples
@@ -340,7 +344,7 @@ def extend_relations(self, data, kind):
340344
)
341345
self.all_completions.add(relname[0])
342346

343-
def extend_columns(self, column_data, kind):
347+
def extend_columns(self, column_data: Iterable[Iterable[str]], kind: str) -> None:
344348
"""Extend column metadata
345349
346350
:param column_data: list of (rel_name, column_name) tuples
@@ -362,7 +366,7 @@ def extend_columns(self, column_data, kind):
362366
metadata[self.dbname][relname].append(column)
363367
self.all_completions.add(column)
364368

365-
def extend_functions(self, func_data):
369+
def extend_functions(self, func_data: Iterable[Iterable[str]]) -> None:
366370
# 'func_data' is a generator object. It can throw an exception while
367371
# being consumed. This could happen if the user has launched the app
368372
# without specifying a database name. This exception must be handled to
@@ -381,24 +385,24 @@ def extend_functions(self, func_data):
381385
metadata[self.dbname][func[0]] = None
382386
self.all_completions.add(func[0])
383387

384-
def set_dbname(self, dbname):
388+
def set_dbname(self, dbname: str) -> None:
385389
self.dbname = dbname
386390

387-
def reset_completions(self):
388-
self.databases = []
391+
def reset_completions(self) -> None:
392+
self.databases: List[str] = []
389393
self.dbname = ""
390-
self.dbmetadata = {"tables": {}, "views": {}, "functions": {}}
391-
self.all_completions = set(self.keywords + self.functions)
394+
self.dbmetadata: dict[str, Any] = {"tables": {}, "views": {}, "functions": {}}
395+
self.all_completions: Set[str] = set(self.keywords + self.functions)
392396

393397
@staticmethod
394398
def find_matches(
395-
text,
396-
collection,
397-
start_only=False,
398-
fuzzy=True,
399-
casing=None,
400-
punctuations="most_punctuations",
401-
):
399+
text: str,
400+
collection: Collection[str],
401+
start_only: bool = False,
402+
fuzzy: bool = True,
403+
casing: Optional[str] = None,
404+
punctuations: str = "most_punctuations",
405+
) -> Generator[Completion, None, None]:
402406
"""Find completion matches for the given text.
403407
404408
Given the user's input text and a collection of available
@@ -434,16 +438,20 @@ def find_matches(
434438
if casing == "auto":
435439
casing = "lower" if last and last[-1].islower() else "upper"
436440

437-
def apply_case(kw):
441+
def apply_case(kw: str) -> str:
438442
if casing == "upper":
439443
return kw.upper()
440444
return kw.lower()
441445

442446
return (Completion(z if casing is None else apply_case(z), -len(text)) for x, y, z in sorted(completions))
443447

444-
def get_completions(self, document, complete_event):
448+
def get_completions(
449+
self,
450+
document: Document,
451+
complete_event: Optional[CompleteEvent],
452+
) -> Iterable[Completion]:
445453
word_before_cursor = document.get_word_before_cursor(WORD=True)
446-
completions = []
454+
completions: List[Completion] = []
447455
suggestions = suggest_type(document.text, document.text_before_cursor)
448456

449457
for suggestion in suggestions:
@@ -551,7 +559,7 @@ def get_completions(self, document, complete_event):
551559

552560
return completions
553561

554-
def find_files(self, word):
562+
def find_files(self, word: str) -> Generator[Completion, None, None]:
555563
"""Yield matching directory or file names.
556564
557565
:param word:
@@ -565,7 +573,7 @@ def find_files(self, word):
565573
if suggestion:
566574
yield Completion(suggestion, position)
567575

568-
def populate_scoped_cols(self, scoped_tbls):
576+
def populate_scoped_cols(self, scoped_tbls: List[Tuple[Optional[str], str, Optional[str]]]) -> List[str]:
569577
"""Find all columns in a set of scoped_tables
570578
:param scoped_tbls: list of (schema, table, alias) tuples
571579
:return: list of column names
@@ -603,7 +611,7 @@ def populate_scoped_cols(self, scoped_tbls):
603611

604612
return columns
605613

606-
def populate_schema_objects(self, schema, obj_type):
614+
def populate_schema_objects(self, schema: Optional[str], obj_type: str) -> List[str]:
607615
"""Returns list of tables or functions for a (optional) schema"""
608616
metadata = self.dbmetadata[obj_type]
609617
schema = schema or self.dbname

0 commit comments

Comments
 (0)