From d65609a67a4e683235f7bbdf7d29bd303aaf710f Mon Sep 17 00:00:00 2001 From: Kun H Date: Wed, 5 Feb 2025 10:54:08 +0000 Subject: [PATCH 1/3] Use stream in repr to truncate rows --- src/dataframe.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index b875480a7..4a6688dd8 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -33,6 +33,7 @@ use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::prelude::*; +use futures::StreamExt; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; @@ -90,8 +91,16 @@ impl PyDataFrame { } fn __repr__(&self, py: Python) -> PyResult { - let df = self.df.as_ref().clone().limit(0, Some(10))?; - let batches = wait_for_future(py, df.collect())?; + let df = self.df.as_ref().clone(); + + let stream = wait_for_future(py, df.execute_stream()).map_err(py_datafusion_err)?; + + let batches: Vec = wait_for_future( + py, + stream.take(10).collect::>()) + .into_iter() + .collect::,_>>()?; + let batches_as_string = pretty::pretty_format_batches(&batches); match batches_as_string { Ok(batch) => Ok(format!("DataFrame()\n{batch}")), From 51fdc126a31bf5c8be78b372c86c56c1e86cb3eb Mon Sep 17 00:00:00 2001 From: Kun H Date: Tue, 18 Feb 2025 12:24:09 +0000 Subject: [PATCH 2/3] Fix bug by moving take(10) on underlying batches rather than the previous stream of batches. Also refactor _repr_html_. --- src/dataframe.rs | 51 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index 935b20adc..ed886bac0 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -33,7 +33,7 @@ use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::prelude::*; -use futures::StreamExt; +use futures::{future, StreamExt}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; @@ -92,15 +92,7 @@ impl PyDataFrame { fn __repr__(&self, py: Python) -> PyDataFusionResult { let df = self.df.as_ref().clone(); - - let stream = wait_for_future(py, df.execute_stream()).map_err(py_datafusion_err)?; - - let batches: Vec = wait_for_future( - py, - stream.take(10).collect::>()) - .into_iter() - .collect::,_>>()?; - + let batches: Vec = get_batches(py, df, 10)?; let batches_as_string = pretty::pretty_format_batches(&batches); match batches_as_string { Ok(batch) => Ok(format!("DataFrame()\n{batch}")), @@ -111,8 +103,8 @@ impl PyDataFrame { fn _repr_html_(&self, py: Python) -> PyDataFusionResult { let mut html_str = "\n".to_string(); - let df = self.df.as_ref().clone().limit(0, Some(10))?; - let batches = wait_for_future(py, df.collect())?; + let df = self.df.as_ref().clone(); + let batches: Vec = get_batches(py, df, 10)?; if batches.is_empty() { html_str.push_str("
\n"); @@ -742,3 +734,38 @@ fn record_batch_into_schema( RecordBatch::try_new(schema, data_arrays) } + +fn get_batches( + py: Python, + df: DataFrame, + max_rows: usize, +) -> Result, PyDataFusionError> { + let partitioned_stream = wait_for_future(py, df.execute_stream_partitioned()).map_err(py_datafusion_err)?; + let stream = futures::stream::iter(partitioned_stream).flatten(); + wait_for_future( + py, + stream + .scan(0, |state, x| { + let total = *state; + if total >= max_rows { + future::ready(None) + } else { + match x { + Ok(batch) => { + if total + batch.num_rows() <= max_rows { + *state = total + batch.num_rows(); + future::ready(Some(Ok(batch))) + } else { + *state = max_rows; + future::ready(Some(Ok(batch.slice(0, max_rows - total)))) + } + } + Err(err) => future::ready(Some(Err(PyDataFusionError::from(err)))), + } + } + }) + .collect::>(), + ) + .into_iter() + .collect::, _>>() +} \ No newline at end of file From 549d2ddb3a6331c14e6d08fc5b3deaca89057db7 Mon Sep 17 00:00:00 2001 From: Kun H Date: Sun, 23 Feb 2025 15:00:53 +0000 Subject: [PATCH 3/3] Add comments and fix format check failure. --- src/dataframe.rs | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index ed886bac0..58533154b 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -92,7 +92,12 @@ impl PyDataFrame { fn __repr__(&self, py: Python) -> PyDataFusionResult { let df = self.df.as_ref().clone(); + + // Mostly the same functionality of `df.limit(0, 10).collect()`. But + // `df.limit(0, 10)` is a semantically different plan, which might be + // invalid. A case is df=`EXPLAIN ...` as `Explain` must be the root. let batches: Vec = get_batches(py, df, 10)?; + let batches_as_string = pretty::pretty_format_batches(&batches); match batches_as_string { Ok(batch) => Ok(format!("DataFrame()\n{batch}")), @@ -103,6 +108,9 @@ impl PyDataFrame { fn _repr_html_(&self, py: Python) -> PyDataFusionResult { let mut html_str = "\n".to_string(); + // Mostly the same functionality of `df.limit(0, 10).collect()`. But + // `df.limit(0, 10)` is a semantically different plan, which might be + // invalid. A case is df=`EXPLAIN ...` as `Explain` must be the root. let df = self.df.as_ref().clone(); let batches: Vec = get_batches(py, df, 10)?; @@ -735,12 +743,18 @@ fn record_batch_into_schema( RecordBatch::try_new(schema, data_arrays) } +/// get dataframe as a list of `RecordBatch`es containing at most `max_rows` rows. fn get_batches( py: Python, df: DataFrame, max_rows: usize, ) -> Result, PyDataFusionError> { - let partitioned_stream = wait_for_future(py, df.execute_stream_partitioned()).map_err(py_datafusion_err)?; + // Here uses `df.execute_stream_partitioned` instead of `df.execute_stream` + // as the later one internally appends `CoalescePartitionsExec` to merge + // the result into a signle partition thus might cause loading of + // unnecessary partitions. + let partitioned_stream = + wait_for_future(py, df.execute_stream_partitioned()).map_err(py_datafusion_err)?; let stream = futures::stream::iter(partitioned_stream).flatten(); wait_for_future( py, @@ -748,14 +762,17 @@ fn get_batches( .scan(0, |state, x| { let total = *state; if total >= max_rows { + // If scanning more than `max_rows`, then stop future::ready(None) } else { match x { Ok(batch) => { if total + batch.num_rows() <= max_rows { + // Add the whole batch when not exceeding `max_rows` *state = total + batch.num_rows(); future::ready(Some(Ok(batch))) } else { + // Partially load `max_rows - total` rows. *state = max_rows; future::ready(Some(Ok(batch.slice(0, max_rows - total)))) } @@ -768,4 +785,4 @@ fn get_batches( ) .into_iter() .collect::, _>>() -} \ No newline at end of file +}