1+ """SQL completer with typed interfaces aligned to mycli."""
2+
13# mypy: ignore-errors
24
3- from __future__ import print_function
4- from __future__ import unicode_literals
5+ from __future__ import annotations
6+
57import logging
68from re import compile , escape
79from collections import Counter
10+ from typing import Any , Collection , Generator , Iterable , List , Optional , Set , Tuple , Literal
811
9- from prompt_toolkit .completion import Completer , Completion
12+ from prompt_toolkit .completion import CompleteEvent , Completer , Completion
13+ from prompt_toolkit .completion .base import Document
1014
1115from .packages .completion_engine import suggest_type
1216from .packages .parseutils import last_word
1822
1923
2024class SQLCompleter (Completer ):
21- keywords = [
25+ keywords : List [ str ] = [
2226 "ABORT" ,
2327 "ACTION" ,
2428 "ADD" ,
@@ -184,7 +188,7 @@ class SQLCompleter(Completer):
184188 "WITHOUT" ,
185189 ]
186190
187- functions = [
191+ functions : List [ str ] = [
188192 "ABS" ,
189193 "AVG" ,
190194 "CHANGES" ,
@@ -255,49 +259,49 @@ class SQLCompleter(Completer):
255259 "TRIM" ,
256260 ]
257261
258- def __init__ (self , supported_formats = (), keyword_casing = "auto" ):
262+ def __init__ (self , supported_formats : Iterable [ str ] = (), keyword_casing : Literal [ "upper" , "lower" , "auto" ] = "auto" ):
259263 super (self .__class__ , self ).__init__ ()
260- self .reserved_words = set ()
264+ self .reserved_words : Set [ str ] = set ()
261265 for x in self .keywords :
262266 self .reserved_words .update (x .split ())
263267 self .name_pattern = compile (r"^[_a-zA-Z][_a-zA-Z0-9\$]*$" )
264268
265- self .special_commands = []
266- self .table_formats = supported_formats
269+ self .special_commands : List [ str ] = []
270+ self .table_formats : List [ str ] = list ( supported_formats )
267271 if keyword_casing not in ("upper" , "lower" , "auto" ):
268272 keyword_casing = "auto"
269- self .keyword_casing = keyword_casing
273+ self .keyword_casing : Literal [ "upper" , "lower" , "auto" ] = keyword_casing
270274 self .reset_completions ()
271275
272- def escape_name (self , name ) :
276+ def escape_name (self , name : str ) -> str :
273277 if name and ((not self .name_pattern .match (name )) or (name .upper () in self .reserved_words ) or (name .upper () in self .functions )):
274278 name = "`%s`" % name
275279
276280 return name
277281
278- def unescape_name (self , name ) :
282+ def unescape_name (self , name : str ) -> str :
279283 """Unquote a string."""
280284 if name and name [0 ] == '"' and name [- 1 ] == '"' :
281285 name = name [1 :- 1 ]
282286
283287 return name
284288
285- def escaped_names (self , names ) :
289+ def escaped_names (self , names : Iterable [ str ]) -> List [ str ] :
286290 return [self .escape_name (name ) for name in names ]
287291
288- def extend_special_commands (self , special_commands ) :
292+ def extend_special_commands (self , special_commands : Iterable [ str ]) -> None :
289293 # Special commands are not part of all_completions since they can only
290294 # be at the beginning of a line.
291295 self .special_commands .extend (special_commands )
292296
293- def extend_database_names (self , databases ) :
297+ def extend_database_names (self , databases : Iterable [ str ]) -> None :
294298 self .databases .extend (databases )
295299
296- def extend_keywords (self , additional_keywords ) :
300+ def extend_keywords (self , additional_keywords : Iterable [ str ]) -> None :
297301 self .keywords .extend (additional_keywords )
298302 self .all_completions .update (additional_keywords )
299303
300- def extend_schemata (self , schema ) :
304+ def extend_schemata (self , schema : Optional [ str ]) -> None :
301305 if schema is None :
302306 return
303307 metadata = self .dbmetadata ["tables" ]
@@ -308,7 +312,7 @@ def extend_schemata(self, schema):
308312 metadata [schema ] = {}
309313 self .all_completions .update (schema )
310314
311- def extend_relations (self , data , kind ) :
315+ def extend_relations (self , data : Iterable [ Iterable [ str ]] , kind : str ) -> None :
312316 """Extend metadata for tables or views
313317
314318 :param data: list of (rel_name, ) tuples
@@ -340,7 +344,7 @@ def extend_relations(self, data, kind):
340344 )
341345 self .all_completions .add (relname [0 ])
342346
343- def extend_columns (self , column_data , kind ) :
347+ def extend_columns (self , column_data : Iterable [ Iterable [ str ]] , kind : str ) -> None :
344348 """Extend column metadata
345349
346350 :param column_data: list of (rel_name, column_name) tuples
@@ -362,7 +366,7 @@ def extend_columns(self, column_data, kind):
362366 metadata [self .dbname ][relname ].append (column )
363367 self .all_completions .add (column )
364368
365- def extend_functions (self , func_data ) :
369+ def extend_functions (self , func_data : Iterable [ Iterable [ str ]]) -> None :
366370 # 'func_data' is a generator object. It can throw an exception while
367371 # being consumed. This could happen if the user has launched the app
368372 # without specifying a database name. This exception must be handled to
@@ -381,24 +385,24 @@ def extend_functions(self, func_data):
381385 metadata [self .dbname ][func [0 ]] = None
382386 self .all_completions .add (func [0 ])
383387
384- def set_dbname (self , dbname ) :
388+ def set_dbname (self , dbname : str ) -> None :
385389 self .dbname = dbname
386390
387- def reset_completions (self ):
388- self .databases = []
391+ def reset_completions (self ) -> None :
392+ self .databases : List [ str ] = []
389393 self .dbname = ""
390- self .dbmetadata = {"tables" : {}, "views" : {}, "functions" : {}}
391- self .all_completions = set (self .keywords + self .functions )
394+ self .dbmetadata : dict [ str , Any ] = {"tables" : {}, "views" : {}, "functions" : {}}
395+ self .all_completions : Set [ str ] = set (self .keywords + self .functions )
392396
393397 @staticmethod
394398 def find_matches (
395- text ,
396- collection ,
397- start_only = False ,
398- fuzzy = True ,
399- casing = None ,
400- punctuations = "most_punctuations" ,
401- ):
399+ text : str ,
400+ collection : Collection [ str ] ,
401+ start_only : bool = False ,
402+ fuzzy : bool = True ,
403+ casing : Optional [ str ] = None ,
404+ punctuations : str = "most_punctuations" ,
405+ ) -> Generator [ Completion , None , None ] :
402406 """Find completion matches for the given text.
403407
404408 Given the user's input text and a collection of available
@@ -434,16 +438,20 @@ def find_matches(
434438 if casing == "auto" :
435439 casing = "lower" if last and last [- 1 ].islower () else "upper"
436440
437- def apply_case (kw ) :
441+ def apply_case (kw : str ) -> str :
438442 if casing == "upper" :
439443 return kw .upper ()
440444 return kw .lower ()
441445
442446 return (Completion (z if casing is None else apply_case (z ), - len (text )) for x , y , z in sorted (completions ))
443447
444- def get_completions (self , document , complete_event ):
448+ def get_completions (
449+ self ,
450+ document : Document ,
451+ complete_event : Optional [CompleteEvent ],
452+ ) -> Iterable [Completion ]:
445453 word_before_cursor = document .get_word_before_cursor (WORD = True )
446- completions = []
454+ completions : List [ Completion ] = []
447455 suggestions = suggest_type (document .text , document .text_before_cursor )
448456
449457 for suggestion in suggestions :
@@ -551,7 +559,7 @@ def get_completions(self, document, complete_event):
551559
552560 return completions
553561
554- def find_files (self , word ) :
562+ def find_files (self , word : str ) -> Generator [ Completion , None , None ] :
555563 """Yield matching directory or file names.
556564
557565 :param word:
@@ -565,7 +573,7 @@ def find_files(self, word):
565573 if suggestion :
566574 yield Completion (suggestion , position )
567575
568- def populate_scoped_cols (self , scoped_tbls ) :
576+ def populate_scoped_cols (self , scoped_tbls : List [ Tuple [ Optional [ str ], str , Optional [ str ]]]) -> List [ str ] :
569577 """Find all columns in a set of scoped_tables
570578 :param scoped_tbls: list of (schema, table, alias) tuples
571579 :return: list of column names
@@ -603,7 +611,7 @@ def populate_scoped_cols(self, scoped_tbls):
603611
604612 return columns
605613
606- def populate_schema_objects (self , schema , obj_type ) :
614+ def populate_schema_objects (self , schema : Optional [ str ] , obj_type : str ) -> List [ str ] :
607615 """Returns list of tables or functions for a (optional) schema"""
608616 metadata = self .dbmetadata [obj_type ]
609617 schema = schema or self .dbname
0 commit comments