Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2230,7 +2230,7 @@ mod tests {
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.

let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } } }";
let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } } }";

let actual = format!("{execution_plan:?}");
assert!(actual.contains(expected), "{}", actual);
Expand Down
70 changes: 70 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/predicates.slt
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,76 @@ SELECT * FROM test WHERE column1 IN ('foo', 'Bar', 'fazzz')
foo
fazzz


###
# Test logical plan simplifies large OR chains
Comment thread
jackwener marked this conversation as resolved.
###

statement ok
set datafusion.explain.logical_plan_only = true

# Number of OR statements is less than or equal to threshold
query TT
EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR column1 = 'bar' OR column1 = 'fazzz'
----
logical_plan
Filter: test.column1 = Utf8("foo") OR test.column1 = Utf8("bar") OR test.column1 = Utf8("fazzz")
--TableScan: test projection=[column1]

# Number of OR statements is greater than threshold
query TT
EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR column1 = 'bar' OR column1 = 'fazzz' OR column1 = 'barfoo'
----
logical_plan
Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), Utf8("barfoo")])
--TableScan: test projection=[column1]

# Complex OR statements
query TT
EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR column1 = 'bar' OR column1 = 'fazzz' OR column1 = 'barfoo' OR false OR column1 = 'foobar'
----
logical_plan
Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), Utf8("barfoo"), Utf8("foobar")])
--TableScan: test projection=[column1]

# Balanced OR structures
query TT
EXPLAIN SELECT * FROM test WHERE (column1 = 'foo' OR column1 = 'bar') OR (column1 = 'fazzz' OR column1 = 'barfoo')
----
logical_plan
Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), Utf8("barfoo")])
--TableScan: test projection=[column1]

# Right-deep OR structures
Comment thread
aprimadi marked this conversation as resolved.
query TT
EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR (column1 = 'bar' OR (column1 = 'fazzz' OR column1 = 'barfoo'))
----
logical_plan
Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), Utf8("barfoo")])
--TableScan: test projection=[column1]

# Not simplifiable, mixed column
query TT
EXPLAIN SELECT * FROM aggregate_test_100
WHERE (c2 = 1 OR c3 = 100) OR (c2 = 2 OR c2 = 3 OR c2 = 4)
----
logical_plan
Filter: aggregate_test_100.c2 = Int8(1) OR aggregate_test_100.c3 = Int16(100) OR aggregate_test_100.c2 = Int8(2) OR aggregate_test_100.c2 = Int8(3) OR aggregate_test_100.c2 = Int8(4)
--TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c2 = Int8(1) OR aggregate_test_100.c3 = Int16(100) OR aggregate_test_100.c2 = Int8(2) OR aggregate_test_100.c2 = Int8(3) OR aggregate_test_100.c2 = Int8(4)]

# Partially simplifiable, mixed column
query TT
EXPLAIN SELECT * FROM aggregate_test_100
WHERE (c2 = 1 OR c3 = 100) OR (c2 = 2 OR c2 = 3 OR c2 = 4 OR c2 = 5)
----
logical_plan
Filter: aggregate_test_100.c2 = Int8(1) OR aggregate_test_100.c3 = Int16(100) OR aggregate_test_100.c2 IN ([Int8(2), Int8(3), Int8(4), Int8(5)])
--TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c2 = Int8(1) OR aggregate_test_100.c3 = Int16(100) OR aggregate_test_100.c2 IN ([Int8(2), Int8(3), Int8(4), Int8(5)])]

statement ok
set datafusion.explain.logical_plan_only = false


# async fn test_expect_all
query IR
SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT ALL SELECT int_col, double_col FROM alltypes_plain where int_col < 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ Sort: lineitem.l_shipmode ASC NULLS LAST
------Projection: lineitem.l_shipmode, orders.o_orderpriority
--------Inner Join: lineitem.l_orderkey = orders.o_orderkey
----------Projection: lineitem.l_orderkey, lineitem.l_shipmode
------------Filter: (lineitem.l_shipmode = Utf8("SHIP") OR lineitem.l_shipmode = Utf8("MAIL")) AND lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("SHIP") OR lineitem.l_shipmode = Utf8("MAIL"), lineitem.l_commitdate < lineitem.l_receiptdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("8766"), lineitem.l_receiptdate < Date32("9131")]
------------Filter: (lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP")) AND lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP"), lineitem.l_commitdate < lineitem.l_receiptdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("8766"), lineitem.l_receiptdate < Date32("9131")]
----------TableScan: orders projection=[o_orderkey, o_orderpriority]
physical_plan
SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
Expand All @@ -73,7 +73,7 @@ SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
----------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4
------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_shipmode@4 as l_shipmode]
--------------------------CoalesceBatchesExec: target_batch_size=8192
----------------------------FilterExec: (l_shipmode@4 = SHIP OR l_shipmode@4 = MAIL) AND l_commitdate@2 < l_receiptdate@3 AND l_shipdate@1 < l_commitdate@2 AND l_receiptdate@3 >= 8766 AND l_receiptdate@3 < 9131
----------------------------FilterExec: (l_shipmode@4 = MAIL OR l_shipmode@4 = SHIP) AND l_commitdate@2 < l_receiptdate@3 AND l_shipdate@1 < l_commitdate@2 AND l_receiptdate@3 >= 8766 AND l_receiptdate@3 < 9131
------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], has_header=false
--------------------CoalesceBatchesExec: target_batch_size=8192
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS re
----Projection: lineitem.l_extendedprice, lineitem.l_discount
------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
--------Projection: lineitem.l_partkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount
----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR"), lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)]
----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR") OR lineitem.l_shipmode = Utf8("AIR REG")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("AIR") OR lineitem.l_shipmode = Utf8("AIR REG"), lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)]
--------Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)]
physical_plan
Expand All @@ -75,7 +75,7 @@ ProjectionExec: expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_disco
----------------RepartitionExec: partitioning=Hash([Column { name: "l_partkey", index: 0 }], 4), input_partitions=4
------------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_quantity@1 as l_quantity, l_extendedprice@2 as l_extendedprice, l_discount@3 as l_discount]
--------------------CoalesceBatchesExec: target_batch_size=8192
----------------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR REG OR l_shipmode@5 = AIR) AND l_shipinstruct@4 = DELIVER IN PERSON
----------------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON
------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
--------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], has_header=false
--------------CoalesceBatchesExec: target_batch_size=8192
Expand Down
76 changes: 59 additions & 17 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

