|
45 | 45 | from prompt_toolkit.output import ColorDepth |
46 | 46 | from prompt_toolkit.shortcuts import CompleteStyle, PromptSession |
47 | 47 | import pymysql |
48 | | -from pymysql.constants.ER import HANDSHAKE_ERROR |
| 48 | +from pymysql.constants.ER import ACCESS_DENIED_ERROR, HANDSHAKE_ERROR |
49 | 49 | from pymysql.cursors import Cursor |
50 | 50 | import sqlglot |
51 | 51 | import sqlparse |
@@ -660,63 +660,63 @@ def connect( |
660 | 660 | passwd = keyring.get_password(keychain_domain, keychain_identifier) |
661 | 661 | keychain_retrieved = True |
662 | 662 |
|
663 | | - # if no password was found from all of the above sources, ask for a password |
664 | | - if passwd is None or passwd == "MYCLI_ASK_PASSWORD": |
| 663 | + # prompt for password if requested by user |
| 664 | + if passwd == "MYCLI_ASK_PASSWORD": |
665 | 665 | passwd = click.prompt(f"Enter password for {user}", hide_input=True, show_default=False, default='', type=str, err=True) |
666 | 666 |
|
667 | | - if reset_keyring or (use_keyring and not keychain_retrieved): |
668 | | - try: |
669 | | - saved_pw = keyring.get_password(keychain_domain, keychain_identifier) |
670 | | - if passwd != saved_pw or reset_keyring: |
671 | | - keyring.set_password(keychain_domain, keychain_identifier, passwd) |
672 | | - click.secho('Password saved to the system keyring', err=True) |
673 | | - except Exception as e: |
674 | | - click.secho(f'Password not saved to the system keyring: {e}', err=True, fg='red') |
| 667 | + connection_info: dict[Any, Any] = { |
| 668 | + "database": database, |
| 669 | + "user": user, |
| 670 | + "password": passwd, |
| 671 | + "host": host, |
| 672 | + "port": int_port, |
| 673 | + "socket": socket, |
| 674 | + "character_set": character_set, |
| 675 | + "local_infile": use_local_infile, |
| 676 | + "ssl": ssl_config_or_none, |
| 677 | + "ssh_user": ssh_user, |
| 678 | + "ssh_host": ssh_host, |
| 679 | + "ssh_port": int(ssh_port) if ssh_port else None, |
| 680 | + "ssh_password": ssh_password, |
| 681 | + "ssh_key_filename": ssh_key_filename, |
| 682 | + "init_command": init_command, |
| 683 | + "unbuffered": unbuffered, |
| 684 | + } |
675 | 685 |
|
676 | | - # Connect to the database. |
677 | | - def _connect() -> None: |
| 686 | + def _update_keyring(password: str | None): |
| 687 | + if not password: |
| 688 | + return |
| 689 | + if reset_keyring or (use_keyring and not keychain_retrieved): |
| 690 | + try: |
| 691 | + saved_pw = keyring.get_password(keychain_domain, keychain_identifier) |
| 692 | + if password != saved_pw or reset_keyring: |
| 693 | + keyring.set_password(keychain_domain, keychain_identifier, password) |
| 694 | + click.secho('Password saved to the system keyring', err=True) |
| 695 | + except Exception as e: |
| 696 | + click.secho(f'Password not saved to the system keyring: {e}', err=True, fg='red') |
| 697 | + |
| 698 | + def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None: |
678 | 699 | try: |
679 | | - self.sqlexecute = SQLExecute( |
680 | | - database, |
681 | | - user, |
682 | | - passwd, |
683 | | - host, |
684 | | - int_port, |
685 | | - socket, |
686 | | - character_set, |
687 | | - use_local_infile, |
688 | | - ssl_config_or_none, |
689 | | - ssh_user, |
690 | | - ssh_host, |
691 | | - int(ssh_port) if ssh_port else None, |
692 | | - ssh_password, |
693 | | - ssh_key_filename, |
694 | | - init_command, |
695 | | - unbuffered, |
696 | | - ) |
| 700 | + _update_keyring(connection_info["password"]) |
| 701 | + self.sqlexecute = SQLExecute(**connection_info) |
697 | 702 | except pymysql.OperationalError as e1: |
698 | 703 | if e1.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto": |
699 | | - try: |
700 | | - self.sqlexecute = SQLExecute( |
701 | | - database, |
702 | | - user, |
703 | | - passwd, |
704 | | - host, |
705 | | - int_port, |
706 | | - socket, |
707 | | - character_set, |
708 | | - use_local_infile, |
709 | | - None, |
710 | | - ssh_user, |
711 | | - ssh_host, |
712 | | - int(ssh_port) if ssh_port else None, |
713 | | - ssh_password, |
714 | | - ssh_key_filename, |
715 | | - init_command, |
716 | | - unbuffered, |
717 | | - ) |
718 | | - except Exception as e2: |
719 | | - raise e2 |
| 704 | + # if we already tried and failed to connect without SSL, raise the error |
| 705 | + if retry_ssl: |
| 706 | + raise e1 |
| 707 | + # disable SSL and try to connect again |
| 708 | + connection_info["ssl"] = None |
| 709 | + _connect(retry_ssl=True) |
| 710 | + elif e1.args[0] == ACCESS_DENIED_ERROR and connection_info["password"] is None: |
| 711 | + # if we already tried and failed to connect with a new password, raise the error |
| 712 | + if retry_password: |
| 713 | + raise e1 |
| 714 | + # ask the user for a new password and try to connect again |
| 715 | + new_password = click.prompt( |
| 716 | + f"Enter password for {user}", hide_input=True, show_default=False, default='', type=str, err=True |
| 717 | + ) |
| 718 | + connection_info["password"] = new_password |
| 719 | + _connect(retry_password=True) |
720 | 720 | else: |
721 | 721 | raise e1 |
722 | 722 |
|
|
0 commit comments