Source code for oxbow._pyarrow

"""
Utilities for exposing custom Arrow RecordBatch streams as PyArrow fragments and datasets.

Classes
-------
BatchReaderFragment
    A fragment that emits RecordBatches from a reproducible source.
BatchReaderDataset
    A PyArrow Dataset composed of one or more BatchReaderFragments.

Constants
---------
DEFAULT_BATCH_SIZE : int
    The default batch size for reading record batches.
DEFAULT_BATCH_READAHEAD : int
    The default number of batches to read ahead.
DEFAULT_FRAGMENT_READAHEAD : int
    The default number of fragments to read ahead.

Type Aliases
------------
RecordBatchIter : Union[pyarrow.RecordBatchReader, Iterator[pyarrow.RecordBatch]]
    A type alias for a PyArrow RecordBatchReader or an iterator of RecordBatches.
"""

from __future__ import annotations

from functools import partial
from itertools import chain
from typing import Any, Callable, Final, Iterator, Union

import pyarrow as pa
import pyarrow.dataset as ds
from pyarrow.dataset import Dataset, Scanner

DEFAULT_BATCH_SIZE: Final = 2**17
DEFAULT_BATCH_READAHEAD: Final = 16
DEFAULT_FRAGMENT_READAHEAD: Final = 4

RecordBatchIter = Union[pa.RecordBatchReader, Iterator[pa.RecordBatch]]


class BatchReaderFragment:
    """
    A Fragment that emits RecordBatches from a reproducible source.

    To provide stateless replay, a new record batch iterator over the same
    records is constructed whenever a scanner is requested.

    Parameters
    ----------
    make_batchreader : Callable[[list[str] | None, int], RecordBatchIter]
        A function that recreates a specific stream of record batches.
    schema : pyarrow.Schema
        The schema of the RecordBatches.
    batch_size : int, optional
        The maximum row count for scanned record batches.
    partition_expression : pyarrow.dataset.Expression, optional
        A partition expression for the fragment, by default None.

    Attributes
    ----------
    schema : pyarrow.Schema
        The schema of the RecordBatches in the fragment.
    partition_expression : pyarrow.dataset.Expression
        An expression that evaluates to true for all data viewed by this
        fragment.
    """

