Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 135 additions & 84 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
DataFileContent,
FileFormat,
)
from pyiceberg.observability import perf_timer
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
from pyiceberg.schema import (
PartnerAccessor,
Expand Down Expand Up @@ -1586,102 +1587,123 @@ def _task_to_record_batches(
downcast_ns_timestamp_to_us: bool | None = None,
batch_size: int | None = None,
) -> Iterator[pa.RecordBatch]:
arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
with io.new_input(task.file.file_path).open() as fin:
fragment = arrow_format.make_fragment(fin)
physical_schema = fragment.physical_schema

# For V1 and V2, we only support Timestamp 'us' in Iceberg Schema,
# therefore it is reasonable to always cast 'ns' timestamp to 'us' on read.
# For V3 this has to set explicitly to avoid nanosecond timestamp to be down-casted by default
downcast_ns_timestamp_to_us = (
downcast_ns_timestamp_to_us if downcast_ns_timestamp_to_us is not None else format_version <= 2
)
file_schema = pyarrow_to_schema(
physical_schema, name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, format_version=format_version
)

# Apply column projection rules: https://iceberg.apache.org/spec/#column-projection
projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, table_schema, partition_spec, file_schema.field_ids
)
with perf_timer(
"arrow.read_file",
file_path=task.file.file_path,
file_format=str(task.file.file_format),
) as t:
t.metric("file_size_bytes", task.file.file_size_in_bytes)
arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
batch_count = 0
row_count = 0
with io.new_input(task.file.file_path).open() as fin:
fragment = arrow_format.make_fragment(fin)
physical_schema = fragment.physical_schema

# For V1 and V2, we only support Timestamp 'us' in Iceberg Schema,
# therefore it is reasonable to always cast 'ns' timestamp to 'us' on read.
# For V3 this has to set explicitly to avoid nanosecond timestamp to be down-casted by default
downcast_ns_timestamp_to_us = (
downcast_ns_timestamp_to_us if downcast_ns_timestamp_to_us is not None else format_version <= 2
)
file_schema = pyarrow_to_schema(
physical_schema,
name_mapping,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
format_version=format_version,
)

pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
translated_row_filter = translate_column_names(
bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_field_values=projected_missing_fields
# Apply column projection rules: https://iceberg.apache.org/spec/#column-projection
projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, table_schema, partition_spec, file_schema.field_ids
)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter, file_schema)

file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)

scanner_kwargs: dict[str, Any] = {
"fragment": fragment,
"schema": physical_schema,
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them first
"filter": pyarrow_filter if not positional_deletes else None,
"columns": [col.name for col in file_project_schema.columns],
}
if batch_size is not None:
scanner_kwargs["batch_size"] = batch_size

fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs)
pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
translated_row_filter = translate_column_names(
bound_row_filter,
file_schema,
case_sensitive=case_sensitive,
projected_field_values=projected_missing_fields,
)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter, file_schema)

file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)

scanner_kwargs: dict[str, Any] = {
"fragment": fragment,
"schema": physical_schema,
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them first
"filter": pyarrow_filter if not positional_deletes else None,
"columns": [col.name for col in file_project_schema.columns],
}
if batch_size is not None:
scanner_kwargs["batch_size"] = batch_size

next_index = 0
batches = fragment_scanner.to_batches()
for batch in batches:
next_index = next_index + len(batch)
current_index = next_index - len(batch)
current_batch = batch
fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs)

if positional_deletes:
# Create the mask of indices that we're interested in
indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch))
current_batch = current_batch.take(indices)
next_index = 0
batches = fragment_scanner.to_batches()
for batch in batches:
next_index = next_index + len(batch)
current_index = next_index - len(batch)
current_batch = batch

# skip empty batches
if current_batch.num_rows == 0:
continue
if positional_deletes:
# Create the mask of indices that we're interested in
indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch))
current_batch = current_batch.take(indices)

# Apply the user filter
if pyarrow_filter is not None:
# Temporary fix until PyArrow 21 is released ( https://github.com/apache/arrow/pull/46057 )
table = pa.Table.from_batches([current_batch])
table = table.filter(pyarrow_filter)
batch_count += 1
# skip empty batches
if table.num_rows == 0:
if current_batch.num_rows == 0:
continue

current_batch = table.combine_chunks().to_batches()[0]
# Apply the user filter
if pyarrow_filter is not None:
# Temporary fix until PyArrow 21 is released ( https://github.com/apache/arrow/pull/46057 )
table = pa.Table.from_batches([current_batch])
table = table.filter(pyarrow_filter)
# skip empty batches
if table.num_rows == 0:
continue

yield _to_requested_schema(
projected_schema,
file_project_schema,
current_batch,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
projected_missing_fields=projected_missing_fields,
allow_timestamp_tz_mismatch=True,
)
current_batch = table.combine_chunks().to_batches()[0]

row_count += current_batch.num_rows
yield _to_requested_schema(
projected_schema,
file_project_schema,
current_batch,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
projected_missing_fields=projected_missing_fields,
allow_timestamp_tz_mismatch=True,
)
t.metric("batch_count", batch_count)
t.metric("row_count", row_count)


