Skip to content

Commit b15fc19

Browse files
authored
Merge pull request #1829 from scottnemes/feat/440/add-sandbox-mode
Feat/440/add sandbox mode
1 parent 9341037 commit b15fc19

9 files changed

Lines changed: 393 additions & 33 deletions

File tree

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Features
77
* Make `--progress` and `--checkpoint` strictly by statement.
88
* Allow more characters in passwords read from a file.
99
* Show sponsors and contributors separately in startup messages.
10+
* Add support for expired password (sandbox) mode (#440)
1011

1112

1213
Bug Fixes

mycli/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,7 @@
1313

1414
DEFAULT_WIDTH = 80
1515
DEFAULT_HEIGHT = 25
16+
17+
# MySQL error codes not available in pymysql.constants.ER
18+
ER_MUST_CHANGE_PASSWORD_LOGIN = 1862
19+
ER_MUST_CHANGE_PASSWORD = 1820

mycli/main.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
DEFAULT_HOST,
5959
DEFAULT_PORT,
6060
DEFAULT_WIDTH,
61+
ER_MUST_CHANGE_PASSWORD_LOGIN,
6162
ISSUES_URL,
6263
REPO_URL,
6364
)
@@ -152,6 +153,7 @@ def __init__(
152153
self.prompt_session: PromptSession | None = None
153154
self._keepalive_counter = 0
154155
self.keepalive_ticks: int | None = 0
156+
self.sandbox_mode: bool = False
155157

156158
# self.cnf_files is a class variable that stores the list of mysql
157159
# config files to read in at launch.
@@ -750,6 +752,13 @@ def _connect(
750752
keyring_retrieved_cleanly=keyring_retrieved_cleanly,
751753
keyring_save_eligible=keyring_save_eligible,
752754
)
755+
elif e1.args[0] == ER_MUST_CHANGE_PASSWORD_LOGIN:
756+
self.echo(
757+
"Your password has expired and the server rejected the connection.",
758+
err=True,
759+
fg='red',
760+
)
761+
raise e1
753762
elif e1.args[0] == CR_SERVER_LOST:
754763
self.echo(
755764
(
@@ -803,6 +812,15 @@ def _connect(
803812
sys.exit(1)
804813

805814
_connect(keyring_retrieved_cleanly=keyring_retrieved_cleanly)
815+
816+
# Check if SQLExecute detected sandbox mode during connection
817+
if self.sqlexecute and self.sqlexecute.sandbox_mode:
818+
self.sandbox_mode = True
819+
self.echo(
820+
"Your password has expired. Use ALTER USER or SET PASSWSORD to set a new password, or quit.",
821+
err=True,
822+
fg='yellow',
823+
)
806824
except Exception as e: # Connecting to a database could fail.
807825
self.logger.debug("Database connection failed: %r.", e)
808826
self.logger.error("traceback: %r", traceback.format_exc())

mycli/main_modes/repl.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from mycli.constants import (
4040
DEFAULT_HOST,
4141
DEFAULT_WIDTH,
42+
ER_MUST_CHANGE_PASSWORD,
4243
HOME_URL,
4344
ISSUES_URL,
4445
)
@@ -55,8 +56,11 @@
5556
from mycli.packages.ptoolkit.history import FileHistoryWithTimestamp
5657
from mycli.packages.special.utils import format_uptime, get_ssl_version, get_uptime, get_warning_count
5758
from mycli.packages.sql_utils import (
59+
extract_new_password,
5860
is_dropping_database,
5961
is_mutating,
62+
is_password_change,
63+
is_sandbox_allowed,
6064
is_select,
6165
need_completion_refresh,
6266
need_completion_reset,
@@ -132,7 +136,8 @@ def _show_startup_banner(
132136
if mycli.less_chatty:
133137
return
134138

135-
print(sqlexecute.server_info)
139+
if sqlexecute.server_info is not None:
140+
print(sqlexecute.server_info)
136141
print('mycli', mycli_package.__version__)
137142
print(SUPPORT_INFO)
138143
if random.random() <= 0.25:
@@ -232,8 +237,6 @@ def get_prompt(
232237
) -> str:
233238
sqlexecute = mycli.sqlexecute
234239
assert sqlexecute is not None
235-
assert sqlexecute.server_info is not None
236-
assert sqlexecute.server_info.species is not None
237240
if mycli.login_path and mycli.login_path_as_host:
238241
prompt_host = mycli.login_path
239242
elif sqlexecute.host is not None:
@@ -250,7 +253,8 @@ def get_prompt(
250253
string = string.replace('\\h', prompt_host or '(none)')
251254
string = string.replace('\\H', short_prompt_host or '(none)')
252255
string = string.replace('\\d', sqlexecute.dbname or '(none)')
253-
string = string.replace('\\t', sqlexecute.server_info.species.name)
256+
species_name = sqlexecute.server_info.species.name if sqlexecute.server_info and sqlexecute.server_info.species else 'MySQL'
257+
string = string.replace('\\t', species_name)
254258
string = string.replace('\\n', '\n')
255259
string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
256260
string = string.replace('\\m', now.strftime('%M'))
@@ -615,6 +619,14 @@ def _one_iteration(
615619
mycli.echo(str(e), err=True, fg='red')
616620
return
617621

622+
if mycli.sandbox_mode and not is_sandbox_allowed(text):
623+
mycli.echo(
624+
"ERROR 1820: You must reset your password using ALTER USER or SET PASSWORD before executing this statement.",
625+
err=True,
626+
fg='red',
627+
)
628+
return
629+
618630
if mycli.destructive_warning:
619631
destroy = confirm_destructive_query(mycli.destructive_keywords, text)
620632
if destroy is None:
@@ -674,20 +686,44 @@ def _one_iteration(
674686
mycli.echo('Not Yet Implemented.', fg='yellow')
675687
except pymysql.OperationalError as e1:
676688
mycli.logger.debug('Exception: %r', e1)
677-
if e1.args[0] in (2003, 2006, 2013):
689+
if e1.args[0] == ER_MUST_CHANGE_PASSWORD:
690+
mycli.sandbox_mode = True
691+
mycli.echo(
692+
"ERROR 1820: You must reset your password using ALTER USER or SET PASSWORD before executing this statement.",
693+
err=True,
694+
fg='red',
695+
)
696+
elif e1.args[0] in (2003, 2006, 2013):
678697
if not mycli.reconnect():
679698
return
680699
_one_iteration(mycli, state, text)
681700
return
682-
683-
mycli.logger.error('sql: %r, error: %r', text, e1)
684-
mycli.logger.error('traceback: %r', traceback.format_exc())
685-
mycli.echo(str(e1), err=True, fg='red')
701+
else:
702+
mycli.logger.error('sql: %r, error: %r', text, e1)
703+
mycli.logger.error('traceback: %r', traceback.format_exc())
704+
mycli.echo(str(e1), err=True, fg='red')
686705
except Exception as e:
687706
mycli.logger.error('sql: %r, error: %r', text, e)
688707
mycli.logger.error('traceback: %r', traceback.format_exc())
689708
mycli.echo(str(e), err=True, fg='red')
690709
else:
710+
if mycli.sandbox_mode and is_password_change(text):
711+
new_password = extract_new_password(text)
712+
if new_password is not None:
713+
sqlexecute.password = new_password
714+
try:
715+
sqlexecute.connect()
716+
mycli.sandbox_mode = False
717+
mycli.echo("Password changed successfully. Reconnected.", err=True, fg='green')
718+
mycli.refresh_completions()
719+
except Exception as e:
720+
mycli.sandbox_mode = False
721+
mycli.echo(
722+
f"Password changed but reconnection failed: {e}\nPlease restart mycli with your new password.",
723+
err=True,
724+
fg='yellow',
725+
)
726+
691727
if is_dropping_database(text, sqlexecute.dbname):
692728
sqlexecute.dbname = None
693729
sqlexecute.connect()
@@ -756,7 +792,7 @@ def main_repl(mycli: 'MyCli') -> None:
756792
state = ReplState()
757793

758794
mycli.configure_pager()
759-
if mycli.smart_completion:
795+
if mycli.smart_completion and not mycli.sandbox_mode:
760796
mycli.refresh_completions()
761797

762798
history = _create_history(mycli)

mycli/packages/sql_utils.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Generator, Literal
55

66
import sqlglot
7+
import sqlglot.tokens
78
import sqlparse
89
from sqlparse.sql import Function, Identifier, IdentifierList, Token, TokenList
910
from sqlparse.tokens import DML, Keyword, Punctuation
@@ -469,3 +470,85 @@ def is_select(status_plain: str | None) -> bool:
469470
if not status_plain:
470471
return False
471472
return status_plain.split(None, 1)[0].lower() == "select"
473+
474+
475+
def classify_sandbox_statement(text: str) -> tuple[str | None, str | None]:
476+
"""Classify a SQL statement for sandbox mode and extract the new password.
477+
478+
Returns (statement_type, new_password) where statement_type is one of:
479+
- 'alter_user' — ALTER USER ... IDENTIFIED BY ...
480+
- 'set_password' — SET PASSWORD [FOR ...] = ...
481+
- 'quit' — quit, exit, \\q
482+
- None — not allowed in sandbox mode
483+
"""
484+
stripped = text.strip()
485+
if not stripped:
486+
return ('quit', None)
487+
488+
tokens = list(sqlglot.tokenize(stripped, dialect='mysql'))
489+
if not tokens:
490+
return ('quit', None)
491+
492+
types = [t.token_type for t in tokens]
493+
texts = [t.text.upper() for t in tokens]
494+
tt = sqlglot.tokens.TokenType
495+
496+
# quit, exit
497+
if len(tokens) == 1 and types[0] == tt.VAR and texts[0] in ('QUIT', 'EXIT'):
498+
return ('quit', None)
499+
500+
# \q
501+
if len(tokens) == 2 and types[0] == tt.BACKSLASH and texts[1] == 'Q':
502+
return ('quit', None)
503+
504+
# ALTER USER ...
505+
if len(tokens) >= 2 and types[0] == tt.ALTER and texts[1] == 'USER':
506+
pw = _find_password_after_by(tokens)
507+
return ('alter_user', pw)
508+
509+
# SET PASSWORD ...
510+
if len(tokens) >= 2 and types[0] == tt.SET and texts[1] == 'PASSWORD':
511+
pw = _find_password_after_eq(tokens)
512+
return ('set_password', pw)
513+
514+
return (None, None)
515+
516+
517+
def _find_password_after_by(tokens: list[sqlglot.tokens.Token]) -> str | None:
518+
"""Find a password literal following a BY token (for ALTER USER ... IDENTIFIED BY 'pw')."""
519+
tt = sqlglot.tokens.TokenType
520+
for i, tok in enumerate(tokens):
521+
if tok.token_type == tt.VAR and tok.text.upper() == 'BY' and i + 1 < len(tokens):
522+
next_tok = tokens[i + 1]
523+
if next_tok.token_type == tt.STRING:
524+
return next_tok.text
525+
return None
526+
527+
528+
def _find_password_after_eq(tokens: list[sqlglot.tokens.Token]) -> str | None:
529+
"""Find a password literal following an = token (for SET PASSWORD = 'pw')."""
530+
tt = sqlglot.tokens.TokenType
531+
for i, tok in enumerate(tokens):
532+
if tok.token_type == tt.EQ and i + 1 < len(tokens):
533+
next_tok = tokens[i + 1]
534+
if next_tok.token_type == tt.STRING:
535+
return next_tok.text
536+
return None
537+
538+
539+
def is_sandbox_allowed(text: str) -> bool:
540+
"""Return True if the command is allowed in expired-password sandbox mode."""
541+
stmt_type, _ = classify_sandbox_statement(text)
542+
return stmt_type is not None
543+
544+
545+
def is_password_change(text: str) -> bool:
546+
"""Return True if the command is a password change statement."""
547+
stmt_type, _ = classify_sandbox_statement(text)
548+
return stmt_type in ('alter_user', 'set_password')
549+
550+
551+
def extract_new_password(text: str) -> str | None:
552+
"""Extract the new password from an ALTER USER or SET PASSWORD statement."""
553+
_, password = classify_sandbox_statement(text)
554+
return password

mycli/sqlexecute.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pymysql.converters import conversions, convert_date, convert_datetime, convert_time, decoders
1515
from pymysql.cursors import Cursor
1616

17+
from mycli.constants import ER_MUST_CHANGE_PASSWORD
1718
from mycli.packages.special import iocommands
1819
from mycli.packages.special.main import CommandNotFound, execute
1920
from mycli.packages.sqlresult import SQLResult
@@ -280,32 +281,50 @@ def connect(
280281
client_flag = pymysql.constants.CLIENT.INTERACTIVE
281282
if init_command and len(list(iocommands.split_queries(init_command))) > 1:
282283
client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS
284+
client_flag |= pymysql.constants.CLIENT.HANDLE_EXPIRED_PASSWORDS
283285

284286
ssl_context = None
285287
if ssl:
286288
ssl_context = self._create_ssl_ctx(ssl)
287289

288-
conn = pymysql.connect(
289-
database=db,
290-
user=user,
291-
password=password or '',
292-
host=host,
293-
port=port or 0,
294-
unix_socket=socket,
295-
use_unicode=True,
296-
charset=character_set or '',
297-
autocommit=True,
298-
client_flag=client_flag,
299-
local_infile=local_infile or False,
300-
conv=conv,
301-
ssl=ssl_context, # type: ignore[arg-type]
302-
program_name="mycli",
303-
defer_connect=defer_connect,
304-
init_command=init_command or None,
305-
cursorclass=pymysql.cursors.SSCursor if unbuffered else pymysql.cursors.Cursor,
306-
) # type: ignore[misc]
290+
connect_kwargs: dict[str, Any] = {
291+
"database": db,
292+
"user": user,
293+
"password": password or '',
294+
"host": host,
295+
"port": port or 0,
296+
"unix_socket": socket,
297+
"use_unicode": True,
298+
"charset": character_set or '',
299+
"autocommit": True,
300+
"client_flag": client_flag,
301+
"local_infile": local_infile or False,
302+
"conv": conv,
303+
"ssl": ssl_context, # type: ignore[arg-type]
304+
"program_name": "mycli",
305+
"defer_connect": defer_connect,
306+
"init_command": init_command or None,
307+
"cursorclass": pymysql.cursors.SSCursor if unbuffered else pymysql.cursors.Cursor,
308+
}
309+
310+
self.sandbox_mode = False
311+
try:
312+
conn = pymysql.connect(**connect_kwargs) # type: ignore[misc]
313+
except pymysql.OperationalError as e:
314+
if e.args[0] == ER_MUST_CHANGE_PASSWORD:
315+
# Post-handshake queries (SET NAMES, SET AUTOCOMMIT, init_command)
316+
# fail with ER_MUST_CHANGE_PASSWORD in sandbox mode.
317+
# Reconnect with only the raw handshake.
318+
connect_kwargs['defer_connect'] = True
319+
connect_kwargs['autocommit'] = None
320+
connect_kwargs['init_command'] = None
321+
conn = pymysql.connect(**connect_kwargs) # type: ignore[misc]
322+
self._connect_sandbox(conn)
323+
self.sandbox_mode = True
324+
else:
325+
raise
307326

308-
if ssh_host:
327+
if ssh_host and not self.sandbox_mode:
309328
##### paramiko.Channel is a bad socket implementation overall if you want SSL through an SSH tunnel
310329
#####
311330
# instead let's open a tunnel and rewrite host:port to local bind
@@ -343,9 +362,10 @@ def connect(
343362
self.ssl = ssl
344363
self.init_command = init_command
345364
self.unbuffered = unbuffered
346-
# retrieve connection id
347-
self.reset_connection_id()
348-
self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined]
365+
# retrieve connection id (skip in sandbox mode as queries will fail)
366+
if not self.sandbox_mode:
367+
self.reset_connection_id()
368+
self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined]
349369

350370
def run(self, statement: str) -> Generator[SQLResult, None, None]:
351371
"""Execute the sql in the database and return the results."""
@@ -576,6 +596,24 @@ def change_db(self, db: str) -> None:
576596
self.conn.select_db(db)
577597
self.dbname = db
578598

599+
@staticmethod
600+
def _connect_sandbox(conn: Connection) -> None:
601+
"""Connect in sandbox mode, performing only the handshake.
602+
603+
pymysql's normal connect() runs post-handshake queries (SET NAMES,
604+
SET AUTOCOMMIT, init_command) that all fail with ER_MUST_CHANGE_PASSWORD
605+
in sandbox mode. This method performs the raw socket connection and
606+
authentication handshake only.
607+
"""
608+
# Reuse pymysql internals for the handshake + auth, but
609+
# temporarily stub out set_character_set so it becomes a no-op.
610+
original_set_charset = conn.set_character_set
611+
conn.set_character_set = lambda *_args, **_kwargs: None # type: ignore[assignment]
612+
try:
613+
conn.connect()
614+
finally:
615+
conn.set_character_set = original_set_charset # type: ignore[assignment]
616+
579617
def _create_ssl_ctx(self, sslp: dict) -> ssl.SSLContext:
580618
ca = sslp.get("ca")
581619
capath = sslp.get("capath")

0 commit comments

Comments
 (0)