[docs] def __init__( self, make_batchreader: Callable[[list[str] | None, int], RecordBatchIter], schema: pa.Schema, batch_size: int = DEFAULT_BATCH_SIZE, partition_expression: ds.Expression = None, tokenize: Any = None, ): """ Create a BatchReaderFragment from a BatchReader factory function. Parameters ---------- make_batchreader : Callable[[list[str] | None, int], RecordBatchIter] A function that recreates a specific stream of record batches. schema : pyarrow.Schema The schema of the RecordBatches. batch_size : int, optional The maximum row count for scanned record batches. partition_expression : pyarrow.dataset.Expression, optional A partition expression for the fragment, by default None. Notes ----- The `make_batchreader` function should accept the following arguments and return a pyarrow.RecordBatchReader: * columns: The columns to project. * batch_size: The maximum number of rows per batch. """ self._make_batchreader = make_batchreader self._schema = schema self._batch_size = batch_size if partition_expression is not None: self._partition_expression = partition_expression else: self._partition_expression = ds.scalar(True) self._tokenize = tokenize
@property def physical_schema(self): raise NotImplementedError @property def schema(self) -> pa.Schema: """ The schema of the RecordBatches in the fragment. Returns ------- pyarrow.Schema """ return self._schema @property def partition_expression(self) -> ds.Expression: """ An expression that evaluates to true for all data viewed by this fragment. Returns ------- pyarrow.dataset.Expression """ return self._partition_expression
[docs] def scanner( self, schema: pa.Schema | None = None, columns: list[str] | None = None, filter: ds.Expression | None = None, batch_size: int | None = None, batch_readahead: int = DEFAULT_BATCH_READAHEAD, fragment_readahead: int = DEFAULT_FRAGMENT_READAHEAD, fragment_scan_options: ds.FragmentScanOptions | None = None, use_threads: bool = True, memory_pool: ds.MemoryPool | None = None, **kwargs, ) -> ds.Scanner: """ Build a scan operation against the fragment. Parameters ---------- schema : pyarrow.Schema, optional The schema to use for scanning. If not specified, uses the Fragment's physical schema. columns : list[str], optional Names of columns to project. By default, all available columns are projected. filter : pyarrow.dataset.Expression, optional A filter expression. Scan will return only the rows matching the filter. batch_size : int, optional The maximum row count for scanned record batches. batch_readahead : int, optional The number of batches to read ahead in a file. fragment_readahead : int, optional The number of fragments/files to read ahead. fragment_scan_options : pyarrow.dataset.FragmentScanOptions, optional Options specific to a particular scan and fragment type, by default None. use_threads : bool, optional If enabled, maximum parallelism will be used, by default True. memory_pool : pyarrow.dataset.MemoryPool, optional For memory allocations, if required. By default, uses the default pool. Returns ------- pyarrow.dataset.Scanner A scanner object for the fragment. """ schema = schema or self._schema batch_size = batch_size or self._batch_size reader = self._make_batchreader(columns, batch_size) # Wrap to pyarrow if the reader supports the Arrow PyCapsule protocol if not isinstance(reader, pa.RecordBatchReader) and hasattr( reader, "__arrow_c_stream__" ): reader = pa.RecordBatchReader.from_stream(reader) # Scanner.from_batches treats RecordBatchReader differently than an # opaque iterator, so we unwrap it for consistency def _iterate(reader): for batch in reader: yield batch return Scanner.from_batches( source=_iterate(reader), schema=schema, columns=columns, filter=filter, batch_size=batch_size, batch_readahead=batch_readahead, fragment_readahead=fragment_readahead, fragment_scan_options=fragment_scan_options, use_threads=use_threads, memory_pool=memory_pool, **kwargs, )
[docs] def to_batches( self, schema: pa.Schema | None = None, columns: list[str] | None = None, filter: ds.Expression | None = None, batch_size: int | None = None, batch_readahead: int = DEFAULT_BATCH_READAHEAD, fragment_readahead: int = DEFAULT_FRAGMENT_READAHEAD, fragment_scan_options: ds.FragmentScanOptions | None = None, use_threads: bool = True, memory_pool: ds.MemoryPool | None = None, **kwargs, ) -> Iterator[pa.RecordBatch]: """ Scan and read the fragment as materialized record batches. Projections and filters are applied if specified. Parameters ---------- schema : pyarrow.Schema, optional The schema to use for scanning. If not specified, uses the Fragment's physical schema. columns : list[str], optional Names of columns to project. By default, all available columns are projected. filter : pyarrow.dataset.Expression, optional A filter expression. Scan will return only the rows matching the filter. batch_size : int, optional The maximum row count for scanned record batches. batch_readahead : int, optional The number of batches to read ahead in a file. fragment_readahead : int, optional The number of fragments/files to read ahead,. fragment_scan_options : pyarrow.dataset.FragmentScanOptions, optional Options specific to a particular scan and fragment type, by default None. use_threads : bool, optional If enabled, maximum parallelism will be used, by default True. memory_pool : pyarrow.dataset.MemoryPool, optional For memory allocations, if required. By default, uses the default pool. Returns ------- Iterator[pyarrow.RecordBatch] An iterator of record batches. """ return self.scanner( schema=schema or self.schema, columns=columns, filter=filter, batch_size=batch_size or self._batch_size, batch_readahead=batch_readahead, fragment_readahead=fragment_readahead, fragment_scan_options=fragment_scan_options, use_threads=use_threads, memory_pool=memory_pool, **kwargs, ).to_batches()
def to_table( self, schema: pa.Schema | None = None, columns: list[str] | None = None, filter: ds.Expression | None = None, batch_size: int | None = None, batch_readahead: int = DEFAULT_BATCH_READAHEAD, fragment_readahead: int = DEFAULT_FRAGMENT_READAHEAD, fragment_scan_options: ds.FragmentScanOptions | None = None, use_threads: bool = True, memory_pool: ds.MemoryPool | None = None, **kwargs, ) -> pa.Table: return self.scanner( schema=schema or self.schema, columns=columns, filter=filter, batch_size=batch_size or self._batch_size, batch_readahead=batch_readahead, fragment_readahead=fragment_readahead, fragment_scan_options=fragment_scan_options, use_threads=use_threads, memory_pool=memory_pool, **kwargs, ).to_table() def take( self, indices, columns: list[str] | None = None, filter: ds.Expression | None = None, batch_size: int | None = None, batch_readahead: int = DEFAULT_BATCH_READAHEAD, fragment_readahead: int = DEFAULT_FRAGMENT_READAHEAD, fragment_scan_options: ds.FragmentScanOptions | None = None, use_threads: bool = True, memory_pool: ds.MemoryPool | None = None, **kwargs, ) -> pa.Table: return self.scanner( columns=columns, filter=filter, batch_size=batch_size or self._batch_size, batch_readahead=batch_readahead, fragment_readahead=fragment_readahead, fragment_scan_options=fragment_scan_options, use_threads=use_threads, memory_pool=memory_pool, **kwargs, ).take(indices) def head( self, num_rows: int, columns: list[str] | None = None, filter: ds.Expression | None = None, batch_size: int | None = None, batch_readahead: int = DEFAULT_BATCH_READAHEAD, fragment_readahead: int = DEFAULT_FRAGMENT_READAHEAD, fragment_scan_options: ds.FragmentScanOptions | None = None, use_threads: bool = True, memory_pool: ds.MemoryPool | None = None, **kwargs, ) -> pa.Table: return self.scanner( columns=columns, filter=filter, batch_size=batch_size or self._batch_size, batch_readahead=batch_readahead, fragment_readahead=fragment_readahead, fragment_scan_options=fragment_scan_options, use_threads=use_threads, memory_pool=memory_pool, **kwargs, ).head(num_rows) def count_rows( self, filter: ds.Expression | None = None, batch_size: int | None = None, batch_readahead: int = DEFAULT_BATCH_READAHEAD, fragment_readahead: int = DEFAULT_FRAGMENT_READAHEAD, fragment_scan_options: ds.FragmentScanOptions | None = None, use_threads: bool = True, memory_pool: ds.MemoryPool | None = None, **kwargs, ) -> int: return self.scanner( filter=filter, batch_size=batch_size or self._batch_size, batch_readahead=batch_readahead, fragment_readahead=fragment_readahead, fragment_scan_options=fragment_scan_options, use_threads=use_threads, memory_pool=memory_pool, **kwargs, ).count_rows()
[docs] def iter_batches( self, columns: list[str] | None = None, batch_size: int | None = None ) -> Iterator[pa.RecordBatch]: """ Iterate over batches in the fragment. Parameters ---------- columns : list[str], optional Names of columns to project. By default, all available columns are projected. batch_size : int, optional The maximum row count for scanned record batches. Returns ------- Iterator[pyarrow.RecordBatch] An iterator of record batches. """ return self._make_batchreader(columns, batch_size or self._batch_size)
def __dask_tokenize__(self): """ Return a representation of the fragment for Dask to tokenize. Returns ------- tuple A tuple that fully and deterministically represents the fragment. Notes ----- https://docs.dask.org/en/stable/custom-collections.html#implementing-deterministic-hashing """ from dask.base import normalize_token if self._tokenize is None: return ( normalize_token(self.__class__), self._make_batchreader, self._schema, self._batch_size, self._partition_expression, ) else: return (normalize_token(self.__class__), *self._tokenize) class BatchReaderDataset(Dataset): """ A PyArrow Dataset composed of one or more BatchReaderFragments. Parameters ---------- fragments : list[BatchReaderFragment] The list of fragments. partition_expression : pyarrow.dataset.Expression, optional A partition expression for the dataset, by default None. Attributes ---------- schema : pyarrow.Schema The schema of the RecordBatches in the dataset. partition_expression : pyarrow.dataset.Expression An expression that evaluates to true for all data viewed by this dataset. """
[docs] def __init__( self, fragments: list[BatchReaderFragment], partition_expression: ds.Expression = None, ): """ Create a BatchReaderDataset from a list of BatchReaderFragments. Parameters ---------- fragments : list[BatchReaderFragment] The list of fragments. partition_expression : pyarrow.dataset.Expression, optional A partition expression for the dataset, by default None. """ self._fragments = fragments self._schema = self._fragments[0].schema self._batch_size = self._fragments[0]._batch_size self._scan_options = {} if partition_expression is not None: self._partition_expression = partition_expression else: self._partition_expression = ds.scalar(True)
@property def partition_expression(self) -> ds.Expression: """ Returns ------- pyarrow.dataset.Expression An expression that evaluates to true for all data viewed by this dataset. """ return self._partition_expression @property def schema(self) -> pa.Schema: """ Returns ------- pyarrow.Schema The schema of the RecordBatches in the dataset. """ return self._schema
[docs] def get_fragments( self, filter: ds.Expression | None = None ) -> Iterator[BatchReaderFragment]: """ Return an iterator over fragments. Parameters ---------- filter : pyarrow.dataset.Expression, optional An expression to filter fragments. Not yet implemented. Returns ------- Iterator[BatchReaderFragment] An iterator over fragments. Notes ----- ``filter`` here is meant to be applied at the fragment level via comparison with the partition_expression, not at the row level. """ if filter is None: for fragment in self._fragments: yield fragment else: # https://github.com/apache/arrow/blob/76fa19e61af25d124ec0af5e543110a4672088db/cpp/src/arrow/dataset/dataset.cc#L204C1-L210C1 raise NotImplementedError("Fragment-level filtering is not yet implemented")
[docs] def scanner( self, columns: list[str] | None = None, filter: ds.Expression | None = None, batch_size: int | None = None, batch_readahead: int = DEFAULT_BATCH_READAHEAD, fragment_readahead: int = DEFAULT_FRAGMENT_READAHEAD, fragment_scan_options: ds.FragmentScanOptions | None = None, use_threads: bool | None = True, memory_pool: ds.MemoryPool | None = None, **kwargs, ) -> ds.Scanner: """ Build a scan operation against the dataset. This scanner chains the record batches from all fragments together and applies column projection and row filtering. Parameters ---------- columns : list[str], optional Names of columns to project. By default, all available columns are projected. filter : pyarrow.dataset.Expression, optional A filter expression. Scan will return only the rows matching the filter. batch_size : int, optional The maximum row count for scanned record batches. batch_readahead : int, optional The number of batches to read ahead in a file. fragment_readahead : int, optional The number of fragments/files to read ahead. fragment_scan_options : pyarrow.dataset.FragmentScanOptions, optional Options specific to a particular scan and fragment type, by default None. use_threads : bool, optional If enabled, maximum parallelism will be used, by default True. memory_pool : pyarrow.dataset.MemoryPool, optional For memory allocations, if required. By default, uses the default pool. Returns ------- pyarrow.dataset.Scanner A scanner object for the dataset. """ # TODO: Prune fragments using their partition expressions. # Chain all the fragments' record batch iterators together; don't # apply any filter yet. No batches should get materialized. # Wrap each fragment's arro3 reader to pyarrow before chaining. batch_size = batch_size or self._batch_size def _pyarrow_batches(fragment): reader = fragment.iter_batches(columns=columns, batch_size=batch_size) if not isinstance(reader, pa.RecordBatchReader) and hasattr( reader, "__arrow_c_stream__" ): reader = pa.RecordBatchReader.from_stream(reader) yield from reader batch_iter = chain.from_iterable( _pyarrow_batches(fragment) for fragment in self._fragments ) # Apply the row filter via the scanner. return Scanner.from_batches( source=batch_iter, schema=self.schema, columns=columns, filter=filter, batch_size=batch_size, batch_readahead=batch_readahead, fragment_readahead=fragment_readahead, fragment_scan_options=fragment_scan_options, use_threads=use_threads, memory_pool=memory_pool, **kwargs, )
[docs] def to_batches( self, columns: list[str] | None = None, filter: ds.Expression | None = None, batch_size: int | None = None, batch_readahead: int = DEFAULT_BATCH_READAHEAD, fragment_readahead: int = DEFAULT_FRAGMENT_READAHEAD, fragment_scan_options: ds.FragmentScanOptions | None = None, use_threads: bool | None = True, memory_pool: ds.MemoryPool | None = None, **kwargs, ) -> Iterator[pa.RecordBatch]: """ Read the dataset as materialized record batches. Parameters ---------- columns : list[str], optional Names of columns to project. By default, all available columns are projected. filter : pyarrow.dataset.Expression, optional A filter expression. Scan will return only the rows matching the filter. batch_size : int, optional The maximum row count for scanned record batches. batch_readahead : int, optional The number of batches to read ahead in a file. fragment_readahead : int, optional The number of fragments/files to read ahead. fragment_scan_options : pyarrow.dataset.FragmentScanOptions, optional Options specific to a particular scan and fragment type, by default None. use_threads : bool, optional If enabled, maximum parallelism will be used, by default True. memory_pool : pyarrow.dataset.MemoryPool, optional For memory allocations, if required. By default, uses the default pool. Returns ------- Iterator[pyarrow.RecordBatch] An iterator of record batches. """ return self.scanner( columns=columns, filter=filter, batch_size=batch_size or self._batch_size, batch_readahead=batch_readahead, fragment_readahead=fragment_readahead, fragment_scan_options=fragment_scan_options, use_threads=use_threads, memory_pool=memory_pool, **kwargs, ).to_batches()
[docs] def iter_batches( self, columns: list[str] | None = None, batch_size: int | None = None ) -> Iterator[pa.RecordBatch]: """ Iterate over batches in the dataset. Parameters ---------- columns : list[str], optional Names of columns to project. By default, all available columns are projected. batch_size : int, optional The maximum row count for scanned record batches. Returns ------- Iterator[pyarrow.RecordBatch] An iterator of record batches. """ return chain.from_iterable( fragment.iter_batches( columns=columns, batch_size=batch_size or self._batch_size, ) for fragment in self._fragments )
[docs] def filter(self, expression: ds.Expression) -> "BatchReaderDataset": """ Apply a row-level filter expression and return a filtered dataset. Parameters ---------- expression : pyarrow.dataset.Expression The filter expression. Returns ------- BatchReaderDataset The filtered dataset. Notes ----- This performs a row-level filter, not a fragment-level filter. The new filter is applied lazily on the underlying record batches by using a wrapper factory function to make a new BatchReaderFragment. """ new_filter = expression current_filter = self._scan_options.get("filter") if current_filter is not None and new_filter is not None: new_filter = current_filter & new_filter new_fragments = [] def filter_batches(fragment, columns, batch_size): batches = fragment.to_batches( schema=self.schema, columns=columns, filter=new_filter, batch_size=batch_size, ) yield from batches for fragment in self._fragments: new_fragment = BatchReaderFragment( partial(filter_batches, fragment), fragment.schema, fragment._batch_size ) new_fragments.append(new_fragment) filtered_dataset = self.__class__(new_fragments) filtered_dataset._scan_options = dict(filter=new_filter) return filtered_dataset
[docs] def replace_schema(self, schema): """ Replace the schema of the dataset. Parameters ---------- schema : pyarrow.Schema The new schema. Raises ------ NotImplementedError This method is not yet implemented. """ raise NotImplementedError
[docs] def sort_by(self, sorting, **kwargs): """ Sort the dataset by the specified columns. Parameters ---------- sorting : list[str] The columns to sort by. Raises ------ NotImplementedError This method is not yet implemented. """ raise NotImplementedError
[docs] def join(self, *args, **kwargs): """ Perform a join operation on the dataset. Raises ------ NotImplementedError This method is not yet implemented. """ raise NotImplementedError
[docs] def join_asof(self, *args, **kwargs): """ Perform an as-of join operation on the dataset. Raises ------ NotImplementedError This method is not yet implemented. """ raise NotImplementedError