Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1697,7 +1697,7 @@ mod tests {
.build()?;
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8
let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false }";
let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false, set: None }";
assert!(format!("{:?}", execution_plan).contains(expected));

// expression: "a in (true, 'a')"
Expand Down Expand Up @@ -1733,6 +1733,64 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn in_set_test() -> Result<()> {
let testdata = crate::test_util::arrow_test_data();
let path = format!("{}/csv/aggregate_test_100.csv", testdata);
let options = CsvReadOptions::new().schema_infer_max_records(100);

// OPTIMIZER_INSET_THRESHOLD = 10
// expression: "a in ('a', 1, 2, ..30)"
let mut list = vec![Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))];
for i in 1..31 {
list.push(Expr::Literal(ScalarValue::Int64(Some(i))));
}

let logical_plan = LogicalPlanBuilder::scan_csv(
Arc::new(LocalFileSystem {}),
&path,
options,
None,
1,
)
.await?
.filter(col("c12").lt(lit(0.05)))?
.project(vec![col("c1").in_list(list, false)])?
.build()?;
let execution_plan = plan(&logical_plan).await?;
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(4) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(6) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(7) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(8) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(9) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(10) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(11) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(12) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(13) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(14) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(15) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(16) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(17) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(18) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(19) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(20) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(21) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(22) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(23) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(24) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(25) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(26) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(27) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(28) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(29) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(30) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false, set: Some(InSet { set:";
assert!(format!("{:?}", execution_plan).contains(expected));
Ok(())
}

#[tokio::test]
async fn in_set_null_test() -> Result<()> {
let testdata = crate::test_util::arrow_test_data();
let path = format!("{}/csv/aggregate_test_100.csv", testdata);
let options = CsvReadOptions::new().schema_infer_max_records(100);
// test NULL
let mut list = vec![Expr::Literal(ScalarValue::Int64(None))];
for i in 1..31 {
list.push(Expr::Literal(ScalarValue::Int64(Some(i))));
}

let logical_plan = LogicalPlanBuilder::scan_csv(
Arc::new(LocalFileSystem {}),
&path,
options,
None,
1,
)
.await?
.filter(col("c12").lt(lit(0.05)))?
.project(vec![col("c1").in_list(list, false)])?
.build()?;
let execution_plan = plan(&logical_plan).await?;
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [CastExpr { expr: Literal { value: Int64(NULL) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(4) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(6) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(7) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(8) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(9) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(10) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(11) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(12) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(13) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(14) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(15) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(16) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(17) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(18) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(19) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(20) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(21) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(22) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(23) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(24) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(25) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(26) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(27) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(28) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(29) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(30) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false, set: Some(InSet { set: ";
assert!(format!("{:?}", execution_plan).contains(expected));
Ok(())
}

#[tokio::test]
async fn hash_agg_input_schema() -> Result<()> {
let testdata = crate::test_util::arrow_test_data();
Expand Down
17 changes: 17 additions & 0 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,20 @@ async fn test_expect_distinct() -> Result<()> {
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_in_set_test() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT count(*) FROM aggregate_test_100 WHERE c7 in ('25','155','204','77','208','67','139','191','26','7','202','113','129','197','249','146','129','220','154','163','220','19','71','243','150','231','196','170','99','255');";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------------+",
"| COUNT(UInt8(1)) |",
"+-----------------+",
"| 36 |",
"+-----------------+",
];
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}
Loading