diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 7504dd4385e43..df402584cc910 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -1697,7 +1697,7 @@ mod tests { .build()?; let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8 - let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false }"; + let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false, set: None }"; assert!(format!("{:?}", execution_plan).contains(expected)); // expression: "a in (true, 'a')" @@ -1733,6 +1733,64 @@ mod tests { Ok(()) } + #[tokio::test] + async fn in_set_test() -> Result<()> { + let testdata = crate::test_util::arrow_test_data(); + let path = format!("{}/csv/aggregate_test_100.csv", testdata); + let options = CsvReadOptions::new().schema_infer_max_records(100); + + // OPTIMIZER_INSET_THRESHOLD = 10 + // expression: "a in ('a', 1, 2, ..30)" + let mut list = vec![Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))]; + for i in 1..31 { + list.push(Expr::Literal(ScalarValue::Int64(Some(i)))); + } + + let logical_plan = LogicalPlanBuilder::scan_csv( + Arc::new(LocalFileSystem {}), + &path, + options, + None, + 1, + ) + .await? + .filter(col("c12").lt(lit(0.05)))? + .project(vec![col("c1").in_list(list, false)])? + .build()?; + let execution_plan = plan(&logical_plan).await?; + let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(4) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(6) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(7) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(8) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(9) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(10) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(11) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(12) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(13) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(14) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(15) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(16) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(17) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(18) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(19) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(20) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(21) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(22) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(23) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(24) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(25) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(26) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(27) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(28) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(29) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(30) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false, set: Some(InSet { set:"; + assert!(format!("{:?}", execution_plan).contains(expected)); + Ok(()) + } + + #[tokio::test] + async fn in_set_null_test() -> Result<()> { + let testdata = crate::test_util::arrow_test_data(); + let path = format!("{}/csv/aggregate_test_100.csv", testdata); + let options = CsvReadOptions::new().schema_infer_max_records(100); + // test NULL + let mut list = vec![Expr::Literal(ScalarValue::Int64(None))]; + for i in 1..31 { + list.push(Expr::Literal(ScalarValue::Int64(Some(i)))); + } + + let logical_plan = LogicalPlanBuilder::scan_csv( + Arc::new(LocalFileSystem {}), + &path, + options, + None, + 1, + ) + .await? + .filter(col("c12").lt(lit(0.05)))? + .project(vec![col("c1").in_list(list, false)])? + .build()?; + let execution_plan = plan(&logical_plan).await?; + let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [CastExpr { expr: Literal { value: Int64(NULL) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(4) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(6) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(7) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(8) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(9) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(10) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(11) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(12) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(13) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(14) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(15) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(16) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(17) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(18) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(19) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(20) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(21) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(22) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(23) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(24) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(25) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(26) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(27) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(28) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(29) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(30) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false, set: Some(InSet { set: "; + assert!(format!("{:?}", execution_plan).contains(expected)); + Ok(()) + } + #[tokio::test] async fn hash_agg_input_schema() -> Result<()> { let testdata = crate::test_util::arrow_test_data(); diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs index 1369baa75f4cc..ea79e2b142d55 100644 --- a/datafusion/core/tests/sql/predicates.rs +++ b/datafusion/core/tests/sql/predicates.rs @@ -369,3 +369,20 @@ async fn test_expect_distinct() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn csv_in_set_test() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT count(*) FROM aggregate_test_100 WHERE c7 in ('25','155','204','77','208','67','139','191','26','7','202','113','129','197','249','146','129','220','154','163','220','19','71','243','150','231','196','170','99','255');"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 36 |", + "+-----------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 2aee0d87dbde3..a6894b938ff68 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -18,6 +18,7 @@ //! InList expression use std::any::Any; +use std::collections::HashSet; use std::sync::Arc; use arrow::array::GenericStringArray; @@ -32,13 +33,19 @@ use arrow::{ record_batch::RecordBatch, }; -use crate::PhysicalExpr; +use crate::{expressions, PhysicalExpr}; use arrow::array::*; use arrow::buffer::{Buffer, MutableBuffer}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; +/// Size at which to use a Set rather than Vec for `IN` / `NOT IN` +/// Value chosen by the benchmark at +/// https://github.com/apache/arrow-datafusion/pull/2156#discussion_r845198369 +/// TODO: add switch codeGen in In_List +static OPTIMIZER_INSET_THRESHOLD: usize = 30; + macro_rules! compare_op_scalar { ($left: expr, $right:expr, $op:expr) => {{ let null_bit_buffer = $left.data().null_buffer().cloned(); @@ -69,6 +76,23 @@ pub struct InListExpr { expr: Arc, list: Vec>, negated: bool, + set: Option, +} + +/// InSet +#[derive(Debug)] +pub struct InSet { + set: HashSet, +} + +impl InSet { + pub fn new(set: HashSet) -> Self { + Self { set } + } + + pub fn get_set(&self) -> &HashSet { + &self.set + } } macro_rules! make_contains { @@ -181,6 +205,26 @@ macro_rules! make_contains_primitive { }}; } +macro_rules! set_contains_with_negated { + ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr) => {{ + if $NEGATED { + return Ok(ColumnarValue::Array(Arc::new( + $ARRAY + .iter() + .map(|x| x.map(|v| !$LIST_VALUES.contains(&v.try_into().unwrap()))) + .collect::(), + ))); + } else { + return Ok(ColumnarValue::Array(Arc::new( + $ARRAY + .iter() + .map(|x| x.map(|v| $LIST_VALUES.contains(&v.try_into().unwrap()))) + .collect::(), + ))); + } + }}; +} + // whether each value on the left (can be null) is contained in the non-null list fn in_list_primitive( array: &PrimitiveArray, @@ -220,6 +264,42 @@ fn not_in_list_utf8( compare_op_scalar!(array, values, |x, v: &[&str]| !v.contains(&x)) } +//check all filter values of In clause are static. +//include `CastExpr + Literal` or `Literal` +fn check_all_static_filter_expr(list: &[Arc]) -> bool { + list.iter().all(|v| { + let cast = v.as_any().downcast_ref::(); + if let Some(c) = cast { + c.expr() + .as_any() + .downcast_ref::() + .is_some() + } else { + let cast = v.as_any().downcast_ref::(); + cast.is_some() + } + }) +} + +fn cast_static_filter_to_set(list: &[Arc]) -> HashSet { + HashSet::from_iter(list.iter().map(|expr| { + if let Some(cast) = expr.as_any().downcast_ref::() { + cast.expr() + .as_any() + .downcast_ref::() + .unwrap() + .value() + .clone() + } else { + expr.as_any() + .downcast_ref::() + .unwrap() + .value() + .clone() + } + })) +} + impl InListExpr { /// Create a new InList expression pub fn new( @@ -227,10 +307,20 @@ impl InListExpr { list: Vec>, negated: bool, ) -> Self { - Self { - expr, - list, - negated, + if list.len() > OPTIMIZER_INSET_THRESHOLD && check_all_static_filter_expr(&list) { + Self { + expr, + set: Some(InSet::new(cast_static_filter_to_set(&list))), + list, + negated, + } + } else { + Self { + expr, + list, + negated, + set: None, + } } } @@ -318,7 +408,13 @@ impl InListExpr { impl std::fmt::Display for InListExpr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { if self.negated { - write!(f, "{} NOT IN ({:?})", self.expr, self.list) + if self.set.is_some() { + write!(f, "{} NOT IN (SET) ({:?})", self.expr, self.list) + } else { + write!(f, "{} NOT IN ({:?})", self.expr, self.list) + } + } else if self.set.is_some() { + write!(f, "Use {} IN (SET) ({:?})", self.expr, self.list) } else { write!(f, "{} IN ({:?})", self.expr, self.list) } @@ -342,119 +438,202 @@ impl PhysicalExpr for InListExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?; let value_data_type = value.data_type(); - let list_values = self - .list - .iter() - .map(|expr| expr.evaluate(batch)) - .collect::>>()?; - - let array = match value { - ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => scalar.to_array(), - }; - match value_data_type { - DataType::Float32 => { - make_contains_primitive!( - array, - list_values, - self.negated, - Float32, - Float32Array - ) - } - DataType::Float64 => { - make_contains_primitive!( - array, - list_values, - self.negated, - Float64, - Float64Array - ) - } - DataType::Int16 => { - make_contains_primitive!( - array, - list_values, - self.negated, - Int16, - Int16Array - ) - } - DataType::Int32 => { - make_contains_primitive!( - array, - list_values, - self.negated, - Int32, - Int32Array - ) - } - DataType::Int64 => { - make_contains_primitive!( - array, - list_values, - self.negated, - Int64, - Int64Array - ) - } - DataType::Int8 => { - make_contains_primitive!( - array, - list_values, - self.negated, - Int8, - Int8Array - ) - } - DataType::UInt16 => { - make_contains_primitive!( - array, - list_values, - self.negated, - UInt16, - UInt16Array - ) - } - DataType::UInt32 => { - make_contains_primitive!( - array, - list_values, - self.negated, - UInt32, - UInt32Array - ) - } - DataType::UInt64 => { - make_contains_primitive!( - array, - list_values, - self.negated, - UInt64, - UInt64Array - ) - } - DataType::UInt8 => { - make_contains_primitive!( - array, - list_values, - self.negated, - UInt8, - UInt8Array - ) - } - DataType::Boolean => { - make_contains!(array, list_values, self.negated, Boolean, BooleanArray) - } - DataType::Utf8 => self.compare_utf8::(array, list_values, self.negated), - DataType::LargeUtf8 => { - self.compare_utf8::(array, list_values, self.negated) + if let Some(in_set) = &self.set { + let array = match value { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => scalar.to_array(), + }; + let set = in_set.get_set(); + match value_data_type { + DataType::Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + DataType::Int8 => { + let array = array.as_any().downcast_ref::().unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + DataType::Int16 => { + let array = array.as_any().downcast_ref::().unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + DataType::Int32 => { + let array = array.as_any().downcast_ref::().unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + DataType::Int64 => { + let array = array.as_any().downcast_ref::().unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + DataType::UInt8 => { + let array = array.as_any().downcast_ref::().unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + DataType::UInt16 => { + let array = array.as_any().downcast_ref::().unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + DataType::UInt32 => { + let array = array.as_any().downcast_ref::().unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + DataType::UInt64 => { + let array = array.as_any().downcast_ref::().unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + DataType::Float32 => { + let array = array.as_any().downcast_ref::().unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + DataType::Float64 => { + let array = array.as_any().downcast_ref::().unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + DataType::Utf8 => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + DataType::LargeUtf8 => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + set_contains_with_negated!(array, set, self.negated) + } + datatype => { + return Result::Err(DataFusionError::NotImplemented(format!( + "InSet does not support datatype {:?}.", + datatype + ))) + } + }; + } else { + let list_values = self + .list + .iter() + .map(|expr| expr.evaluate(batch)) + .collect::>>()?; + + let array = match value { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => scalar.to_array(), + }; + + match value_data_type { + DataType::Float32 => { + make_contains_primitive!( + array, + list_values, + self.negated, + Float32, + Float32Array + ) + } + DataType::Float64 => { + make_contains_primitive!( + array, + list_values, + self.negated, + Float64, + Float64Array + ) + } + DataType::Int16 => { + make_contains_primitive!( + array, + list_values, + self.negated, + Int16, + Int16Array + ) + } + DataType::Int32 => { + make_contains_primitive!( + array, + list_values, + self.negated, + Int32, + Int32Array + ) + } + DataType::Int64 => { + make_contains_primitive!( + array, + list_values, + self.negated, + Int64, + Int64Array + ) + } + DataType::Int8 => { + make_contains_primitive!( + array, + list_values, + self.negated, + Int8, + Int8Array + ) + } + DataType::UInt16 => { + make_contains_primitive!( + array, + list_values, + self.negated, + UInt16, + UInt16Array + ) + } + DataType::UInt32 => { + make_contains_primitive!( + array, + list_values, + self.negated, + UInt32, + UInt32Array + ) + } + DataType::UInt64 => { + make_contains_primitive!( + array, + list_values, + self.negated, + UInt64, + UInt64Array + ) + } + DataType::UInt8 => { + make_contains_primitive!( + array, + list_values, + self.negated, + UInt8, + UInt8Array + ) + } + DataType::Boolean => { + make_contains!( + array, + list_values, + self.negated, + Boolean, + BooleanArray + ) + } + DataType::Utf8 => { + self.compare_utf8::(array, list_values, self.negated) + } + DataType::LargeUtf8 => { + self.compare_utf8::(array, list_values, self.negated) + } + datatype => Result::Err(DataFusionError::NotImplemented(format!( + "InList does not support datatype {:?}.", + datatype + ))), } - datatype => Result::Err(DataFusionError::NotImplemented(format!( - "InList does not support datatype {:?}.", - datatype - ))), } } }