diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 1a0fdd7..462f8c4 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -107,3 +107,18 @@ python -m pytest - `examples/` — example orchestrations (see `examples/README.md`) - `tests/` — test suite - `dev-requirements.txt` — development dependencies + +## Cross-Package Compatibility + +The `durabletask-azuremanaged` package extends the core `durabletask` +package (e.g. `DurableTaskSchedulerWorker` subclasses +`TaskHubGrpcWorker`). When adding or changing features in +`durabletask/`, always verify that `durabletask-azuremanaged` still +works correctly: + +- Check whether the azuremanaged worker, client, or tests override or + depend on the code you changed. +- Run the azuremanaged unit tests if they exist for the affected area. +- If a new public API is added to the core SDK (e.g. a method on + `OrchestrationContext`), confirm it is accessible through the + azuremanaged package and add a test or example if appropriate. diff --git a/docs/features.md b/docs/features.md index ba836c9..123312d 100644 --- a/docs/features.md +++ b/docs/features.md @@ -150,6 +150,49 @@ Orchestrations can be suspended using the `suspend_orchestration` client API and Orchestrations can specify retry policies for activities and sub-orchestrations. These policies control how many times and how frequently an activity or sub-orchestration will be retried in the event of a transient error. +### Replay-safe logging + +Orchestrator functions replay their history each time they are resumed, +which can cause duplicate log messages. The `create_replay_safe_logger` +method on `OrchestrationContext` returns a `ReplaySafeLogger` that wraps +a standard `logging.Logger` and automatically suppresses output while +the orchestrator is replaying. `ReplaySafeLogger` extends Python's +`logging.LoggerAdapter`, which is the idiomatic way to add context or +modify behavior on an existing logger. + +```python +import logging + +from durabletask import task + +logger = logging.getLogger("my_orchestrator") + +def my_orchestrator(ctx: task.OrchestrationContext, payload): + replay_logger = ctx.create_replay_safe_logger(logger) + replay_logger.info("Starting orchestration %s", ctx.instance_id) + result = yield ctx.call_activity(my_activity, input=payload) + replay_logger.info("Activity returned: %s", result) + return result +``` + +> [!NOTE] +> Unlike the .NET SDK, where `CreateReplaySafeLogger` accepts a +> category name string and internally creates the logger via +> `ILoggerFactory`, the Python SDK requires you to pass an existing +> `logging.Logger` instance. This is because Python's +> `logging.getLogger(name)` already serves as the global factory and +> is the standard way to obtain loggers. + +The replay-safe logger supports all standard log levels: `debug`, +`info`, `warning`, `error`, `critical`, and `exception`, as well as +the generic `log(level, msg)` method. It also exposes `isEnabledFor` +which returns `False` during replay so callers can skip expensive +message formatting. + +> [!TIP] +> Create the replay-safe logger once at the start of your orchestrator +> and reuse it throughout the function. + ### Large payload externalization Orchestration inputs, outputs, and event data are transmitted through diff --git a/durabletask/task.py b/durabletask/task.py index 3f8f0f5..8276533 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -4,6 +4,7 @@ # See https://peps.python.org/pep-0563/ from __future__ import annotations +import logging import math from abc import ABC, abstractmethod from datetime import datetime, timedelta, timezone @@ -279,6 +280,51 @@ def new_uuid(self) -> str: def _exit_critical_section(self) -> None: pass + def create_replay_safe_logger(self, logger: logging.Logger) -> ReplaySafeLogger: + """Create a replay-safe logger that suppresses log messages during orchestration replay. + + The returned logger wraps the provided logger and only emits log messages when + the orchestrator is not replaying. This prevents duplicate log messages from + appearing as a side effect of orchestration replay. + + Parameters + ---------- + logger : logging.Logger + The underlying logger to wrap. + + Returns + ------- + ReplaySafeLogger + A logger that only emits log messages when the orchestrator is not replaying. + """ + return ReplaySafeLogger(logger, lambda: self.is_replaying) + + +class ReplaySafeLogger(logging.LoggerAdapter): + """A logger adapter that suppresses log messages during orchestration replay. + + This class extends :class:`logging.LoggerAdapter` and only emits log + messages when the orchestrator is *not* replaying. Use this to avoid + duplicate log entries that would otherwise appear every time the + orchestrator replays its history. + + Obtain an instance by calling :meth:`OrchestrationContext.create_replay_safe_logger`. + """ + + def __init__(self, logger: logging.Logger, is_replaying: Callable[[], bool]) -> None: + super().__init__(logger, {}) + self._is_replaying = is_replaying + + def isEnabledFor(self, level: int) -> bool: + """Return whether logging is enabled for the given level. + + Returns ``False`` while the orchestrator is replaying so that callers + can skip expensive message formatting during replay. + """ + if self._is_replaying(): + return False + return self.logger.isEnabledFor(level) + class FailureDetails: def __init__(self, message: str, error_type: str, stack_trace: Optional[str]): diff --git a/examples/activity_sequence.py b/examples/activity_sequence.py index 420935d..b4b92ea 100644 --- a/examples/activity_sequence.py +++ b/examples/activity_sequence.py @@ -1,5 +1,6 @@ """End-to-end sample that demonstrates how to configure an orchestrator that calls an activity function in a sequence and prints the outputs.""" +import logging import os from azure.identity import DefaultAzureCredential @@ -8,6 +9,8 @@ from durabletask.azuremanaged.client import DurableTaskSchedulerClient from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +logger = logging.getLogger("activity_sequence") + def hello(ctx: task.ActivityContext, name: str) -> str: """Activity function that returns a greeting""" @@ -16,10 +19,15 @@ def hello(ctx: task.ActivityContext, name: str) -> str: def sequence(ctx: task.OrchestrationContext, _): """Orchestrator function that calls the 'hello' activity function in a sequence""" + # Create a replay-safe logger to avoid duplicate log messages during replay + replay_logger = ctx.create_replay_safe_logger(logger) + + replay_logger.info("Starting activity sequence for instance %s", ctx.instance_id) # call "hello" activity function in a sequence result1 = yield ctx.call_activity(hello, input='Tokyo') result2 = yield ctx.call_activity(hello, input='Seattle') result3 = yield ctx.call_activity(hello, input='London') + replay_logger.info("All activities completed") # return an array of results return [result1, result2, result3] diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index ee4c0f2..0134b12 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -1501,6 +1501,198 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert complete_action.result.value == encoded_output +def test_replay_safe_logger_suppresses_during_replay(): + """Validates that the replay-safe logger suppresses log messages during replay.""" + log_calls: list[str] = [] + + class _RecordingHandler(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + log_calls.append(record.getMessage()) + + handler = _RecordingHandler() + inner_logger = logging.getLogger("test_replay_safe_logger") + inner_logger.setLevel(logging.DEBUG) + original_propagate = inner_logger.propagate + inner_logger.propagate = False + inner_logger.addHandler(handler) + + try: + activity_name = "say_hello" + + def say_hello(_, name: str) -> str: + return f"Hello, {name}!" + + def orchestrator(ctx: task.OrchestrationContext, _): + replay_logger = ctx.create_replay_safe_logger(inner_logger) + replay_logger.info("Starting orchestration") + result = yield ctx.call_activity(say_hello, input="World") + replay_logger.info("Activity completed: %s", result) + return result + + registry = worker._Registry() + activity_name = registry.add_activity(say_hello) + orchestrator_name = registry.add_orchestrator(orchestrator) + + # First execution: starts the orchestration. The orchestrator runs without + # replay, emits the initial log message, and then schedules the activity. + new_events = [ + helpers.new_orchestrator_started_event(datetime.now()), + helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), + ] + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, [], new_events) + assert result.actions # should have scheduled the activity + + assert log_calls == ["Starting orchestration"] + log_calls.clear() + + # Second execution: the orchestrator replays from history and then processes the + # activity completion. The "Starting orchestration" message is emitted during + # replay and should be suppressed; "Activity completed" is emitted after replay + # ends and should appear exactly once. + old_events = new_events + [ + helpers.new_task_scheduled_event(1, activity_name), + ] + encoded_output = json.dumps(say_hello(None, "World")) + new_events = [helpers.new_task_completed_event(1, encoded_output)] + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + complete_action = get_and_validate_complete_orchestration_action_list(1, result.actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED + + assert log_calls == ["Activity completed: Hello, World!"] + finally: + inner_logger.removeHandler(handler) + inner_logger.propagate = original_propagate + + +def test_replay_safe_logger_all_levels(): + """Validates that all log levels are suppressed during replay and emitted otherwise.""" + log_levels: list[str] = [] + + class _LevelRecorder(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + log_levels.append(record.levelname) + + handler = _LevelRecorder() + inner_logger = logging.getLogger("test_replay_safe_logger_levels") + inner_logger.setLevel(logging.DEBUG) + original_propagate = inner_logger.propagate + inner_logger.propagate = False + inner_logger.addHandler(handler) + + try: + def orchestrator(ctx: task.OrchestrationContext, _): + replay_logger = ctx.create_replay_safe_logger(inner_logger) + replay_logger.debug("debug msg") + replay_logger.info("info msg") + replay_logger.warning("warning msg") + replay_logger.error("error msg") + replay_logger.critical("critical msg") + return "done" + + registry = worker._Registry() + orchestrator_name = registry.add_orchestrator(orchestrator) + + new_events = [ + helpers.new_orchestrator_started_event(datetime.now()), + helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), + ] + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, [], new_events) + complete_action = get_and_validate_complete_orchestration_action_list(1, result.actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED + + assert log_levels == ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + finally: + inner_logger.removeHandler(handler) + inner_logger.propagate = original_propagate + + +def test_replay_safe_logger_direct(): + """Unit test for ReplaySafeLogger — verifies suppression based on is_replaying flag.""" + log_calls: list[str] = [] + + class _RecordingHandler(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + log_calls.append(record.getMessage()) + + handler = _RecordingHandler() + inner_logger = logging.getLogger("test_replay_safe_logger_direct") + inner_logger.setLevel(logging.DEBUG) + original_propagate = inner_logger.propagate + inner_logger.propagate = False + inner_logger.addHandler(handler) + + try: + replaying = True + replay_logger = task.ReplaySafeLogger(inner_logger, lambda: replaying) + + replay_logger.info("should be suppressed") + assert log_calls == [] + + replaying = False + replay_logger.info("should appear") + assert log_calls == ["should appear"] + finally: + inner_logger.removeHandler(handler) + inner_logger.propagate = original_propagate + + +def test_replay_safe_logger_log_method(): + """Validates the generic log() method respects the replay flag.""" + log_calls: list[str] = [] + + class _RecordingHandler(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + log_calls.append(record.getMessage()) + + handler = _RecordingHandler() + inner_logger = logging.getLogger("test_replay_safe_logger_log_method") + inner_logger.setLevel(logging.DEBUG) + original_propagate = inner_logger.propagate + inner_logger.propagate = False + inner_logger.addHandler(handler) + + try: + replaying = True + replay_logger = task.ReplaySafeLogger(inner_logger, lambda: replaying) + + replay_logger.log(logging.WARNING, "suppressed warning") + assert log_calls == [] + + replaying = False + replay_logger.log(logging.WARNING, "visible warning") + assert log_calls == ["visible warning"] + finally: + inner_logger.removeHandler(handler) + inner_logger.propagate = original_propagate + + +def test_replay_safe_logger_is_enabled_for(): + """Validates isEnabledFor returns False during replay.""" + inner_logger = logging.getLogger("test_replay_safe_logger_enabled") + inner_logger.setLevel(logging.DEBUG) + + replaying = True + replay_logger = task.ReplaySafeLogger(inner_logger, lambda: replaying) + + # During replay, isEnabledFor should always return False + assert replay_logger.isEnabledFor(logging.DEBUG) is False + assert replay_logger.isEnabledFor(logging.INFO) is False + assert replay_logger.isEnabledFor(logging.CRITICAL) is False + + # After replay, delegates to the inner logger + replaying = False + assert replay_logger.isEnabledFor(logging.DEBUG) is True + assert replay_logger.isEnabledFor(logging.INFO) is True + + # If a level is below the inner logger's level, should return False + inner_logger.setLevel(logging.WARNING) + assert replay_logger.isEnabledFor(logging.DEBUG) is False + assert replay_logger.isEnabledFor(logging.WARNING) is True + + def test_when_any_with_retry(): """Tests that a when_any pattern works correctly with retries""" def dummy_activity(_, inp: str):