diff --git a/codecarbon/cli/main.py b/codecarbon/cli/main.py index d5175ea9c..fd6545a3f 100644 --- a/codecarbon/cli/main.py +++ b/codecarbon/cli/main.py @@ -377,7 +377,7 @@ def monitor( # If extra args are provided (e.g. `codecarbon monitor -- my_script.py`), delegate to `run_and_monitor` if getattr(ctx, "args", None): - return run_and_monitor(ctx, **tracker_args) + return run_and_monitor(ctx, offline=offline, **tracker_args) # Instantiate the tracker if offline: diff --git a/codecarbon/cli/monitor.py b/codecarbon/cli/monitor.py index e5925e3fe..98fa4e244 100644 --- a/codecarbon/cli/monitor.py +++ b/codecarbon/cli/monitor.py @@ -8,7 +8,7 @@ from rich import print from typing_extensions import Annotated -from codecarbon.emissions_tracker import EmissionsTracker +from codecarbon.emissions_tracker import EmissionsTracker, OfflineEmissionsTracker def run_and_monitor( @@ -17,6 +17,7 @@ def run_and_monitor( str, typer.Option(help="Log level (critical, error, warning, info, debug)"), ] = "error", + offline: bool = False, **tracker_args, ): """ @@ -63,8 +64,8 @@ def run_and_monitor( ) raise typer.Exit(1) - # Initialize tracker with specified logging level and shared args - tracker = EmissionsTracker( + tracker_cls = OfflineEmissionsTracker if offline else EmissionsTracker + tracker = tracker_cls( log_level=log_level, save_to_logger=False, tracking_mode="process", diff --git a/tests/cli/test_cli_main.py b/tests/cli/test_cli_main.py index f4cb2fef7..2319dadad 100644 --- a/tests/cli/test_cli_main.py +++ b/tests/cli/test_cli_main.py @@ -303,6 +303,45 @@ def stop(self): assert calls["kwargs"]["region"] == "IDF" +def test_monitor_delegates_offline_flag_to_run_and_monitor(monkeypatch): + captured = {} + + def fake_run_and_monitor(ctx, offline=False, **kwargs): + captured["offline"] = offline + captured["kwargs"] = kwargs + return "ok" + + monkeypatch.setattr(cli_main, "run_and_monitor", fake_run_and_monitor) + + ctx = SimpleNamespace(args=["python", "-c", "print(1)"]) + result = cli_main.monitor( + ctx=ctx, + offline=True, + country_iso_code="FRA", + ) + assert result == "ok" + assert captured["offline"] is True + assert captured["kwargs"]["country_iso_code"] == "FRA" + + +def test_monitor_delegates_online_mode_to_run_and_monitor(monkeypatch): + captured = {} + + def fake_run_and_monitor(ctx, offline=False, **kwargs): + captured["offline"] = offline + captured["kwargs"] = kwargs + return "ok" + + monkeypatch.setattr(cli_main, "run_and_monitor", fake_run_and_monitor) + monkeypatch.setattr(cli_main, "get_existing_exp_id", lambda: "exp-1") + + ctx = SimpleNamespace(args=["python", "train.py"]) + result = cli_main.monitor(ctx=ctx, api=True) + assert result == "ok" + assert captured["offline"] is False + assert captured["kwargs"]["save_to_api"] is True + + def test_monitor_delegates_to_run_and_monitor_with_extra_args(monkeypatch): captured = {} @@ -319,3 +358,21 @@ def fake_run_and_monitor(ctx, **kwargs): assert result == "ok" assert captured["args"] == ["python", "train.py"] assert captured["kwargs"]["save_to_api"] is False + + +def test_monitor_no_api_skips_experiment_id_requirement(monkeypatch): + captured = {} + + def fake_run_and_monitor(ctx, offline=False, **kwargs): + captured["offline"] = offline + captured["kwargs"] = kwargs + return "ok" + + monkeypatch.setattr(cli_main, "run_and_monitor", fake_run_and_monitor) + monkeypatch.setattr(cli_main, "get_existing_exp_id", lambda: None) + + ctx = SimpleNamespace(args=["python", "train.py"]) + result = cli_main.monitor(ctx=ctx, api=False) + assert result == "ok" + assert captured["offline"] is False + assert captured["kwargs"]["save_to_api"] is False diff --git a/tests/cli/test_monitor.py b/tests/cli/test_monitor.py index 907a8dd8e..d4dd718a2 100644 --- a/tests/cli/test_monitor.py +++ b/tests/cli/test_monitor.py @@ -60,6 +60,68 @@ def __init__(self, command, text=True): assert exc_info.value.exit_code == 1 +def test_run_and_monitor_uses_offline_tracker_when_offline_mode(monkeypatch): + captured = {} + + class FakeOfflineTracker(FakeTracker): + def __init__(self, **kwargs): + captured["kwargs"] = kwargs + super().__init__() + + class FakePopen: + def __init__(self, command, text=True): + pass + + def wait(self): + return 0 + + monkeypatch.setattr(monitor_module, "OfflineEmissionsTracker", FakeOfflineTracker) + monkeypatch.setattr(monitor_module, "EmissionsTracker", FakeTracker) + monkeypatch.setattr(monitor_module.subprocess, "Popen", FakePopen) + monkeypatch.setattr(monitor_module, "print", lambda *args, **kwargs: None) + + with pytest.raises(typer.Exit) as exc_info: + monitor_module.run_and_monitor( + SimpleNamespace(args=["echo", "hi"]), + offline=True, + country_iso_code="FRA", + ) + + assert exc_info.value.exit_code == 0 + assert captured["kwargs"]["country_iso_code"] == "FRA" + + +def test_run_and_monitor_uses_online_tracker_by_default(monkeypatch): + captured = {} + + class FakeOnlineTracker(FakeTracker): + def __init__(self, **kwargs): + captured["kwargs"] = kwargs + super().__init__() + + class FakePopen: + def __init__(self, command, text=True): + pass + + def wait(self): + return 0 + + monkeypatch.setattr(monitor_module, "EmissionsTracker", FakeOnlineTracker) + monkeypatch.setattr(monitor_module, "OfflineEmissionsTracker", FakeTracker) + monkeypatch.setattr(monitor_module.subprocess, "Popen", FakePopen) + monkeypatch.setattr(monitor_module, "print", lambda *args, **kwargs: None) + + with pytest.raises(typer.Exit) as exc_info: + monitor_module.run_and_monitor( + SimpleNamespace(args=["echo", "hi"]), + save_to_api=True, + ) + + assert exc_info.value.exit_code == 0 + assert captured["kwargs"]["tracking_mode"] == "process" + assert captured["kwargs"]["save_to_api"] is True + + def test_run_and_monitor_handles_keyboard_interrupt(monkeypatch): process_info = {"terminated": 0, "killed": 0}