def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[str, list[ChunkedArray]]:
deletes_per_file: dict[str, list[ChunkedArray]] = {}
unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks]))
if len(unique_deletes) > 0:
executor = ExecutorFactory.get_or_create()
deletes_per_files: Iterator[dict[str, ChunkedArray]] = executor.map(
lambda args: _read_deletes(*args),
[(io, delete_file) for delete_file in unique_deletes],
)
for delete in deletes_per_files:
for file, arr in delete.items():
if file in deletes_per_file:
deletes_per_file[file].append(arr)
else:
deletes_per_file[file] = [arr]
with perf_timer("arrow.read_delete_files") as t:
deletes_per_file: dict[str, list[ChunkedArray]] = {}
unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks]))
t.metric("unique_delete_files", len(unique_deletes))
if len(unique_deletes) > 0:
executor = ExecutorFactory.get_or_create()
deletes_per_files: Iterator[dict[str, ChunkedArray]] = executor.map(
lambda args: _read_deletes(*args),
[(io, delete_file) for delete_file in unique_deletes],
)
for delete in deletes_per_files:
for file, arr in delete.items():
if file in deletes_per_file:
deletes_per_file[file].append(arr)
else:
deletes_per_file[file] = [arr]

t.metric("data_files_with_deletes", len(deletes_per_file))
return deletes_per_file


Expand Down Expand Up @@ -1887,12 +1909,41 @@ def to_record_batches(
if order.concurrent_streams < 1:
raise ValueError(f"concurrent_streams must be >= 1, got {order.concurrent_streams}")
return self._apply_limit(
self._iter_batches_arrival(
task_list, deletes_per_file, order.batch_size, order.concurrent_streams, order.max_buffered_batches
self._iter_batches_counted(
self._iter_batches_arrival(
task_list,
deletes_per_file,
order.batch_size,
order.concurrent_streams,
order.max_buffered_batches,
),
task_count=len(task_list),
)
)

return self._apply_limit(self._iter_batches_materialized(task_list, deletes_per_file))
return self._apply_limit(
self._iter_batches_counted(
self._iter_batches_materialized(task_list, deletes_per_file),
task_count=len(task_list),
)
)

@staticmethod
def _iter_batches_counted(
inner: Iterator[pa.RecordBatch],
task_count: int,
) -> Iterator[pa.RecordBatch]:
"""Wrap an inner batch iterator with aggregate perf_timer tracking."""
with perf_timer("arrow.to_record_batches") as t:
t.metric("task_count", task_count)
batch_count = 0
row_count = 0
for batch in inner:
batch_count += 1
row_count += batch.num_rows
yield batch
t.metric("batch_count", batch_count)
t.metric("row_count", row_count)

def _prepare_tasks_and_deletes(
self, tasks: Iterable[FileScanTask]
Expand Down
59 changes: 34 additions & 25 deletions pyiceberg/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pyiceberg.conversions import to_bytes
from pyiceberg.exceptions import ValidationError
from pyiceberg.io import FileIO, InputFile, OutputFile
from pyiceberg.observability import perf_timer
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.typedef import Record, TableVersion
Expand Down Expand Up @@ -869,18 +870,21 @@ def fetch_manifest_entry(self, io: FileIO, discard_deleted: bool = True) -> list
Returns:
An Iterator of manifest entries.
"""
input_file = io.new_input(self.manifest_path)
with AvroFile[ManifestEntry](
input_file,
MANIFEST_ENTRY_SCHEMAS[DEFAULT_READ_VERSION],
read_types={-1: ManifestEntry, 2: DataFile},
read_enums={0: ManifestEntryStatus, 101: FileFormat, 134: DataFileContent},
) as reader:
return [
_inherit_from_manifest(entry, self)
for entry in reader
if not discard_deleted or entry.status != ManifestEntryStatus.DELETED
]
with perf_timer("manifest.fetch_entries", manifest_path=self.manifest_path) as t:
input_file = io.new_input(self.manifest_path)
with AvroFile[ManifestEntry](
input_file,
MANIFEST_ENTRY_SCHEMAS[DEFAULT_READ_VERSION],
read_types={-1: ManifestEntry, 2: DataFile},
read_enums={0: ManifestEntryStatus, 101: FileFormat, 134: DataFileContent},
) as reader:
result = [
_inherit_from_manifest(entry, self)
for entry in reader
if not discard_deleted or entry.status != ManifestEntryStatus.DELETED
]
t.metric("entry_count", len(result))
return result

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the ManifestFile class."""
Expand Down Expand Up @@ -924,19 +928,24 @@ def _manifests(io: FileIO, manifest_list: str) -> tuple[ManifestFile, ...]:
Returns:
A tuple of ManifestFile objects.
"""
file = io.new_input(manifest_list)
manifest_files = list(read_manifest_list(file))

result = []
with _manifest_cache_lock:
for manifest_file in manifest_files:
manifest_path = manifest_file.manifest_path
if manifest_path in _manifest_cache:
result.append(_manifest_cache[manifest_path])
else:
_manifest_cache[manifest_path] = manifest_file
result.append(manifest_file)

with perf_timer("manifest.read_list") as t:
file = io.new_input(manifest_list)
manifest_files = list(read_manifest_list(file))

result = []
cache_hits = 0
with _manifest_cache_lock:
for manifest_file in manifest_files:
manifest_path = manifest_file.manifest_path
if manifest_path in _manifest_cache:
result.append(_manifest_cache[manifest_path])
cache_hits += 1
else:
_manifest_cache[manifest_path] = manifest_file
result.append(manifest_file)

t.metric("manifest_count", len(result))
t.metric("cache_hits", cache_hits)
return tuple(result)


Expand Down
Loading
Loading