11# mypy: ignore-errors
22
3- from __future__ import print_function , unicode_literals
3+ from __future__ import annotations
44
55import itertools
66import logging
1919except ImportError :
2020 from sqlite3 import OperationalError , sqlite_version
2121from time import time
22+ from typing import Any , Iterable , Optional , Dict
2223
2324import click
2425import sqlparse
3637)
3738from prompt_toolkit .lexers import PygmentsLexer
3839from prompt_toolkit .shortcuts import CompleteStyle , PromptSession
40+ from prompt_toolkit .completion import Completion
3941
4042from .__init__ import __version__
4143from .clibuffer import cli_is_multiline
@@ -66,13 +68,13 @@ class LiteCli(object):
6668
6769 def __init__ (
6870 self ,
69- sqlexecute = None ,
70- prompt = None ,
71- logfile = None ,
72- auto_vertical_output = False ,
73- warn = None ,
74- liteclirc = None ,
75- ):
71+ sqlexecute : Optional [ SQLExecute ] = None ,
72+ prompt : Optional [ str ] = None ,
73+ logfile : Optional [ Any ] = None ,
74+ auto_vertical_output : bool = False ,
75+ warn : Optional [ bool ] = None ,
76+ liteclirc : Optional [ str ] = None ,
77+ ) -> None :
7678 self .sqlexecute = sqlexecute
7779 self .logfile = logfile
7880
@@ -139,7 +141,7 @@ def __init__(
139141
140142 self .prompt_app = None
141143
142- def register_special_commands (self ):
144+ def register_special_commands (self ) -> None :
143145 special .register_special_command (
144146 self .change_db ,
145147 ".open" ,
@@ -180,7 +182,7 @@ def register_special_commands(self):
180182 case_sensitive = True ,
181183 )
182184
183- def change_table_format (self , arg , ** _ ):
185+ def change_table_format (self , arg : str , ** _ : Any ):
184186 try :
185187 self .formatter .format_name = arg
186188 yield (None , None , None , "Changed table format to {}" .format (arg ))
@@ -190,7 +192,7 @@ def change_table_format(self, arg, **_):
190192 msg += "\n \t {}" .format (table_type )
191193 yield (None , None , None , msg )
192194
193- def change_db (self , arg , ** _ ):
195+ def change_db (self , arg : Optional [ str ] , ** _ : Any ):
194196 if arg is None :
195197 self .sqlexecute .connect ()
196198 else :
@@ -204,7 +206,7 @@ def change_db(self, arg, **_):
204206 'You are now connected to database "%s"' % (self .sqlexecute .dbname ),
205207 )
206208
207- def execute_from_file (self , arg , ** _ ):
209+ def execute_from_file (self , arg : Optional [ str ] , ** _ : Any ):
208210 if not arg :
209211 message = "Missing required argument, filename."
210212 return [(None , None , None , message )]
@@ -220,7 +222,7 @@ def execute_from_file(self, arg, **_):
220222
221223 return self .sqlexecute .run (query )
222224
223- def change_prompt_format (self , arg , ** _ ):
225+ def change_prompt_format (self , arg : Optional [ str ] , ** _ : Any ):
224226 """
225227 Change the prompt format.
226228 """
@@ -231,7 +233,7 @@ def change_prompt_format(self, arg, **_):
231233 self .prompt_format = self .get_prompt (arg )
232234 return [(None , None , None , "Changed prompt format to %s" % arg )]
233235
234- def initialize_logging (self ):
236+ def initialize_logging (self ) -> None :
235237 log_file = self .config ["main" ]["log_file" ]
236238 if log_file == "default" :
237239 log_file = config_location () + "log"
@@ -279,7 +281,7 @@ def initialize_logging(self):
279281 root_logger .debug ("Initializing litecli logging." )
280282 root_logger .debug ("Log file %r." , log_file )
281283
282- def read_my_cnf_files (self , keys ) :
284+ def read_my_cnf_files (self , keys : Iterable [ str ]) -> Dict [ str , Optional [ str ]] :
283285 """
284286 Reads a list of config files and merges them. The last one will win.
285287 :param files: list of files to read
@@ -299,7 +301,7 @@ def get(key):
299301
300302 return {x : get (x ) for x in keys }
301303
302- def connect (self , database = "" ):
304+ def connect (self , database : str = "" ) -> None :
303305 cnf = {"database" : None }
304306
305307 cnf = self .read_my_cnf_files (cnf .keys ())
@@ -321,7 +323,7 @@ def _connect():
321323 self .echo (str (e ), err = True , fg = "red" )
322324 exit (1 )
323325
324- def handle_editor_command (self , text ) :
326+ def handle_editor_command (self , text : str ) -> str :
325327 R"""Editor command is any query that is prefixed or suffixed by a '\e'.
326328 The reason for a while loop is because a user might edit a query
327329 multiple times. For eg:
@@ -351,7 +353,7 @@ def handle_editor_command(self, text):
351353 continue
352354 return text
353355
354- def run_cli (self ):
356+ def run_cli (self ) -> None :
355357 iterations = 0
356358 sqlexecute = self .sqlexecute
357359 logger = self .logger
@@ -645,12 +647,12 @@ def startup_commands():
645647 if not self .less_chatty :
646648 self .echo ("Goodbye!" )
647649
648- def log_output (self , output ) :
650+ def log_output (self , output : str ) -> None :
649651 """Log the output in the audit log, if it's enabled."""
650652 if self .logfile :
651653 click .echo (output , file = self .logfile )
652654
653- def echo (self , s , ** kwargs ) :
655+ def echo (self , s : str , ** kwargs : Any ) -> None :
654656 """Print a message to stdout.
655657
656658 The message will be logged in the audit log, if enabled.
@@ -661,7 +663,7 @@ def echo(self, s, **kwargs):
661663 self .log_output (s )
662664 click .secho (s , ** kwargs )
663665
664- def get_output_margin (self , status = None ):
666+ def get_output_margin (self , status : Optional [ str ] = None ) -> int :
665667 """Get the output margin (number of rows for the prompt, footer and
666668 timing message."""
667669 margin = self .get_reserved_space () + self .get_prompt (self .prompt_format ).count ("\n " ) + 2
@@ -670,7 +672,7 @@ def get_output_margin(self, status=None):
670672
671673 return margin
672674
673- def output (self , output , status = None ):
675+ def output (self , output : Iterable [ str ] , status : Optional [ str ] = None ) -> None :
674676 """Output text to stdout or a pager command.
675677
676678 The status text is not outputted to pager or files.
@@ -723,7 +725,7 @@ def output(self, output, status=None):
723725 self .log_output (status )
724726 click .secho (status )
725727
726- def configure_pager (self ):
728+ def configure_pager (self ) -> None :
727729 # Provide sane defaults for less if they are empty.
728730 if not os .environ .get ("LESS" ):
729731 os .environ ["LESS" ] = "-RXF"
@@ -738,7 +740,7 @@ def configure_pager(self):
738740 if cnf ["skip-pager" ] or not self .config ["main" ].as_bool ("enable_pager" ):
739741 special .disable_pager ()
740742
741- def refresh_completions (self , reset = False ):
743+ def refresh_completions (self , reset : bool = False ):
742744 if reset :
743745 with self ._completer_lock :
744746 self .completer .reset_completions ()
@@ -753,7 +755,7 @@ def refresh_completions(self, reset=False):
753755
754756 return [(None , None , None , "Auto-completion refresh started in the background." )]
755757
756- def _on_completions_refreshed (self , new_completer ) :
758+ def _on_completions_refreshed (self , new_completer : SQLCompleter ) -> None :
757759 """Swap the completer object in cli with the newly created completer."""
758760 with self ._completer_lock :
759761 self .completer = new_completer
@@ -763,11 +765,11 @@ def _on_completions_refreshed(self, new_completer):
763765 # "Refreshing completions..." indicator
764766 self .prompt_app .app .invalidate ()
765767
766- def get_completions (self , text , cursor_positition ) :
768+ def get_completions (self , text : str , cursor_positition : int ) -> Iterable [ Completion ] :
767769 with self ._completer_lock :
768770 return self .completer .get_completions (Document (text = text , cursor_position = cursor_positition ), None )
769771
770- def get_prompt (self , string ) :
772+ def get_prompt (self , string : str ) -> str :
771773 self .logger .debug ("Getting prompt %r" , string )
772774 sqlexecute = self .sqlexecute
773775 now = datetime .now ()
@@ -795,7 +797,7 @@ def replacer(match):
795797 # Perform the substitution
796798 return pattern .sub (replacer , string )
797799
798- def run_query (self , query , new_line = True ):
800+ def run_query (self , query : str , new_line : bool = True ):
799801 """Runs *query*."""
800802 results = self .sqlexecute .run (query )
801803 for result in results :
0 commit comments