diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 212ffdaaa2a52..aff11b0dddeea 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -241,15 +241,15 @@ async fn sort_preserving_merge() { // SortPreservingMergeExec (not a Sort which would compete // with the SortPreservingMergeExec for memory) &[ - "+---------------+------------------------------------------------------------------------------------------------------------+", - "| plan_type | plan |", - "+---------------+------------------------------------------------------------------------------------------------------------+", - "| logical_plan | Sort: t.a ASC NULLS LAST, t.b ASC NULLS LAST, fetch=10 |", - "| | TableScan: t projection=[a, b] |", - "| physical_plan | SortPreservingMergeExec: [a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], fetch=10 |", - "| | MemoryExec: partitions=2, partition_sizes=[5, 5], output_ordering=a@0 ASC NULLS LAST, b@1 ASC NULLS LAST |", - "| | |", - "+---------------+------------------------------------------------------------------------------------------------------------+", + "+---------------+----------------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+----------------------------------------------------------------------------------------------------------------------+", + "| logical_plan | Sort: t.a ASC NULLS LAST, t.b ASC NULLS LAST, fetch=10 |", + "| | TableScan: t projection=[a, b] |", + "| physical_plan | SortPreservingMergeExec: [a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], fetch=10 |", + "| | MemoryExec: partitions=2, partition_sizes=[5, 5], limit=10, output_ordering=a@0 ASC NULLS LAST, b@1 ASC NULLS LAST |", + "| | |", + "+---------------+----------------------------------------------------------------------------------------------------------------------+", ] ) .run() diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 198b8ccd69926..876bd2ab61909 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -65,6 +65,8 @@ pub struct MemoryExec { cache: PlanProperties, /// if partition sizes should be displayed show_sizes: bool, + /// Maximum number of rows to return + fetch: Option, } impl fmt::Debug for MemoryExec { @@ -74,6 +76,7 @@ impl fmt::Debug for MemoryExec { .field("schema", &self.schema) .field("projection", &self.projection) .field("sort_information", &self.sort_information) + .field("fetch", &self.fetch) .finish() } } @@ -100,16 +103,20 @@ impl DisplayAs for MemoryExec { format!(", {}", constraints) }; + let limit = self + .fetch + .map_or(String::new(), |limit| format!(", limit={}", limit)); + if self.show_sizes { write!( f, - "MemoryExec: partitions={}, partition_sizes={partition_sizes:?}{output_ordering}{constraints}", + "MemoryExec: partitions={}, partition_sizes={partition_sizes:?}{limit}{output_ordering}{constraints}", partition_sizes.len(), ) } else { write!( f, - "MemoryExec: partitions={}{output_ordering}{constraints}", + "MemoryExec: partitions={}{limit}{output_ordering}{constraints}", partition_sizes.len(), ) } @@ -154,11 +161,14 @@ impl ExecutionPlan for MemoryExec { partition: usize, _context: Arc, ) -> Result { - Ok(Box::pin(MemoryStream::try_new( - self.partitions[partition].clone(), - Arc::clone(&self.projected_schema), - self.projection.clone(), - )?)) + Ok(Box::pin( + MemoryStream::try_new( + self.partitions[partition].clone(), + Arc::clone(&self.projected_schema), + self.projection.clone(), + )? + .with_fetch(self.fetch), + )) } /// We recompute the statistics dynamically from the arrow metadata as it is pretty cheap to do so @@ -193,6 +203,23 @@ impl ExecutionPlan for MemoryExec { }) .transpose() } + + fn fetch(&self) -> Option { + self.fetch + } + + fn with_fetch(&self, limit: Option) -> Option> { + Some(Arc::new(Self { + partitions: self.partitions.clone(), + schema: Arc::clone(&self.schema), + projected_schema: Arc::clone(&self.projected_schema), + projection: self.projection.clone(), + sort_information: self.sort_information.clone(), + cache: self.cache.clone(), + show_sizes: self.show_sizes, + fetch: limit, + })) + } } impl MemoryExec { @@ -219,6 +246,7 @@ impl MemoryExec { sort_information: vec![], cache, show_sizes: true, + fetch: None, }) } @@ -314,6 +342,7 @@ impl MemoryExec { sort_information: vec![], cache, show_sizes: true, + fetch: None, }) } @@ -462,6 +491,8 @@ pub struct MemoryStream { projection: Option>, /// Index into the data index: usize, + /// The remaining number of rows to return + fetch: Option, } impl MemoryStream { @@ -477,6 +508,7 @@ impl MemoryStream { schema, projection, index: 0, + fetch: None, }) } @@ -485,6 +517,12 @@ impl MemoryStream { self.reservation = Some(reservation); self } + + /// Set the number of rows to produce + pub(super) fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } } impl Stream for MemoryStream { @@ -494,20 +532,35 @@ impl Stream for MemoryStream { mut self: std::pin::Pin<&mut Self>, _: &mut Context<'_>, ) -> Poll> { - Poll::Ready(if self.index < self.data.len() { - self.index += 1; - let batch = &self.data[self.index - 1]; + if self.index >= self.data.len() { + return Poll::Ready(None); + } - // return just the columns requested - let batch = match self.projection.as_ref() { - Some(columns) => batch.project(columns)?, - None => batch.clone(), - }; + self.index += 1; + let batch = &self.data[self.index - 1]; - Some(Ok(batch)) + // return just the columns requested + let batch = match self.projection.as_ref() { + Some(columns) => batch.project(columns)?, + None => batch.clone(), + }; + + if self.fetch.is_none() { + return Poll::Ready(Some(Ok(batch))); + } + + let fetch = self.fetch.unwrap(); + if fetch == 0 { + return Poll::Ready(None); + } + + let batch = if batch.num_rows() > fetch { + batch.slice(0, fetch) } else { - None - }) + batch + }; + self.fetch = Some(fetch - batch.num_rows()); + Poll::Ready(Some(Ok(batch))) } fn size_hint(&self) -> (usize, Option) { @@ -859,7 +912,9 @@ mod tests { use crate::test::{self, make_partition}; use arrow_schema::{DataType, Field}; + use datafusion_common::assert_batches_eq; use datafusion_common::stats::{ColumnStatistics, Precision}; + use futures::StreamExt; #[tokio::test] async fn values_empty_case() -> Result<()> { @@ -944,4 +999,30 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn exec_with_limit() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let batch = make_partition(7); + let schema = batch.schema(); + let batches = vec![batch.clone(), batch]; + + let exec = MemoryExec::try_new_from_batches(schema, batches).unwrap(); + assert_eq!(exec.fetch(), None); + + let exec = exec.with_fetch(Some(4)).unwrap(); + assert_eq!(exec.fetch(), Some(4)); + + let mut it = exec.execute(0, task_ctx)?; + let mut results = vec![]; + while let Some(batch) = it.next().await { + results.push(batch?); + } + + let expected = [ + "+---+", "| i |", "+---+", "| 0 |", "| 1 |", "| 2 |", "| 3 |", "+---+", + ]; + assert_batches_eq!(expected, &results); + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 65f35d40fcf53..2aefdfc82b801 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -739,10 +739,8 @@ physical_plan 01)ProjectionExec: expr=[a@2 as a, b@3 as b, a@0 as a, b@1 as b] 02)--GlobalLimitExec: skip=0, fetch=10 03)----CrossJoinExec -04)------GlobalLimitExec: skip=0, fetch=1 -05)--------MemoryExec: partitions=1, partition_sizes=[1] -06)------GlobalLimitExec: skip=0, fetch=10 -07)--------MemoryExec: partitions=1, partition_sizes=[1] +04)------MemoryExec: partitions=1, partition_sizes=[1], limit=1 +05)------MemoryExec: partitions=1, partition_sizes=[1], limit=10 query IIII @@ -765,10 +763,8 @@ logical_plan physical_plan 01)GlobalLimitExec: skip=0, fetch=2 02)--CrossJoinExec -03)----GlobalLimitExec: skip=0, fetch=2 -04)------MemoryExec: partitions=1, partition_sizes=[1] -05)----GlobalLimitExec: skip=0, fetch=2 -06)------MemoryExec: partitions=1, partition_sizes=[1] +03)----MemoryExec: partitions=1, partition_sizes=[1], limit=2 +04)----MemoryExec: partitions=1, partition_sizes=[1], limit=2 statement ok drop table testSubQueryLimit; diff --git a/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt b/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt index 2f12e9c7a39b5..99271a28f9509 100644 --- a/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt +++ b/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt @@ -67,4 +67,4 @@ statement ok drop table test_substr_base; statement ok -drop table test_datetime_base; \ No newline at end of file +drop table test_datetime_base; diff --git a/datafusion/sqllogictest/test_files/string/large_string.slt b/datafusion/sqllogictest/test_files/string/large_string.slt index 93ec796ec6f05..9126a80383ef1 100644 --- a/datafusion/sqllogictest/test_files/string/large_string.slt +++ b/datafusion/sqllogictest/test_files/string/large_string.slt @@ -75,4 +75,4 @@ statement ok drop table test_substr_base; statement ok -drop table test_datetime_base; \ No newline at end of file +drop table test_datetime_base;