Skip to content

Commit

Permalink
Add read_from_parquet operation (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
simw authored Nov 14, 2023
1 parent 3c9980b commit ac36067
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "pipedata"
version = "0.2"
version = "0.2.1"
description = "Framework for building pipelines for data processing"
authors = ["Simon Wicks <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/pipedata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2"
__version__ = "0.2.1"

__all__ = [
"__version__",
Expand Down
47 changes: 46 additions & 1 deletion src/pipedata/ops/files.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
import logging
import zipfile
from dataclasses import dataclass
from typing import IO, Iterator
from typing import (
IO,
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Union,
)

import fsspec # type: ignore
import pyarrow as pa # type: ignore
import pyarrow.dataset as pa_dataset # type: ignore

logger = logging.getLogger(__name__)


class FilesReaderError(Exception):
pass


@dataclass
class OpenedFileRef:
name: str
Expand All @@ -30,3 +46,32 @@ def zipped_files(file_refs: Iterator[str]) -> Iterator[OpenedFileRef]:
name=name,
contents=inner_file,
)


def read_from_parquet(
columns: Optional[Union[List[str], Dict[str, Any]]] = None,
return_as: Literal["recordbatch", "record"] = "record",
batch_size: Optional[int] = 100_000,
) -> Callable[[Iterator[str]], Iterator[Union[Dict[str, Any], pa.RecordBatch]]]:
logger.info(f"Initializing parquet reader with {batch_size=}")

if return_as not in ("recordbatch", "record"):
raise FilesReaderError(f"Unknown return_as value {return_as}")

def parquet_batch_reader(
file_refs: Iterator[str],
) -> Iterator[Union[Dict[str, Any], pa.RecordBatch]]:
for file_ref in file_refs:
logger.info(f"Reading parquet file {file_ref}")
ds = pa_dataset.dataset(file_ref, format="parquet")
for batch in ds.to_batches(columns=columns, batch_size=batch_size):
if return_as == "recordbatch":
yield batch
elif return_as == "record":
yield from batch.to_pylist()
else:
raise FilesReaderError(
f"Unknown return_as value {return_as}"
) # pragma: no cover

return parquet_batch_reader
77 changes: 76 additions & 1 deletion tests/ops/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import zipfile
from pathlib import Path

import pyarrow as pa # type: ignore
import pytest

from pipedata.core import StreamStart
from pipedata.ops.files import zipped_files
from pipedata.ops.files import FilesReaderError, read_from_parquet, zipped_files


def test_zipped_files() -> None:
Expand Down Expand Up @@ -44,3 +47,75 @@ def test_zipped_file_contents() -> None:
"Hello, world 3!",
]
assert result == expected


def test_parquet_reading_simple() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
parquet_path = Path(temp_dir) / "test.parquet"

table = pa.Table.from_pydict(
{
"a": [1, 2, 3],
"b": [4, 5, 6],
}
)
pa.parquet.write_table(table, parquet_path)

parquet_reader = read_from_parquet()
result = StreamStart([str(parquet_path)]).flat_map(parquet_reader).to_list()

expected = [
{"a": 1, "b": 4},
{"a": 2, "b": 5},
{"a": 3, "b": 6},
]
assert result == expected


def test_parquet_reading_with_columns() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
parquet_path = Path(temp_dir) / "test.parquet"

table = pa.Table.from_pydict(
{
"a": [1, 2, 3],
"b": [4, 5, 6],
}
)
pa.parquet.write_table(table, parquet_path)

parquet_reader = read_from_parquet(columns=["a"])
result = StreamStart([str(parquet_path)]).flat_map(parquet_reader).to_list()

expected = [
{"a": 1},
{"a": 2},
{"a": 3},
]
assert result == expected


def test_parquet_reading_record_batch() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
parquet_path = Path(temp_dir) / "test.parquet"

table = pa.Table.from_pydict(
{
"a": [1, 2, 3],
"b": [4, 5, 6],
}
)
pa.parquet.write_table(table, parquet_path)

parquet_reader = read_from_parquet(columns=["a"], return_as="recordbatch")
result = StreamStart([str(parquet_path)]).flat_map(parquet_reader).to_list()

schema = pa.schema([("a", pa.int64())])
a_array = pa.array([1, 2, 3])
rb = pa.RecordBatch.from_arrays([a_array], schema=schema)
assert result == [rb]


def test_parquet_reading_invalid_return_as() -> None:
with pytest.raises(FilesReaderError):
read_from_parquet(columns=["a"], return_as="unknown") # type: ignore

0 comments on commit ac36067

Please sign in to comment.