Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ mod tests {

use super::*;
use crate::datasource::MemTable;
use crate::logical_plan::{aggregate_expr, col, create_udf};
use crate::logical_plan::{col, create_udf, sum};
use crate::physical_plan::functions::ScalarFunctionImplementation;
use crate::test;
use crate::variable::VarType;
Expand Down Expand Up @@ -941,7 +941,7 @@ mod tests {
]));

let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)?
.aggregate(vec![col("c1")], vec![aggregate_expr("SUM", col("c2"))])?
.aggregate(vec![col("c1")], vec![sum(col("c2"))])?
.project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])?
.build()?;

Expand Down
94 changes: 33 additions & 61 deletions rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::datasource::TableProvider;
use crate::error::{ExecutionError, Result};
use crate::{
physical_plan::{
expressions::binary_operator_data_type, functions,
aggregates, expressions::binary_operator_data_type, functions,
type_coercion::can_coerce_from, udf::ScalarUDF,
},
sql::parser::FileType,
Expand Down Expand Up @@ -210,12 +210,12 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result<String> {
}
Ok(format!("{}({})", fun.name, names.join(",")))
}
Expr::AggregateFunction { name, args, .. } => {
Expr::AggregateFunction { fun, args, .. } => {
let mut names = Vec::with_capacity(args.len());
for e in args {
names.push(create_name(e, input_schema)?);
}
Ok(format!("{}({})", name, names.join(",")))
Ok(format!("{}({})", fun, names.join(",")))
}
other => Err(ExecutionError::NotImplemented(format!(
"Physical plan does not support logical expression {:?}",
Expand Down Expand Up @@ -290,7 +290,7 @@ pub enum Expr {
/// aggregate function
AggregateFunction {
/// Name of the function
name: String,
fun: aggregates::AggregateFunction,
/// List of expressions to feed to the functions as arguments
args: Vec<Expr>,
},
Expand Down Expand Up @@ -321,47 +321,12 @@ impl Expr {
.collect::<Result<Vec<_>>>()?;
functions::return_type(fun, &data_types)
}
Expr::AggregateFunction { name, args, .. } => {
match name.to_uppercase().as_str() {
"MIN" | "MAX" => args[0].get_type(schema),
"SUM" => match args[0].get_type(schema)? {
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64 => Ok(DataType::Int64),
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => Ok(DataType::UInt64),
DataType::Float32 => Ok(DataType::Float32),
DataType::Float64 => Ok(DataType::Float64),
other => Err(ExecutionError::General(format!(
"SUM does not support {:?}",
other
))),
},
"AVG" => match args[0].get_type(schema)? {
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64 => Ok(DataType::Float64),
other => Err(ExecutionError::General(format!(
"AVG does not support {:?}",
other
))),
},
"COUNT" => Ok(DataType::UInt64),
other => Err(ExecutionError::General(format!(
"Invalid aggregate function '{:?}'",
other
))),
}
Expr::AggregateFunction { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
aggregates::return_type(fun, &data_types)
}
Expr::Not(_) => Ok(DataType::Boolean),
Expr::IsNull(_) => Ok(DataType::Boolean),
Expand Down Expand Up @@ -573,27 +538,42 @@ pub fn col(name: &str) -> Expr {

/// Create an expression to represent the min() aggregate function
pub fn min(expr: Expr) -> Expr {
aggregate_expr("MIN", expr)
Expr::AggregateFunction {
fun: aggregates::AggregateFunction::Min,
args: vec![expr],
}
}

/// Create an expression to represent the max() aggregate function
pub fn max(expr: Expr) -> Expr {
aggregate_expr("MAX", expr)
Expr::AggregateFunction {
fun: aggregates::AggregateFunction::Max,
args: vec![expr],
}
}

/// Create an expression to represent the sum() aggregate function
pub fn sum(expr: Expr) -> Expr {
aggregate_expr("SUM", expr)
Expr::AggregateFunction {
fun: aggregates::AggregateFunction::Sum,
args: vec![expr],
}
}

/// Create an expression to represent the avg() aggregate function
pub fn avg(expr: Expr) -> Expr {
aggregate_expr("AVG", expr)
Expr::AggregateFunction {
fun: aggregates::AggregateFunction::Avg,
args: vec![expr],
}
}

/// Create an expression to represent the count() aggregate function
pub fn count(expr: Expr) -> Expr {
aggregate_expr("COUNT", expr)
Expr::AggregateFunction {
fun: aggregates::AggregateFunction::Count,
args: vec![expr],
}
}

/// Whether it can be represented as a literal expression
Expand Down Expand Up @@ -690,14 +670,6 @@ pub fn concat(args: Vec<Expr>) -> Expr {
}
}

/// Create an aggregate expression
pub fn aggregate_expr(name: &str, expr: Expr) -> Expr {
Expr::AggregateFunction {
name: name.to_owned(),
args: vec![expr],
}
}

/// Creates a new UDF with a specific signature and specific return type.
/// This is a helper function to create a new UDF.
/// The function `create_udf` returns a subset of all possible `ScalarFunction`:
Expand Down Expand Up @@ -767,8 +739,8 @@ impl fmt::Debug for Expr {

write!(f, ")")
}
Expr::AggregateFunction { name, ref args, .. } => {
write!(f, "{}(", name)?;
Expr::AggregateFunction { fun, ref args, .. } => {
write!(f, "{}(", fun)?;
for i in 0..args.len() {
if i > 0 {
write!(f, ", ")?;
Expand Down Expand Up @@ -1429,7 +1401,7 @@ mod tests {
)?
.aggregate(
vec![col("state")],
vec![aggregate_expr("SUM", col("salary")).alias("total_salary")],
vec![sum(col("salary")).alias("total_salary")],
)?
.project(vec![col("state"), col("total_salary")])?
.build()?;
Expand Down
14 changes: 4 additions & 10 deletions rust/datafusion/src/optimizer/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ fn rewrite(expr: &Expr, projection: &HashMap<String, Expr>) -> Result<Expr> {
mod tests {
use super::*;
use crate::logical_plan::col;
use crate::logical_plan::{aggregate_expr, lit, Expr, LogicalPlanBuilder, Operator};
use crate::logical_plan::{lit, sum, Expr, LogicalPlanBuilder, Operator};
use crate::test::*;

fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
Expand Down Expand Up @@ -369,10 +369,7 @@ mod tests {
fn filter_move_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.aggregate(
vec![col("a")],
vec![aggregate_expr("SUM", col("b")).alias("total_salary")],
)?
.aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
.filter(col("a").gt(lit(10i64)))?
.build()?;
// filter of key aggregation is commutative
Expand All @@ -388,10 +385,7 @@ mod tests {
fn filter_keep_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.aggregate(
vec![col("a")],
vec![aggregate_expr("SUM", col("b")).alias("b")],
)?
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
.filter(col("b").gt(lit(10i64)))?
.build()?;
// filter of aggregate is after aggregation since they are non-commutative
Expand Down Expand Up @@ -508,7 +502,7 @@ mod tests {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.aggregate(vec![col("b")], vec![aggregate_expr("SUM", col("c"))])?
.aggregate(vec![col("b")], vec![sum(col("c"))])?
.filter(col("b").gt(lit(10i64)))?
.filter(col("SUM(c)").gt(lit(10i64)))?
.build()?;
Expand Down
2 changes: 1 addition & 1 deletion rust/datafusion/src/optimizer/projection_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ mod tests {

use super::*;
use crate::logical_plan::{col, lit};
use crate::logical_plan::{Expr, LogicalPlanBuilder};
use crate::logical_plan::{max, min, Expr, LogicalPlanBuilder};
use crate::test::*;
use arrow::datatypes::DataType;

Expand Down
4 changes: 2 additions & 2 deletions rust/datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ pub fn rewrite_expression(expr: &Expr, expressions: &Vec<Expr>) -> Result<Expr>
fun: fun.clone(),
args: expressions.clone(),
}),
Expr::AggregateFunction { name, .. } => Ok(Expr::AggregateFunction {
name: name.clone(),
Expr::AggregateFunction { fun, .. } => Ok(Expr::AggregateFunction {
fun: fun.clone(),
args: expressions.clone(),
}),
Expr::Cast { data_type, .. } => Ok(Expr::Cast {
Expand Down
Loading