use std::ops::Not;

use super::or_in_list_simplifier::OrInListSimplifier;
use super::utils::*;

use crate::analyzer::type_coercion::TypeCoercionRewriter;
use crate::simplify_expressions::regex::simplify_regex_expr;
use arrow::{
Expand Down Expand Up @@ -116,13 +118,15 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
pub fn simplify(&self, expr: Expr) -> Result<Expr> {
let mut simplifier = Simplifier::new(&self.info);
let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
let mut or_in_list_simplifier = OrInListSimplifier::new();

// TODO iterate until no changes are made during rewrite
// (evaluating constants can enable new simplifications and
// simplifications can enable new constant evaluation)
// https://github.com/apache/arrow-datafusion/issues/1160
expr.rewrite(&mut const_evaluator)?
.rewrite(&mut simplifier)?
.rewrite(&mut or_in_list_simplifier)?
// run both passes twice to try an minimize simplifications that we missed
.rewrite(&mut const_evaluator)?
.rewrite(&mut simplifier)
Expand Down Expand Up @@ -432,17 +436,37 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
{
let first_val = list[0].clone();
if negated {
list.into_iter()
.skip(1)
.fold((*expr.clone()).not_eq(first_val), |acc, y| {
(*expr.clone()).not_eq(y).and(acc)
})
list.into_iter().skip(1).fold(
(*expr.clone()).not_eq(first_val),
|acc, y| {
// Note that `A and B and C and D` is a left-deep tree structure
// as such we want to maintain this structure as much as possible
// to avoid reordering the expression during each optimization
// pass.
//
// Left-deep tree structure for `A and B and C and D`:
// ```
// &
// / \
// & D
// / \
// & C
// / \
// A B
// ```
//
// The code below maintain the left-deep tree structure.
acc.and((*expr.clone()).not_eq(y))
},
)
} else {
list.into_iter()
.skip(1)
.fold((*expr.clone()).eq(first_val), |acc, y| {
(*expr.clone()).eq(y).or(acc)
})
list.into_iter().skip(1).fold(
(*expr.clone()).eq(first_val),
|acc, y| {
// Same reasoning as above
acc.or((*expr.clone()).eq(y))
},
)
}
}
//
Expand Down Expand Up @@ -2888,11 +2912,11 @@ mod tests {

assert_eq!(
simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)),
col("c1").eq(lit(2)).or(col("c1").eq(lit(1)))
col("c1").eq(lit(1)).or(col("c1").eq(lit(2)))
);
assert_eq!(
simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)),
col("c1").not_eq(lit(2)).and(col("c1").not_eq(lit(1)))
col("c1").not_eq(lit(1)).and(col("c1").not_eq(lit(2)))
);

let subquery = Arc::new(test_table_scan_with_name("test").unwrap());
Expand All @@ -2918,26 +2942,44 @@ mod tests {
let subquery2 =
scalar_subquery(Arc::new(test_table_scan_with_name("test2").unwrap()));

// c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery2> AND c1 != <subquery1>
// c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery1> AND c1 != <subquery2>
assert_eq!(
simplify(in_list(
col("c1"),
vec![subquery1.clone(), subquery2.clone()],
true
)),
col("c1")
.not_eq(subquery2.clone())
.and(col("c1").not_eq(subquery1.clone()))
.not_eq(subquery1.clone())
.and(col("c1").not_eq(subquery2.clone()))
);

// c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery2> OR c1 == <subquery1>
// c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery1> OR c1 == <subquery2>
assert_eq!(
simplify(in_list(
col("c1"),
vec![subquery1.clone(), subquery2.clone()],
false
)),
col("c1").eq(subquery2).or(col("c1").eq(subquery1))
col("c1").eq(subquery1).or(col("c1").eq(subquery2))
);

// c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) ->
// c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8)
let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
);
assert_eq!(simplify(expr.clone()), expr);
}

#[test]
fn simplify_large_or() {
let expr = (0..5)
.map(|i| col("c1").eq(lit(i)))
.fold(lit(false), |acc, e| acc.or(e));
assert_eq!(
simplify(expr),
in_list(col("c1"), (0..5).map(lit).collect(), false),
);
}

Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/src/simplify_expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

pub mod context;
pub mod expr_simplifier;
mod or_in_list_simplifier;
mod regex;
pub mod simplify_exprs;
mod utils;
Expand Down
Loading