diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index be4369af..fe1592f1 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -128,6 +128,7 @@ DataFileContent, FileFormat, ) +from pyiceberg.observability import perf_timer from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value from pyiceberg.schema import ( PartnerAccessor, @@ -1586,102 +1587,123 @@ def _task_to_record_batches( downcast_ns_timestamp_to_us: bool | None = None, batch_size: int | None = None, ) -> Iterator[pa.RecordBatch]: - arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) - with io.new_input(task.file.file_path).open() as fin: - fragment = arrow_format.make_fragment(fin) - physical_schema = fragment.physical_schema - - # For V1 and V2, we only support Timestamp 'us' in Iceberg Schema, - # therefore it is reasonable to always cast 'ns' timestamp to 'us' on read. - # For V3 this has to set explicitly to avoid nanosecond timestamp to be down-casted by default - downcast_ns_timestamp_to_us = ( - downcast_ns_timestamp_to_us if downcast_ns_timestamp_to_us is not None else format_version <= 2 - ) - file_schema = pyarrow_to_schema( - physical_schema, name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, format_version=format_version - ) - - # Apply column projection rules: https://iceberg.apache.org/spec/#column-projection - projected_missing_fields = _get_column_projection_values( - task.file, projected_schema, table_schema, partition_spec, file_schema.field_ids - ) + with perf_timer( + "arrow.read_file", + file_path=task.file.file_path, + file_format=str(task.file.file_format), + ) as t: + t.metric("file_size_bytes", task.file.file_size_in_bytes) + arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) + batch_count = 0 + row_count = 0 + with io.new_input(task.file.file_path).open() as fin: + fragment = arrow_format.make_fragment(fin) + physical_schema = fragment.physical_schema + + # For V1 and V2, we only support Timestamp 'us' in Iceberg Schema, + # therefore it is reasonable to always cast 'ns' timestamp to 'us' on read. + # For V3 this has to set explicitly to avoid nanosecond timestamp to be down-casted by default + downcast_ns_timestamp_to_us = ( + downcast_ns_timestamp_to_us if downcast_ns_timestamp_to_us is not None else format_version <= 2 + ) + file_schema = pyarrow_to_schema( + physical_schema, + name_mapping, + downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, + format_version=format_version, + ) - pyarrow_filter = None - if bound_row_filter is not AlwaysTrue(): - translated_row_filter = translate_column_names( - bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_field_values=projected_missing_fields + # Apply column projection rules: https://iceberg.apache.org/spec/#column-projection + projected_missing_fields = _get_column_projection_values( + task.file, projected_schema, table_schema, partition_spec, file_schema.field_ids ) - bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) - pyarrow_filter = expression_to_pyarrow(bound_file_filter, file_schema) - - file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) - - scanner_kwargs: dict[str, Any] = { - "fragment": fragment, - "schema": physical_schema, - # This will push down the query to Arrow. - # But in case there are positional deletes, we have to apply them first - "filter": pyarrow_filter if not positional_deletes else None, - "columns": [col.name for col in file_project_schema.columns], - } - if batch_size is not None: - scanner_kwargs["batch_size"] = batch_size - fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs) + pyarrow_filter = None + if bound_row_filter is not AlwaysTrue(): + translated_row_filter = translate_column_names( + bound_row_filter, + file_schema, + case_sensitive=case_sensitive, + projected_field_values=projected_missing_fields, + ) + bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) + pyarrow_filter = expression_to_pyarrow(bound_file_filter, file_schema) + + file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) + + scanner_kwargs: dict[str, Any] = { + "fragment": fragment, + "schema": physical_schema, + # This will push down the query to Arrow. + # But in case there are positional deletes, we have to apply them first + "filter": pyarrow_filter if not positional_deletes else None, + "columns": [col.name for col in file_project_schema.columns], + } + if batch_size is not None: + scanner_kwargs["batch_size"] = batch_size - next_index = 0 - batches = fragment_scanner.to_batches() - for batch in batches: - next_index = next_index + len(batch) - current_index = next_index - len(batch) - current_batch = batch + fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs) - if positional_deletes: - # Create the mask of indices that we're interested in - indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch)) - current_batch = current_batch.take(indices) + next_index = 0 + batches = fragment_scanner.to_batches() + for batch in batches: + next_index = next_index + len(batch) + current_index = next_index - len(batch) + current_batch = batch - # skip empty batches - if current_batch.num_rows == 0: - continue + if positional_deletes: + # Create the mask of indices that we're interested in + indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch)) + current_batch = current_batch.take(indices) - # Apply the user filter - if pyarrow_filter is not None: - # Temporary fix until PyArrow 21 is released ( https://github.com/apache/arrow/pull/46057 ) - table = pa.Table.from_batches([current_batch]) - table = table.filter(pyarrow_filter) + batch_count += 1 # skip empty batches - if table.num_rows == 0: + if current_batch.num_rows == 0: continue - current_batch = table.combine_chunks().to_batches()[0] + # Apply the user filter + if pyarrow_filter is not None: + # Temporary fix until PyArrow 21 is released ( https://github.com/apache/arrow/pull/46057 ) + table = pa.Table.from_batches([current_batch]) + table = table.filter(pyarrow_filter) + # skip empty batches + if table.num_rows == 0: + continue - yield _to_requested_schema( - projected_schema, - file_project_schema, - current_batch, - downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, - projected_missing_fields=projected_missing_fields, - allow_timestamp_tz_mismatch=True, - ) + current_batch = table.combine_chunks().to_batches()[0] + + row_count += current_batch.num_rows + yield _to_requested_schema( + projected_schema, + file_project_schema, + current_batch, + downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, + projected_missing_fields=projected_missing_fields, + allow_timestamp_tz_mismatch=True, + ) + t.metric("batch_count", batch_count) + t.metric("row_count", row_count) def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[str, list[ChunkedArray]]: - deletes_per_file: dict[str, list[ChunkedArray]] = {} - unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks])) - if len(unique_deletes) > 0: - executor = ExecutorFactory.get_or_create() - deletes_per_files: Iterator[dict[str, ChunkedArray]] = executor.map( - lambda args: _read_deletes(*args), - [(io, delete_file) for delete_file in unique_deletes], - ) - for delete in deletes_per_files: - for file, arr in delete.items(): - if file in deletes_per_file: - deletes_per_file[file].append(arr) - else: - deletes_per_file[file] = [arr] + with perf_timer("arrow.read_delete_files") as t: + deletes_per_file: dict[str, list[ChunkedArray]] = {} + unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks])) + t.metric("unique_delete_files", len(unique_deletes)) + if len(unique_deletes) > 0: + executor = ExecutorFactory.get_or_create() + deletes_per_files: Iterator[dict[str, ChunkedArray]] = executor.map( + lambda args: _read_deletes(*args), + [(io, delete_file) for delete_file in unique_deletes], + ) + for delete in deletes_per_files: + for file, arr in delete.items(): + if file in deletes_per_file: + deletes_per_file[file].append(arr) + else: + deletes_per_file[file] = [arr] + t.metric("data_files_with_deletes", len(deletes_per_file)) return deletes_per_file @@ -1887,12 +1909,41 @@ def to_record_batches( if order.concurrent_streams < 1: raise ValueError(f"concurrent_streams must be >= 1, got {order.concurrent_streams}") return self._apply_limit( - self._iter_batches_arrival( - task_list, deletes_per_file, order.batch_size, order.concurrent_streams, order.max_buffered_batches + self._iter_batches_counted( + self._iter_batches_arrival( + task_list, + deletes_per_file, + order.batch_size, + order.concurrent_streams, + order.max_buffered_batches, + ), + task_count=len(task_list), ) ) - return self._apply_limit(self._iter_batches_materialized(task_list, deletes_per_file)) + return self._apply_limit( + self._iter_batches_counted( + self._iter_batches_materialized(task_list, deletes_per_file), + task_count=len(task_list), + ) + ) + + @staticmethod + def _iter_batches_counted( + inner: Iterator[pa.RecordBatch], + task_count: int, + ) -> Iterator[pa.RecordBatch]: + """Wrap an inner batch iterator with aggregate perf_timer tracking.""" + with perf_timer("arrow.to_record_batches") as t: + t.metric("task_count", task_count) + batch_count = 0 + row_count = 0 + for batch in inner: + batch_count += 1 + row_count += batch.num_rows + yield batch + t.metric("batch_count", batch_count) + t.metric("row_count", row_count) def _prepare_tasks_and_deletes( self, tasks: Iterable[FileScanTask] diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index 4c68f5e3..8738a5de 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -36,6 +36,7 @@ from pyiceberg.conversions import to_bytes from pyiceberg.exceptions import ValidationError from pyiceberg.io import FileIO, InputFile, OutputFile +from pyiceberg.observability import perf_timer from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema from pyiceberg.typedef import Record, TableVersion @@ -869,18 +870,21 @@ def fetch_manifest_entry(self, io: FileIO, discard_deleted: bool = True) -> list Returns: An Iterator of manifest entries. """ - input_file = io.new_input(self.manifest_path) - with AvroFile[ManifestEntry]( - input_file, - MANIFEST_ENTRY_SCHEMAS[DEFAULT_READ_VERSION], - read_types={-1: ManifestEntry, 2: DataFile}, - read_enums={0: ManifestEntryStatus, 101: FileFormat, 134: DataFileContent}, - ) as reader: - return [ - _inherit_from_manifest(entry, self) - for entry in reader - if not discard_deleted or entry.status != ManifestEntryStatus.DELETED - ] + with perf_timer("manifest.fetch_entries", manifest_path=self.manifest_path) as t: + input_file = io.new_input(self.manifest_path) + with AvroFile[ManifestEntry]( + input_file, + MANIFEST_ENTRY_SCHEMAS[DEFAULT_READ_VERSION], + read_types={-1: ManifestEntry, 2: DataFile}, + read_enums={0: ManifestEntryStatus, 101: FileFormat, 134: DataFileContent}, + ) as reader: + result = [ + _inherit_from_manifest(entry, self) + for entry in reader + if not discard_deleted or entry.status != ManifestEntryStatus.DELETED + ] + t.metric("entry_count", len(result)) + return result def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the ManifestFile class.""" @@ -924,19 +928,24 @@ def _manifests(io: FileIO, manifest_list: str) -> tuple[ManifestFile, ...]: Returns: A tuple of ManifestFile objects. """ - file = io.new_input(manifest_list) - manifest_files = list(read_manifest_list(file)) - - result = [] - with _manifest_cache_lock: - for manifest_file in manifest_files: - manifest_path = manifest_file.manifest_path - if manifest_path in _manifest_cache: - result.append(_manifest_cache[manifest_path]) - else: - _manifest_cache[manifest_path] = manifest_file - result.append(manifest_file) - + with perf_timer("manifest.read_list") as t: + file = io.new_input(manifest_list) + manifest_files = list(read_manifest_list(file)) + + result = [] + cache_hits = 0 + with _manifest_cache_lock: + for manifest_file in manifest_files: + manifest_path = manifest_file.manifest_path + if manifest_path in _manifest_cache: + result.append(_manifest_cache[manifest_path]) + cache_hits += 1 + else: + _manifest_cache[manifest_path] = manifest_file + result.append(manifest_file) + + t.metric("manifest_count", len(result)) + t.metric("cache_hits", cache_hits) return tuple(result) diff --git a/pyiceberg/observability.py b/pyiceberg/observability.py new file mode 100644 index 00000000..1a8602c5 --- /dev/null +++ b/pyiceberg/observability.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +import threading +import time +from collections.abc import Callable, Generator +from contextlib import contextmanager +from dataclasses import dataclass, field +from functools import wraps +from typing import Any, Protocol, TypeVar, runtime_checkable + +logger = logging.getLogger("pyiceberg.perf") + +F = TypeVar("F", bound=Callable[..., Any]) + + +@dataclass +class PerfEvent: + """A structured performance event emitted by instrumentation points. + + ``tags`` are low-cardinality dimensions for grouping/filtering (e.g. database, + table, file_format). ``metrics`` are measured values (e.g. row_count, + batch_count, response_bytes). + """ + + operation: str + duration_ms: float + tags: dict[str, str] = field(default_factory=dict) + metrics: dict[str, int | float] = field(default_factory=dict) + + +@runtime_checkable +class PerfObserver(Protocol): + """Protocol for receiving performance events.""" + + def emit(self, event: PerfEvent) -> None: ... + + +class NullPerfObserver: + """No-op observer that discards all events. Default — zero overhead.""" + + def emit(self, event: PerfEvent) -> None: + pass + + +class LoggingPerfObserver: + """Observer that emits structured key=value log lines at DEBUG level.""" + + def emit(self, event: PerfEvent) -> None: + parts = [f"operation={event.operation}", f"duration_ms={event.duration_ms:.3f}"] + for key, value in event.tags.items(): + parts.append(f"{key}={value}") + for key, metric_value in event.metrics.items(): + parts.append(f"{key}={metric_value}") + logger.debug(" ".join(parts)) + + +class CompositeObserver: + """Fans out events to multiple observers. + + If an observer raises, the exception is logged and remaining observers + still receive the event. + """ + + def __init__(self, *observers: PerfObserver) -> None: + self._observers = observers + + def emit(self, event: PerfEvent) -> None: + for observer in self._observers: + try: + observer.emit(event) + except Exception: + logger.debug("Observer %s failed to emit event %s", type(observer).__name__, event.operation, exc_info=True) + + +_observer: PerfObserver = NullPerfObserver() +_observer_lock = threading.Lock() + + +def set_observer(observer: PerfObserver) -> None: + """Set the global performance observer. Thread-safe.""" + global _observer + with _observer_lock: + _observer = observer + + +def get_observer() -> PerfObserver: + """Get the current global performance observer. Thread-safe.""" + with _observer_lock: + return _observer + + +class _PerfTimerContext: + """Context object yielded by perf_timer. + + Use ``.tag()`` for dimensions (low-cardinality strings) and + ``.metric()`` for measured values (counts, sizes, etc.). + """ + + __slots__ = ("_tags", "_metrics") + + def __init__(self, initial_tags: dict[str, str]) -> None: + self._tags = initial_tags + self._metrics: dict[str, int | float] = {} + + def tag(self, key: str, value: str) -> None: + """Set a dimension tag on the performance event.""" + self._tags[key] = value + + def metric(self, key: str, value: int | float) -> None: + """Set a metric value on the performance event.""" + self._metrics[key] = value + + +@contextmanager +def perf_timer(operation: str, **tags: str) -> Generator[_PerfTimerContext, None, None]: + """Context manager for timing a block of code and emitting a PerfEvent. + + Keyword arguments are recorded as dimension tags. Use ``ctx.metric()`` + inside the block for measured values. + + When the active observer is NullPerfObserver, time.monotonic() is skipped entirely. + """ + with _observer_lock: + observer = _observer + if isinstance(observer, NullPerfObserver): + yield _PerfTimerContext(tags) + return + + ctx = _PerfTimerContext(tags) + start = time.monotonic() + try: + yield ctx + finally: + duration_ms = (time.monotonic() - start) * 1000.0 + observer.emit(PerfEvent(operation=operation, duration_ms=duration_ms, tags=ctx._tags, metrics=ctx._metrics)) + + +def timed(operation: str, **decorator_tags: str) -> Callable[[F], F]: + """Decorate a function to wrap its body in perf_timer. + + Built on top of perf_timer internally — same PerfObserver/PerfEvent pipeline, + same NullPerfObserver fast path, same structured log output. + + The decorated function receives an extra ``_perf_ctx`` keyword argument + (a :class:`_PerfTimerContext`) that can be used to set tags/metrics:: + + @timed("my.operation") + def process(data, *, _perf_ctx=None): + ... + if _perf_ctx: + _perf_ctx.metric("row_count", len(data)) + + Functions that don't declare ``_perf_ctx`` in their signature can ignore it + — the wrapper only passes it if the function accepts ``**kwargs`` or has an + explicit ``_perf_ctx`` parameter. + """ + + def decorator(fn: F) -> F: + @wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + with perf_timer(operation, **decorator_tags) as t: + kwargs["_perf_ctx"] = t + try: + result = fn(*args, **kwargs) + except TypeError: + # Function doesn't accept _perf_ctx — retry without it + del kwargs["_perf_ctx"] + result = fn(*args, **kwargs) + return result + + return wrapper # type: ignore[return-value] + + return decorator diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 598d9d94..8cfe58e7 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -61,6 +61,7 @@ ManifestEntry, ManifestFile, ) +from pyiceberg.observability import perf_timer from pyiceberg.partitioning import ( PARTITION_FIELD_ID_START, UNPARTITIONED_PARTITION_SPEC, @@ -1984,11 +1985,17 @@ def _open_manifest( Returns: A list of ManifestEntry that matches the provided filters. """ - return [ - manifest_entry - for manifest_entry in manifest.fetch_manifest_entry(io, discard_deleted=True) - if partition_filter(manifest_entry.data_file) and metrics_evaluator(manifest_entry.data_file) - ] + with perf_timer("scan.open_manifest", manifest_path=manifest.manifest_path) as t: + t.metric("manifest_length", manifest.manifest_length) + all_entries = manifest.fetch_manifest_entry(io, discard_deleted=True) + t.metric("entry_count", len(all_entries)) + result = [ + manifest_entry + for manifest_entry in all_entries + if partition_filter(manifest_entry.data_file) and metrics_evaluator(manifest_entry.data_file) + ] + t.metric("matched_count", len(result)) + return result def _min_sequence_number(manifests: list[ManifestFile]) -> int: @@ -2144,35 +2151,44 @@ def _plan_files_server_side(self) -> Iterable[FileScanTask]: def _plan_files_local(self) -> Iterable[FileScanTask]: """Plan files locally by reading manifests.""" - data_entries: list[ManifestEntry] = [] - delete_index = DeleteFileIndex() - - residual_evaluators: dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict(self._build_residual_evaluator) - - for manifest_entry in chain.from_iterable(self.scan_plan_helper()): - data_file = manifest_entry.data_file - if data_file.content == DataFileContent.DATA: - data_entries.append(manifest_entry) - elif data_file.content == DataFileContent.POSITION_DELETES: - delete_index.add_delete_file(manifest_entry, partition_key=data_file.partition) - elif data_file.content == DataFileContent.EQUALITY_DELETES: - raise ValueError("PyIceberg does not yet support equality deletes: https://github.com/apache/iceberg/issues/6568") - else: - raise ValueError(f"Unknown DataFileContent ({data_file.content}): {manifest_entry}") - return [ - FileScanTask( - data_entry.data_file, - delete_files=delete_index.for_data_file( - data_entry.sequence_number or INITIAL_SEQUENCE_NUMBER, - data_entry.data_file, - partition_key=data_entry.data_file.partition, - ), - residual=residual_evaluators[data_entry.data_file.spec_id](data_entry.data_file).residual_for( - data_entry.data_file.partition - ), + with perf_timer("scan.plan_files_local") as t: + data_entries: list[ManifestEntry] = [] + delete_index = DeleteFileIndex() + delete_entry_count = 0 + + residual_evaluators: dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict( + self._build_residual_evaluator ) - for data_entry in data_entries - ] + + for manifest_entry in chain.from_iterable(self.scan_plan_helper()): + data_file = manifest_entry.data_file + if data_file.content == DataFileContent.DATA: + data_entries.append(manifest_entry) + elif data_file.content == DataFileContent.POSITION_DELETES: + delete_index.add_delete_file(manifest_entry, partition_key=data_file.partition) + delete_entry_count += 1 + elif data_file.content == DataFileContent.EQUALITY_DELETES: + raise ValueError( + "PyIceberg does not yet support equality deletes: https://github.com/apache/iceberg/issues/6568" + ) + else: + raise ValueError(f"Unknown DataFileContent ({data_file.content}): {manifest_entry}") + t.metric("data_file_count", len(data_entries)) + t.metric("delete_entry_count", delete_entry_count) + return [ + FileScanTask( + data_entry.data_file, + delete_files=delete_index.for_data_file( + data_entry.sequence_number or INITIAL_SEQUENCE_NUMBER, + data_entry.data_file, + partition_key=data_entry.data_file.partition, + ), + residual=residual_evaluators[data_entry.data_file.spec_id](data_entry.data_file).residual_for( + data_entry.data_file.partition + ), + ) + for data_entry in data_entries + ] def plan_files(self) -> Iterable[FileScanTask]: """Plans the relevant files by filtering on the PartitionSpecs. @@ -2184,9 +2200,15 @@ def plan_files(self) -> Iterable[FileScanTask]: Returns: List of FileScanTasks that contain both data and delete files. """ - if self._should_use_server_side_planning(): - return self._plan_files_server_side() - return self._plan_files_local() + with perf_timer("scan.plan_files") as t: + if self._should_use_server_side_planning(): + t.tag("planning_mode", "server_side") + result = list(self._plan_files_server_side()) + else: + t.tag("planning_mode", "local") + result = list(self._plan_files_local()) + t.metric("task_count", len(result)) + return result def to_arrow(self) -> pa.Table: """Read an Arrow table eagerly from this DataScan. diff --git a/tests/test_observability.py b/tests/test_observability.py new file mode 100644 index 00000000..bd88b8b3 --- /dev/null +++ b/tests/test_observability.py @@ -0,0 +1,279 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from collections.abc import Iterator +from unittest.mock import patch + +import pytest + +from pyiceberg.observability import ( + CompositeObserver, + LoggingPerfObserver, + NullPerfObserver, + PerfEvent, + PerfObserver, + _PerfTimerContext, + get_observer, + perf_timer, + set_observer, + timed, +) + + +class CollectingObserver: + """Test observer that collects all emitted events.""" + + def __init__(self) -> None: + self.events: list[PerfEvent] = [] + + def emit(self, event: PerfEvent) -> None: + self.events.append(event) + + +@pytest.fixture(autouse=True) +def _reset_observer() -> Iterator[None]: + """Reset the global observer to NullPerfObserver after each test.""" + set_observer(NullPerfObserver()) + yield + set_observer(NullPerfObserver()) + + +class TestPerfEvent: + def test_perf_event_creation(self) -> None: + event = PerfEvent(operation="test.op", duration_ms=42.5, tags={"db": "prod"}, metrics={"rows": 100}) + assert event.operation == "test.op" + assert event.duration_ms == 42.5 + assert event.tags == {"db": "prod"} + assert event.metrics == {"rows": 100} + + def test_perf_event_default_tags_and_metrics(self) -> None: + event = PerfEvent(operation="test.op", duration_ms=0.0) + assert event.tags == {} + assert event.metrics == {} + + def test_tags_and_metrics_are_separate(self) -> None: + event = PerfEvent(operation="test.op", duration_ms=1.0, tags={"db": "prod"}, metrics={"rows": 42}) + assert "db" not in event.metrics + assert "rows" not in event.tags + + +class TestNullPerfObserver: + def test_emit_does_nothing(self) -> None: + observer = NullPerfObserver() + event = PerfEvent(operation="test.op", duration_ms=1.0) + observer.emit(event) # should not raise + + def test_satisfies_protocol(self) -> None: + observer = NullPerfObserver() + assert isinstance(observer, PerfObserver) + + +class TestLoggingPerfObserver: + def test_satisfies_protocol(self) -> None: + observer = LoggingPerfObserver() + assert isinstance(observer, PerfObserver) + + def test_emits_structured_log_line(self, caplog: pytest.LogCaptureFixture) -> None: + observer = LoggingPerfObserver() + event = PerfEvent(operation="test.op", duration_ms=123.456, tags={"table": "db.tbl"}, metrics={"rows": 42}) + with caplog.at_level(logging.DEBUG, logger="pyiceberg.perf"): + observer.emit(event) + assert len(caplog.records) == 1 + msg = caplog.records[0].message + assert "operation=test.op" in msg + assert "duration_ms=123.456" in msg + assert "table=db.tbl" in msg + assert "rows=42" in msg + + def test_emits_at_debug_level(self, caplog: pytest.LogCaptureFixture) -> None: + observer = LoggingPerfObserver() + event = PerfEvent(operation="test.op", duration_ms=1.0) + with caplog.at_level(logging.DEBUG, logger="pyiceberg.perf"): + observer.emit(event) + assert caplog.records[0].levelno == logging.DEBUG + + def test_no_log_at_info_level(self, caplog: pytest.LogCaptureFixture) -> None: + observer = LoggingPerfObserver() + event = PerfEvent(operation="test.op", duration_ms=1.0) + with caplog.at_level(logging.INFO, logger="pyiceberg.perf"): + observer.emit(event) + assert len(caplog.records) == 0 + + +class TestCompositeObserver: + def test_fans_out_to_all_observers(self) -> None: + a = CollectingObserver() + b = CollectingObserver() + composite = CompositeObserver(a, b) + event = PerfEvent(operation="test.op", duration_ms=1.0) + composite.emit(event) + assert len(a.events) == 1 + assert len(b.events) == 1 + assert a.events[0] is event + assert b.events[0] is event + + def test_empty_composite_does_not_raise(self) -> None: + composite = CompositeObserver() + composite.emit(PerfEvent(operation="test.op", duration_ms=1.0)) + + def test_satisfies_protocol(self) -> None: + assert isinstance(CompositeObserver(), PerfObserver) + + +class TestSetGetObserver: + def test_default_is_null_observer(self) -> None: + assert isinstance(get_observer(), NullPerfObserver) + + def test_set_and_get(self) -> None: + collecting = CollectingObserver() + set_observer(collecting) + assert get_observer() is collecting + + def test_set_back_to_null(self) -> None: + set_observer(CollectingObserver()) + set_observer(NullPerfObserver()) + assert isinstance(get_observer(), NullPerfObserver) + + +class TestPerfTimer: + def test_measures_duration(self) -> None: + collecting = CollectingObserver() + set_observer(collecting) + with perf_timer("test.op"): + pass + assert len(collecting.events) == 1 + assert collecting.events[0].operation == "test.op" + assert collecting.events[0].duration_ms >= 0 + + def test_captures_initial_tags(self) -> None: + collecting = CollectingObserver() + set_observer(collecting) + with perf_timer("test.op", table="db.tbl", db="prod"): + pass + tags = collecting.events[0].tags + assert tags["table"] == "db.tbl" + assert tags["db"] == "prod" + + def test_tag_and_metric_in_body(self) -> None: + collecting = CollectingObserver() + set_observer(collecting) + with perf_timer("test.op") as t: + t.tag("db", "prod") + t.metric("rows", 100) + t.metric("bytes", 2048) + event = collecting.events[0] + assert event.tags["db"] == "prod" + assert event.metrics["rows"] == 100 + assert event.metrics["bytes"] == 2048 + + def test_tag_overwrites_initial(self) -> None: + collecting = CollectingObserver() + set_observer(collecting) + with perf_timer("test.op", status="pending") as t: + t.tag("status", "done") + assert collecting.events[0].tags["status"] == "done" + + def test_emits_on_exception(self) -> None: + collecting = CollectingObserver() + set_observer(collecting) + try: + with perf_timer("test.op"): + raise ValueError("boom") + except ValueError: + pass + assert len(collecting.events) == 1 + assert collecting.events[0].operation == "test.op" + assert collecting.events[0].duration_ms >= 0 + + def test_null_observer_skips_timing(self) -> None: + set_observer(NullPerfObserver()) + with patch("pyiceberg.observability.time.monotonic") as mock_monotonic: + with perf_timer("test.op"): + pass + mock_monotonic.assert_not_called() + + def test_active_observer_calls_timing(self) -> None: + set_observer(CollectingObserver()) + with patch("pyiceberg.observability.time.monotonic", side_effect=[1.0, 2.0]) as mock_monotonic: + with perf_timer("test.op"): + pass + assert mock_monotonic.call_count == 2 + + def test_yields_perf_timer_context(self) -> None: + collecting = CollectingObserver() + set_observer(collecting) + with perf_timer("test.op") as t: + assert isinstance(t, _PerfTimerContext) + + +class TestTimedDecorator: + def test_basic_function(self) -> None: + collecting = CollectingObserver() + set_observer(collecting) + + @timed("test.add") + def add(a: int, b: int) -> int: + return a + b + + result = add(2, 3) + assert result == 5 + assert len(collecting.events) == 1 + assert collecting.events[0].operation == "test.add" + assert collecting.events[0].duration_ms >= 0 + + def test_with_decorator_tags(self) -> None: + collecting = CollectingObserver() + set_observer(collecting) + + @timed("test.op", source="test") + def noop() -> None: + pass + + noop() + assert collecting.events[0].tags["source"] == "test" + + def test_preserves_function_name(self) -> None: + @timed("test.op") + def my_function() -> None: + pass + + assert my_function.__name__ == "my_function" + + def test_null_observer_skips_timing(self) -> None: + set_observer(NullPerfObserver()) + + @timed("test.op") + def noop() -> None: + pass + + with patch("pyiceberg.observability.time.monotonic") as mock_monotonic: + noop() + mock_monotonic.assert_not_called() + + def test_propagates_exception(self) -> None: + collecting = CollectingObserver() + set_observer(collecting) + + @timed("test.op") + def fail() -> None: + raise RuntimeError("fail") + + with pytest.raises(RuntimeError, match="fail"): + fail() + assert len(collecting.events) == 1