|
25 | 25 |
|
26 | 26 | from renku.core import errors |
27 | 27 | from renku.core.config import get_value |
28 | | -from renku.core.plugin.session import get_supported_session_providers |
| 28 | +from renku.core.plugin.session import get_supported_hibernating_session_providers, get_supported_session_providers |
29 | 29 | from renku.core.session.utils import get_image_repository_host, get_renku_project_name |
30 | 30 | from renku.core.util import communication |
31 | 31 | from renku.core.util.os import safe_read_yaml |
32 | 32 | from renku.core.util.ssh import SystemSSHConfig, generate_ssh_keys |
33 | | -from renku.domain_model.session import ISessionProvider, Session, SessionStopStatus |
| 33 | +from renku.domain_model.session import IHibernatingSessionProvider, ISessionProvider, Session, SessionStopStatus |
34 | 34 |
|
35 | 35 |
|
36 | 36 | def _safe_get_provider(provider: str) -> ISessionProvider: |
@@ -80,6 +80,22 @@ def search_session_providers(name: str) -> List[str]: |
80 | 80 | return [p.name for p in get_supported_session_providers() if p.name.lower().startswith(name)] |
81 | 81 |
|
82 | 82 |
|
| 83 | +@validate_arguments(config=dict(arbitrary_types_allowed=True)) |
| 84 | +def search_hibernating_session_providers(name: str) -> List[str]: |
| 85 | + """Get all session providers that support hibernation and their name starts with the given name. |
| 86 | +
|
| 87 | + Args: |
| 88 | + name(str): The name to search for. |
| 89 | +
|
| 90 | + Returns: |
| 91 | + All session providers whose name starts with ``name``. |
| 92 | + """ |
| 93 | + from renku.core.plugin.session import get_supported_hibernating_session_providers |
| 94 | + |
| 95 | + name = name.lower() |
| 96 | + return [p.name for p in get_supported_hibernating_session_providers() if p.name.lower().startswith(name)] |
| 97 | + |
| 98 | + |
83 | 99 | @validate_arguments(config=dict(arbitrary_types_allowed=True)) |
84 | 100 | def session_list(*, provider: Optional[str] = None) -> SessionList: |
85 | 101 | """List interactive sessions. |
@@ -358,3 +374,94 @@ def ssh_setup(existing_key: Optional[Path] = None, force: bool = False): |
358 | 374 | "This command does not add any public SSH keys to your project. " |
359 | 375 | "Keys have to be added manually or by using the 'renku session start' command with the '--ssh' flag." |
360 | 376 | ) |
| 377 | + |
| 378 | + |
| 379 | +@validate_arguments(config=dict(arbitrary_types_allowed=True)) |
| 380 | +def session_pause(session_name: Optional[str], provider: Optional[str] = None, **kwargs): |
| 381 | + """Pause an interactive session. |
| 382 | +
|
| 383 | + Args: |
| 384 | + session_name(Optional[str]): Name of the session. |
| 385 | + provider(Optional[str]): Name of the session provider to use. |
| 386 | + """ |
| 387 | + |
| 388 | + def pause(session_provider: IHibernatingSessionProvider) -> SessionStopStatus: |
| 389 | + try: |
| 390 | + return session_provider.session_pause(project_name=project_name, session_name=session_name) |
| 391 | + except errors.RenkulabSessionGetUrlError: |
| 392 | + if provider: |
| 393 | + raise |
| 394 | + return SessionStopStatus.FAILED |
| 395 | + |
| 396 | + project_name = get_renku_project_name() |
| 397 | + |
| 398 | + if provider: |
| 399 | + session_provider = _safe_get_provider(provider) |
| 400 | + if session_provider is None: |
| 401 | + raise errors.ParameterError(f"Provider '{provider}' not found") |
| 402 | + elif not isinstance(session_provider, IHibernatingSessionProvider): |
| 403 | + raise errors.ParameterError(f"Provider '{provider}' doesn't support pausing sessions") |
| 404 | + providers = [session_provider] |
| 405 | + else: |
| 406 | + providers = get_supported_hibernating_session_providers() |
| 407 | + |
| 408 | + session_message = f"session {session_name}" if session_name else "session" |
| 409 | + statues = [] |
| 410 | + warning_messages = [] |
| 411 | + with communication.busy(msg=f"Waiting for {session_message} to pause..."): |
| 412 | + for session_provider in sorted(providers, key=lambda p: p.priority): |
| 413 | + try: |
| 414 | + status = pause(session_provider) # type: ignore |
| 415 | + except errors.RenkuException as e: |
| 416 | + warning_messages.append(f"Cannot pause sessions in provider '{session_provider.name}': {e}") |
| 417 | + else: |
| 418 | + statues.append(status) |
| 419 | + |
| 420 | + # NOTE: The given session name was stopped; don't continue |
| 421 | + if session_name and status == SessionStopStatus.SUCCESSFUL: |
| 422 | + break |
| 423 | + |
| 424 | + if warning_messages: |
| 425 | + for message in warning_messages: |
| 426 | + communication.warn(message) |
| 427 | + |
| 428 | + if not statues: |
| 429 | + return |
| 430 | + elif all(s == SessionStopStatus.NO_ACTIVE_SESSION for s in statues): |
| 431 | + raise errors.ParameterError("There are no running sessions.") |
| 432 | + elif session_name and not any(s == SessionStopStatus.SUCCESSFUL for s in statues): |
| 433 | + raise errors.ParameterError(f"Could not find '{session_name}' among the running sessions.") |
| 434 | + elif not session_name and not any(s == SessionStopStatus.SUCCESSFUL for s in statues): |
| 435 | + raise errors.ParameterError("Session name is missing") |
| 436 | + |
| 437 | + |
| 438 | +@validate_arguments(config=dict(arbitrary_types_allowed=True)) |
| 439 | +def session_resume(session_name: Optional[str], provider: Optional[str] = None, **kwargs): |
| 440 | + """Resume a paused session. |
| 441 | +
|
| 442 | + Args: |
| 443 | + session_name(Optional[str]): Name of the session. |
| 444 | + provider(Optional[str]): Name of the session provider to use. |
| 445 | + """ |
| 446 | + project_name = get_renku_project_name() |
| 447 | + |
| 448 | + if provider: |
| 449 | + session_provider = _safe_get_provider(provider) |
| 450 | + if session_provider is None: |
| 451 | + raise errors.ParameterError(f"Provider '{provider}' not found") |
| 452 | + elif not isinstance(session_provider, IHibernatingSessionProvider): |
| 453 | + raise errors.ParameterError(f"Provider '{provider}' doesn't support pausing/resuming sessions") |
| 454 | + providers = [session_provider] |
| 455 | + else: |
| 456 | + providers = get_supported_hibernating_session_providers() |
| 457 | + |
| 458 | + session_message = f"session {session_name}" if session_name else "session" |
| 459 | + with communication.busy(msg=f"Waiting for {session_message} to resume..."): |
| 460 | + for session_provider in providers: |
| 461 | + if session_provider.session_resume(project_name, session_name, **kwargs): # type: ignore |
| 462 | + return |
| 463 | + |
| 464 | + if session_name: |
| 465 | + raise errors.ParameterError(f"Could not find '{session_name}' among the sessions.") |
| 466 | + else: |
| 467 | + raise errors.ParameterError("Session name is missing") |
0 commit comments