From e86c3e3d9cb97d0cbe9a9a50efc06e0dae15a700 Mon Sep 17 00:00:00 2001 From: nuno-faria Date: Mon, 5 May 2025 22:29:31 +0100 Subject: [PATCH] feat: Support Parquet writer options --- python/datafusion/__init__.py | 6 +- python/datafusion/dataframe.py | 247 +++++++++++++++------- python/tests/test_dataframe.py | 373 ++++++++++++++++++++++++++++++--- src/dataframe.rs | 154 ++++++++++---- src/lib.rs | 2 + 5 files changed, 623 insertions(+), 159 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 15ceefbdb..273abbadb 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -31,7 +31,7 @@ from . import functions, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. -from ._internal import Config +from ._internal import Config, ParquetWriterOptions from .catalog import Catalog, Database, Table from .common import ( DFSchema, @@ -42,7 +42,7 @@ SessionContext, SQLOptions, ) -from .dataframe import DataFrame +from .dataframe import DataFrame, ParquetColumnOptions from .expr import ( Expr, WindowFrame, @@ -66,6 +66,8 @@ "ExecutionPlan", "Expr", "LogicalPlan", + "ParquetColumnOptions", + "ParquetWriterOptions", "RecordBatch", "RecordBatchStream", "RuntimeEnvBuilder", diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 26fe8f453..96f939e70 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -28,7 +28,6 @@ Iterable, Literal, Optional, - Union, overload, ) @@ -51,67 +50,58 @@ from datafusion._internal import DataFrame as DataFrameInternal from datafusion._internal import expr as expr_internal -from enum import Enum - +from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal +from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal from datafusion.expr import Expr, SortExpr, sort_or_default -# excerpt from deltalake -# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 -class Compression(Enum): - """Enum representing the available compression types for Parquet files.""" - - UNCOMPRESSED = "uncompressed" - SNAPPY = "snappy" - GZIP = "gzip" - BROTLI = "brotli" - LZ4 = "lz4" - # lzo is not implemented yet - # https://github.com/apache/arrow-rs/issues/6970 - # LZO = "lzo" - ZSTD = "zstd" - LZ4_RAW = "lz4_raw" - - @classmethod - def from_str(cls: type[Compression], value: str) -> Compression: - """Convert a string to a Compression enum value. - - Args: - value: The string representation of the compression type. - - Returns: - The Compression enum lowercase value. - - Raises: - ValueError: If the string does not match any Compression enum value. - """ - try: - return cls(value.lower()) - except ValueError as err: - valid_values = str([item.value for item in Compression]) - error_msg = f""" - {value} is not a valid Compression. - Valid values are: {valid_values} - """ - raise ValueError(error_msg) from err - - def get_default_level(self) -> Optional[int]: - """Get the default compression level for the compression type. +class ParquetColumnOptions: + """Parquet options for individual columns. + + Contains the available options that can be applied for an individual Parquet column, + replacing the provided options in the `write_parquet`. + + Attributes: + encoding: Sets encoding for the column path. Valid values are: `plain`, + `plain_dictionary`, `rle`, `bit_packed`, `delta_binary_packed`, + `delta_length_byte_array`, `delta_byte_array`, `rle_dictionary`, and + `byte_stream_split`. These values are not case-sensitive. If `None`, uses + the default parquet options + dictionary_enabled: Sets if dictionary encoding is enabled for the column path. + If `None`, uses the default parquet options + compression: Sets default parquet compression codec for the column path. Valid + values are `uncompressed`, `snappy`, `gzip(level)`, `lzo`, `brotli(level)`, + `lz4`, `zstd(level)`, and `lz4_raw`. These values are not case-sensitive. If + `None`, uses the default parquet options. + statistics_enabled: Sets if statistics are enabled for the column Valid values + are: `none`, `chunk`, and `page` These values are not case sensitive. If + `None`, uses the default parquet options. + bloom_filter_enabled: Sets if bloom filter is enabled for the column path. If + `None`, uses the default parquet options. + bloom_filter_fpp: Sets bloom filter false positive probability for the column + path. If `None`, uses the default parquet options. + bloom_filter_ndv: Sets bloom filter number of distinct values. If `None`, uses + the default parquet options. + """ - Returns: - The default compression level for the compression type. - """ - # GZIP, BROTLI default values from deltalake repo - # https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 - # ZSTD default value from delta-rs - # https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223 - if self == Compression.GZIP: - return 6 - if self == Compression.BROTLI: - return 1 - if self == Compression.ZSTD: - return 4 - return None + def __init__( + self, + encoding: Optional[str] = None, + dictionary_enabled: Optional[bool] = None, + compression: Optional[str] = None, + statistics_enabled: Optional[str] = None, + bloom_filter_enabled: Optional[bool] = None, + bloom_filter_fpp: Optional[float] = None, + bloom_filter_ndv: Optional[int] = None, + ) -> None: + """Initialize the ParquetColumnOptions.""" + self.encoding = encoding + self.dictionary_enabled = dictionary_enabled + self.compression = compression + self.statistics_enabled = statistics_enabled + self.bloom_filter_enabled = bloom_filter_enabled + self.bloom_filter_fpp = bloom_filter_fpp + self.bloom_filter_ndv = bloom_filter_ndv class DataFrame: @@ -704,38 +694,135 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None def write_parquet( self, path: str | pathlib.Path, - compression: Union[str, Compression] = Compression.ZSTD, - compression_level: int | None = None, + data_pagesize_limit: int = 1024 * 1024, + write_batch_size: int = 1024, + writer_version: str = "1.0", + skip_arrow_metadata: bool = False, + compression: Optional[str] = "zstd(3)", + dictionary_enabled: Optional[bool] = True, + dictionary_page_size_limit: int = 1024 * 1024, + statistics_enabled: Optional[str] = "page", + max_row_group_size: int = 1024 * 1024, + created_by: str = "datafusion-python", + column_index_truncate_length: Optional[int] = 64, + statistics_truncate_length: Optional[int] = None, + data_page_row_count_limit: int = 20_000, + encoding: Optional[str] = None, + bloom_filter_on_write: bool = False, + bloom_filter_fpp: Optional[float] = None, + bloom_filter_ndv: Optional[int] = None, + allow_single_file_parallelism: bool = True, + maximum_parallel_row_group_writers: int = 1, + maximum_buffered_record_batches_per_stream: int = 2, + column_specific_options: Optional[dict[str, ParquetColumnOptions]] = None, ) -> None: """Execute the :py:class:`DataFrame` and write the results to a Parquet file. Args: path: Path of the Parquet file to write. - compression: Compression type to use. Default is "ZSTD". - Available compression types are: + data_pagesize_limit: Sets best effort maximum size of data page in bytes. + write_batch_size: Sets write_batch_size in bytes. + writer_version: Sets parquet writer version. Valid values are `1.0` and + `2.0`. + skip_arrow_metadata: Skip encoding the embedded arrow metadata in the + KV_meta. + compression: Compression type to use. Default is "zstd(3)". + Available compression types are - "uncompressed": No compression. - "snappy": Snappy compression. - - "gzip": Gzip compression. - - "brotli": Brotli compression. + - "gzip(n)": Gzip compression with level n. + - "brotli(n)": Brotli compression with level n. - "lz4": LZ4 compression. - "lz4_raw": LZ4_RAW compression. - - "zstd": Zstandard compression. - Note: LZO is not yet implemented in arrow-rs and is therefore excluded. - compression_level: Compression level to use. For ZSTD, the - recommended range is 1 to 22, with the default being 4. Higher levels - provide better compression but slower speed. - """ - # Convert string to Compression enum if necessary - if isinstance(compression, str): - compression = Compression.from_str(compression) - - if ( - compression in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD} - and compression_level is None - ): - compression_level = compression.get_default_level() + - "zstd(n)": Zstandard compression with level n. + dictionary_enabled: Sets if dictionary encoding is enabled. If None, uses + the default parquet writer setting. + dictionary_page_size_limit: Sets best effort maximum dictionary page size, + in bytes. + statistics_enabled: Sets if statistics are enabled for any column Valid + values are `none`, `chunk`, and `page`. If None, uses the default + parquet writer setting. + max_row_group_size: Target maximum number of rows in each row group + (defaults to 1M rows). Writing larger row groups requires more memory to + write, but can get better compression and be faster to read. + created_by: Sets "created by" property. + column_index_truncate_length: Sets column index truncate length. + statistics_truncate_length: Sets statistics truncate length. If None, uses + the default parquet writer setting. + data_page_row_count_limit: Sets best effort maximum number of rows in a data + page. + encoding: Sets default encoding for any column. Valid values are `plain`, + `plain_dictionary`, `rle`, `bit_packed`, `delta_binary_packed`, + `delta_length_byte_array`, `delta_byte_array`, `rle_dictionary`, and + `byte_stream_split`. If None, uses the default parquet writer setting. + bloom_filter_on_write: Write bloom filters for all columns when creating + parquet files. + bloom_filter_fpp: Sets bloom filter false positive probability. If None, + uses the default parquet writer setting + bloom_filter_ndv: Sets bloom filter number of distinct values. If None, uses + the default parquet writer setting. + allow_single_file_parallelism: Controls whether DataFusion will attempt to + speed up writing parquet files by serializing them in parallel. Each + column in each row group in each output file are serialized in parallel + leveraging a maximum possible core count of n_files * n_row_groups * + n_columns. + maximum_parallel_row_group_writers: By default parallel parquet writer is + tuned for minimum memory usage in a streaming execution plan. You may + see a performance benefit when writing large parquet files by increasing + `maximum_parallel_row_group_writers` and + `maximum_buffered_record_batches_per_stream` if your system has idle + cores and can tolerate additional memory usage. Boosting these values is + likely worthwhile when writing out already in-memory data, such as from + a cached data frame. + maximum_buffered_record_batches_per_stream: See + `maximum_parallel_row_group_writers`. + column_specific_options: Overrides options for specific columns. If a column + is not a part of this dictionary, it will use the parameters provided in + the `write_parquet`. + """ + options_internal = ParquetWriterOptionsInternal( + data_pagesize_limit, + write_batch_size, + writer_version, + skip_arrow_metadata, + compression, + dictionary_enabled, + dictionary_page_size_limit, + statistics_enabled, + max_row_group_size, + created_by, + column_index_truncate_length, + statistics_truncate_length, + data_page_row_count_limit, + encoding, + bloom_filter_on_write, + bloom_filter_fpp, + bloom_filter_ndv, + allow_single_file_parallelism, + maximum_parallel_row_group_writers, + maximum_buffered_record_batches_per_stream, + ) + + if column_specific_options is None: + column_specific_options = {} + + column_specific_options_internal = {} + for column, opts in column_specific_options.items(): + column_specific_options_internal[column] = ParquetColumnOptionsInternal( + bloom_filter_enabled=opts.bloom_filter_enabled, + encoding=opts.encoding, + dictionary_enabled=opts.dictionary_enabled, + compression=opts.compression, + statistics_enabled=opts.statistics_enabled, + bloom_filter_fpp=opts.bloom_filter_fpp, + bloom_filter_ndv=opts.bloom_filter_ndv, + ) - self.df.write_parquet(str(path), compression.value, compression_level) + self.df.write_parquet( + str(path), + options_internal, + column_specific_options_internal, + ) def write_json(self, path: str | pathlib.Path) -> None: """Execute the :py:class:`DataFrame` and write the results to a JSON file. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index e01308c86..e1e29c45c 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -23,6 +23,7 @@ import pytest from datafusion import ( DataFrame, + ParquetColumnOptions, SessionContext, WindowFrame, column, @@ -62,6 +63,21 @@ def df(): return ctx.from_arrow(batch) +@pytest.fixture +def large_df(): + ctx = SessionContext() + + rows = 100000 + data = { + "a": list(range(rows)), + "b": [f"s-{i}" for i in range(rows)], + "c": [float(i + 0.1) for i in range(rows)], + } + batch = pa.record_batch(data) + + return ctx.from_arrow(batch) + + @pytest.fixture def struct_df(): ctx = SessionContext() @@ -1533,16 +1549,26 @@ def test_write_parquet(df, tmp_path, path_to_str): assert result == expected +def test_write_parquet_default_compression(df, tmp_path): + """Test that the default compression is ZSTD.""" + df.write_parquet(tmp_path) + + for file in tmp_path.rglob("*.parquet"): + metadata = pq.ParquetFile(file).metadata.to_dict() + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["compression"].lower() == "zstd" + + @pytest.mark.parametrize( - ("compression", "compression_level"), - [("gzip", 6), ("brotli", 7), ("zstd", 15)], + "compression", + ["gzip(6)", "brotli(7)", "zstd(15)", "snappy", "uncompressed"], ) -def test_write_compressed_parquet(df, tmp_path, compression, compression_level): - path = tmp_path +def test_write_compressed_parquet(df, tmp_path, compression): + import re - df.write_parquet( - str(path), compression=compression, compression_level=compression_level - ) + path = tmp_path + df.write_parquet(str(path), compression=compression) # test that the actual compression scheme is the one written for _root, _dirs, files in os.walk(path): @@ -1550,8 +1576,10 @@ def test_write_compressed_parquet(df, tmp_path, compression, compression_level): if file.endswith(".parquet"): metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict() for row_group in metadata["row_groups"]: - for columns in row_group["columns"]: - assert columns["compression"].lower() == compression + for col in row_group["columns"]: + assert col["compression"].lower() == re.sub( + r"\(\d+\)", "", compression + ) result = pq.read_table(str(path)).to_pydict() expected = df.to_pydict() @@ -1560,40 +1588,323 @@ def test_write_compressed_parquet(df, tmp_path, compression, compression_level): @pytest.mark.parametrize( - ("compression", "compression_level"), - [("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)], + "compression", + ["gzip(12)", "brotli(15)", "zstd(23)"], ) -def test_write_compressed_parquet_wrong_compression_level( - df, tmp_path, compression, compression_level -): +def test_write_compressed_parquet_wrong_compression_level(df, tmp_path, compression): path = tmp_path - with pytest.raises(ValueError): - df.write_parquet( - str(path), - compression=compression, - compression_level=compression_level, - ) + with pytest.raises(Exception, match=r"valid compression range .*? exceeded."): + df.write_parquet(str(path), compression=compression) -@pytest.mark.parametrize("compression", ["wrong"]) +@pytest.mark.parametrize("compression", ["wrong", "wrong(12)"]) def test_write_compressed_parquet_invalid_compression(df, tmp_path, compression): path = tmp_path - with pytest.raises(ValueError): + with pytest.raises(Exception, match="Unknown or unsupported parquet compression"): df.write_parquet(str(path), compression=compression) -# not testing lzo because it it not implemented yet -# https://github.com/apache/arrow-rs/issues/6970 -@pytest.mark.parametrize("compression", ["zstd", "brotli", "gzip"]) -def test_write_compressed_parquet_default_compression_level(df, tmp_path, compression): - # Test write_parquet with zstd, brotli, gzip default compression level, - # ie don't specify compression level - # should complete without error - path = tmp_path +@pytest.mark.parametrize( + ("writer_version", "format_version"), + [("1.0", "1.0"), ("2.0", "2.6"), (None, "1.0")], +) +def test_write_parquet_writer_version(df, tmp_path, writer_version, format_version): + """Test the Parquet writer version. Note that writer_version=2.0 results in + format_version=2.6""" + if writer_version is None: + df.write_parquet(tmp_path) + else: + df.write_parquet(tmp_path, writer_version=writer_version) - df.write_parquet(str(path), compression=compression) + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + assert metadata["format_version"] == format_version + + +@pytest.mark.parametrize("writer_version", ["1.2.3", "custom-version", "0"]) +def test_write_parquet_wrong_writer_version(df, tmp_path, writer_version): + """Test that invalid writer versions in Parquet throw an exception.""" + with pytest.raises( + Exception, match="Unknown or unsupported parquet writer version" + ): + df.write_parquet(tmp_path, writer_version=writer_version) + + +@pytest.mark.parametrize("dictionary_enabled", [True, False, None]) +def test_write_parquet_dictionary_enabled(df, tmp_path, dictionary_enabled): + """Test enabling/disabling the dictionaries in Parquet.""" + df.write_parquet(tmp_path, dictionary_enabled=dictionary_enabled) + # by default, the dictionary is enabled, so None results in True + result = dictionary_enabled if dictionary_enabled is not None else True + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["has_dictionary_page"] == result + + +@pytest.mark.parametrize( + ("statistics_enabled", "has_statistics"), + [("page", True), ("chunk", True), ("none", False), (None, True)], +) +def test_write_parquet_statistics_enabled( + df, tmp_path, statistics_enabled, has_statistics +): + """Test configuring the statistics in Parquet. In pyarrow we can only check for + column-level statistics, so "page" and "chunk" are tested in the same way.""" + df.write_parquet(tmp_path, statistics_enabled=statistics_enabled) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + if has_statistics: + assert col["statistics"] is not None + else: + assert col["statistics"] is None + + +@pytest.mark.parametrize("max_row_group_size", [1000, 5000, 10000, 100000]) +def test_write_parquet_max_row_group_size(large_df, tmp_path, max_row_group_size): + """Test configuring the max number of rows per group in Parquet. These test cases + guarantee that the number of rows for each row group is max_row_group_size, given + the total number of rows is a multiple of max_row_group_size.""" + large_df.write_parquet(tmp_path, max_row_group_size=max_row_group_size) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + for row_group in metadata["row_groups"]: + assert row_group["num_rows"] == max_row_group_size + + +@pytest.mark.parametrize("created_by", ["datafusion", "datafusion-python", "custom"]) +def test_write_parquet_created_by(df, tmp_path, created_by): + """Test configuring the created by metadata in Parquet.""" + df.write_parquet(tmp_path, created_by=created_by) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + assert metadata["created_by"] == created_by + + +@pytest.mark.parametrize("statistics_truncate_length", [5, 25, 50]) +def test_write_parquet_statistics_truncate_length( + df, tmp_path, statistics_truncate_length +): + """Test configuring the truncate limit in Parquet's row-group-level statistics.""" + ctx = SessionContext() + data = { + "a": [ + "a_the_quick_brown_fox_jumps_over_the_lazy_dog", + "m_the_quick_brown_fox_jumps_over_the_lazy_dog", + "z_the_quick_brown_fox_jumps_over_the_lazy_dog", + ], + "b": ["a_smaller", "m_smaller", "z_smaller"], + } + df = ctx.from_arrow(pa.record_batch(data)) + df.write_parquet(tmp_path, statistics_truncate_length=statistics_truncate_length) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + statistics = col["statistics"] + assert len(statistics["min"]) <= statistics_truncate_length + assert len(statistics["max"]) <= statistics_truncate_length + + +def test_write_parquet_default_encoding(tmp_path): + """Test that, by default, Parquet files are written with dictionary encoding. + Note that dictionary encoding is not used for boolean values, so it is not tested + here.""" + ctx = SessionContext() + data = { + "a": [1, 2, 3], + "b": ["1", "2", "3"], + "c": [1.01, 2.02, 3.03], + } + df = ctx.from_arrow(pa.record_batch(data)) + df.write_parquet(tmp_path) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["encodings"] == ("PLAIN", "RLE", "RLE_DICTIONARY") + + +@pytest.mark.parametrize( + ("encoding", "data_types", "result"), + [ + ("plain", ["int", "float", "str", "bool"], ("PLAIN", "RLE")), + ("rle", ["bool"], ("RLE",)), + ("delta_binary_packed", ["int"], ("RLE", "DELTA_BINARY_PACKED")), + ("delta_length_byte_array", ["str"], ("RLE", "DELTA_LENGTH_BYTE_ARRAY")), + ("delta_byte_array", ["str"], ("RLE", "DELTA_BYTE_ARRAY")), + ("byte_stream_split", ["int", "float"], ("RLE", "BYTE_STREAM_SPLIT")), + ], +) +def test_write_parquet_encoding(tmp_path, encoding, data_types, result): + """Test different encodings in Parquet in their respective support column types.""" + ctx = SessionContext() + + data = {} + for data_type in data_types: + match data_type: + case "int": + data["int"] = [1, 2, 3] + case "float": + data["float"] = [1.01, 2.02, 3.03] + case "str": + data["str"] = ["a", "b", "c"] + case "bool": + data["bool"] = [True, False, True] + + df = ctx.from_arrow(pa.record_batch(data)) + df.write_parquet(tmp_path, encoding=encoding, dictionary_enabled=False) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["encodings"] == result + + +@pytest.mark.parametrize("encoding", ["bit_packed"]) +def test_write_parquet_unsupported_encoding(df, tmp_path, encoding): + """Test that unsupported Parquet encodings do not work.""" + # BaseException is used since this throws a Rust panic: https://github.com/PyO3/pyo3/issues/3519 + with pytest.raises(BaseException, match="Encoding .*? is not supported"): + df.write_parquet(tmp_path, encoding=encoding) + + +@pytest.mark.parametrize("encoding", ["non_existent", "unknown", "plain123"]) +def test_write_parquet_invalid_encoding(df, tmp_path, encoding): + """Test that invalid Parquet encodings do not work.""" + with pytest.raises(Exception, match="Unknown or unsupported parquet encoding"): + df.write_parquet(tmp_path, encoding=encoding) + + +@pytest.mark.parametrize("encoding", ["plain_dictionary", "rle_dictionary"]) +def test_write_parquet_dictionary_encoding_fallback(df, tmp_path, encoding): + """Test that the dictionary encoding cannot be used as fallback in Parquet.""" + # BaseException is used since this throws a Rust panic: https://github.com/PyO3/pyo3/issues/3519 + with pytest.raises( + BaseException, match="Dictionary encoding can not be used as fallback encoding" + ): + df.write_parquet(tmp_path, encoding=encoding) + + +def test_write_parquet_bloom_filter(df, tmp_path): + """Test Parquet files with and without (default) bloom filters. Since pyarrow does + not expose any information about bloom filters, the easiest way to confirm that they + are actually written is to compare the file size.""" + path_no_bloom_filter = tmp_path / "1" + path_bloom_filter = tmp_path / "2" + + df.write_parquet(path_no_bloom_filter) + df.write_parquet(path_bloom_filter, bloom_filter_on_write=True) + + size_no_bloom_filter = 0 + for file in path_no_bloom_filter.rglob("*.parquet"): + size_no_bloom_filter += os.path.getsize(file) + + size_bloom_filter = 0 + for file in path_bloom_filter.rglob("*.parquet"): + size_bloom_filter += os.path.getsize(file) + + assert size_no_bloom_filter < size_bloom_filter + + +def test_write_parquet_column_options(df, tmp_path): + """Test writing Parquet files with different options for each column, which replace + the global configs (when provided).""" + data = { + "a": [1, 2, 3], + "b": ["a", "b", "c"], + "c": [False, True, False], + "d": [1.01, 2.02, 3.03], + "e": [4, 5, 6], + } + + column_specific_options = { + "a": ParquetColumnOptions(statistics_enabled="none"), + "b": ParquetColumnOptions(encoding="plain", dictionary_enabled=False), + "c": ParquetColumnOptions( + compression="snappy", encoding="rle", dictionary_enabled=False + ), + "d": ParquetColumnOptions( + compression="zstd(6)", + encoding="byte_stream_split", + dictionary_enabled=False, + statistics_enabled="none", + ), + # column "e" will use the global configs + } + + results = { + "a": { + "statistics": False, + "compression": "brotli", + "encodings": ("PLAIN", "RLE", "RLE_DICTIONARY"), + }, + "b": { + "statistics": True, + "compression": "brotli", + "encodings": ("PLAIN", "RLE"), + }, + "c": { + "statistics": True, + "compression": "snappy", + "encodings": ("RLE",), + }, + "d": { + "statistics": False, + "compression": "zstd", + "encodings": ("RLE", "BYTE_STREAM_SPLIT"), + }, + "e": { + "statistics": True, + "compression": "brotli", + "encodings": ("PLAIN", "RLE", "RLE_DICTIONARY"), + }, + } + + ctx = SessionContext() + df = ctx.from_arrow(pa.record_batch(data)) + df.write_parquet( + tmp_path, + compression="brotli(8)", + column_specific_options=column_specific_options, + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + column_name = col["path_in_schema"] + result = results[column_name] + assert (col["statistics"] is not None) == result["statistics"] + assert col["compression"].lower() == result["compression"].lower() + assert col["encodings"] == result["encodings"] def test_dataframe_export(df) -> None: diff --git a/src/dataframe.rs b/src/dataframe.rs index 211e31bd1..ffb3f36cf 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::ffi::CString; use std::sync::Arc; @@ -27,12 +28,11 @@ use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::arrow::util::pretty; use datafusion::common::UnnestOptions; -use datafusion::config::{CsvOptions, TableParquetOptions}; +use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::datasource::TableProvider; use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; -use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::prelude::*; use futures::{StreamExt, TryStreamExt}; use pyo3::exceptions::PyValueError; @@ -165,10 +165,105 @@ fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult< // Return the validated config, converting String error to PyErr config .validate() - .map_err(|e| pyo3::exceptions::PyValueError::new_err(e))?; + .map_err(pyo3::exceptions::PyValueError::new_err)?; Ok(config) } +/// Python mapping of `ParquetOptions` (includes just the writer-related options). +#[pyclass(name = "ParquetWriterOptions", module = "datafusion", subclass)] +#[derive(Clone, Default)] +pub struct PyParquetWriterOptions { + options: ParquetOptions, +} + +#[pymethods] +impl PyParquetWriterOptions { + #[new] + #[allow(clippy::too_many_arguments)] + pub fn new( + data_pagesize_limit: usize, + write_batch_size: usize, + writer_version: String, + skip_arrow_metadata: bool, + compression: Option, + dictionary_enabled: Option, + dictionary_page_size_limit: usize, + statistics_enabled: Option, + max_row_group_size: usize, + created_by: String, + column_index_truncate_length: Option, + statistics_truncate_length: Option, + data_page_row_count_limit: usize, + encoding: Option, + bloom_filter_on_write: bool, + bloom_filter_fpp: Option, + bloom_filter_ndv: Option, + allow_single_file_parallelism: bool, + maximum_parallel_row_group_writers: usize, + maximum_buffered_record_batches_per_stream: usize, + ) -> Self { + Self { + options: ParquetOptions { + data_pagesize_limit, + write_batch_size, + writer_version, + skip_arrow_metadata, + compression, + dictionary_enabled, + dictionary_page_size_limit, + statistics_enabled, + max_row_group_size, + created_by, + column_index_truncate_length, + statistics_truncate_length, + data_page_row_count_limit, + encoding, + bloom_filter_on_write, + bloom_filter_fpp, + bloom_filter_ndv, + allow_single_file_parallelism, + maximum_parallel_row_group_writers, + maximum_buffered_record_batches_per_stream, + ..Default::default() + }, + } + } +} + +/// Python mapping of `ParquetColumnOptions`. +#[pyclass(name = "ParquetColumnOptions", module = "datafusion", subclass)] +#[derive(Clone, Default)] +pub struct PyParquetColumnOptions { + options: ParquetColumnOptions, +} + +#[pymethods] +impl PyParquetColumnOptions { + #[new] + pub fn new( + bloom_filter_enabled: Option, + encoding: Option, + dictionary_enabled: Option, + compression: Option, + statistics_enabled: Option, + bloom_filter_fpp: Option, + bloom_filter_ndv: Option, + ) -> Self { + Self { + options: ParquetColumnOptions { + bloom_filter_enabled, + encoding, + dictionary_enabled, + compression, + statistics_enabled, + bloom_filter_fpp, + bloom_filter_ndv, + ..Default::default() + }, + } + } +} + /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. /// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. @@ -613,61 +708,28 @@ impl PyDataFrame { } /// Write a `DataFrame` to a Parquet file. - #[pyo3(signature = ( - path, - compression="zstd", - compression_level=None - ))] fn write_parquet( &self, path: &str, - compression: &str, - compression_level: Option, + options: PyParquetWriterOptions, + column_specific_options: HashMap, py: Python, ) -> PyDataFusionResult<()> { - fn verify_compression_level(cl: Option) -> Result { - cl.ok_or(PyValueError::new_err("compression_level is not defined")) - } - - let _validated = match compression.to_lowercase().as_str() { - "snappy" => Compression::SNAPPY, - "gzip" => Compression::GZIP( - GzipLevel::try_new(compression_level.unwrap_or(6)) - .map_err(|e| PyValueError::new_err(format!("{e}")))?, - ), - "brotli" => Compression::BROTLI( - BrotliLevel::try_new(verify_compression_level(compression_level)?) - .map_err(|e| PyValueError::new_err(format!("{e}")))?, - ), - "zstd" => Compression::ZSTD( - ZstdLevel::try_new(verify_compression_level(compression_level)? as i32) - .map_err(|e| PyValueError::new_err(format!("{e}")))?, - ), - "lzo" => Compression::LZO, - "lz4" => Compression::LZ4, - "lz4_raw" => Compression::LZ4_RAW, - "uncompressed" => Compression::UNCOMPRESSED, - _ => { - return Err(PyDataFusionError::Common(format!( - "Unrecognized compression type {compression}" - ))); - } + let table_options = TableParquetOptions { + global: options.options, + column_specific_options: column_specific_options + .into_iter() + .map(|(k, v)| (k, v.options)) + .collect(), + ..Default::default() }; - let mut compression_string = compression.to_string(); - if let Some(level) = compression_level { - compression_string.push_str(&format!("({level})")); - } - - let mut options = TableParquetOptions::default(); - options.global.compression = Some(compression_string); - wait_for_future( py, self.df.as_ref().clone().write_parquet( path, DataFrameWriteOptions::new(), - Option::from(options), + Option::from(table_options), ), )?; Ok(()) diff --git a/src/lib.rs b/src/lib.rs index 6eeda0878..990231c66 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -85,6 +85,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?;