diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index d6b6bd3498be..1cb725e0a6c8 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -514,7 +514,7 @@ mod tests { use super::*; use crate::datasource::MemTable; - use crate::logical_plan::{aggregate_expr, col, create_udf}; + use crate::logical_plan::{col, create_udf, sum}; use crate::physical_plan::functions::ScalarFunctionImplementation; use crate::test; use crate::variable::VarType; @@ -941,7 +941,7 @@ mod tests { ])); let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)? - .aggregate(vec![col("c1")], vec![aggregate_expr("SUM", col("c2"))])? + .aggregate(vec![col("c1")], vec![sum(col("c2"))])? .project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])? .build()?; diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index d4256c3db03c..29f748a788da 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -32,7 +32,7 @@ use crate::datasource::TableProvider; use crate::error::{ExecutionError, Result}; use crate::{ physical_plan::{ - expressions::binary_operator_data_type, functions, + aggregates, expressions::binary_operator_data_type, functions, type_coercion::can_coerce_from, udf::ScalarUDF, }, sql::parser::FileType, @@ -210,12 +210,12 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result { } Ok(format!("{}({})", fun.name, names.join(","))) } - Expr::AggregateFunction { name, args, .. } => { + Expr::AggregateFunction { fun, args, .. } => { let mut names = Vec::with_capacity(args.len()); for e in args { names.push(create_name(e, input_schema)?); } - Ok(format!("{}({})", name, names.join(","))) + Ok(format!("{}({})", fun, names.join(","))) } other => Err(ExecutionError::NotImplemented(format!( "Physical plan does not support logical expression {:?}", @@ -290,7 +290,7 @@ pub enum Expr { /// aggregate function AggregateFunction { /// Name of the function - name: String, + fun: aggregates::AggregateFunction, /// List of expressions to feed to the functions as arguments args: Vec, }, @@ -321,47 +321,12 @@ impl Expr { .collect::>>()?; functions::return_type(fun, &data_types) } - Expr::AggregateFunction { name, args, .. } => { - match name.to_uppercase().as_str() { - "MIN" | "MAX" => args[0].get_type(schema), - "SUM" => match args[0].get_type(schema)? { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 => Ok(DataType::Int64), - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => Ok(DataType::UInt64), - DataType::Float32 => Ok(DataType::Float32), - DataType::Float64 => Ok(DataType::Float64), - other => Err(ExecutionError::General(format!( - "SUM does not support {:?}", - other - ))), - }, - "AVG" => match args[0].get_type(schema)? { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(ExecutionError::General(format!( - "AVG does not support {:?}", - other - ))), - }, - "COUNT" => Ok(DataType::UInt64), - other => Err(ExecutionError::General(format!( - "Invalid aggregate function '{:?}'", - other - ))), - } + Expr::AggregateFunction { fun, args, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + aggregates::return_type(fun, &data_types) } Expr::Not(_) => Ok(DataType::Boolean), Expr::IsNull(_) => Ok(DataType::Boolean), @@ -573,27 +538,42 @@ pub fn col(name: &str) -> Expr { /// Create an expression to represent the min() aggregate function pub fn min(expr: Expr) -> Expr { - aggregate_expr("MIN", expr) + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::Min, + args: vec![expr], + } } /// Create an expression to represent the max() aggregate function pub fn max(expr: Expr) -> Expr { - aggregate_expr("MAX", expr) + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::Max, + args: vec![expr], + } } /// Create an expression to represent the sum() aggregate function pub fn sum(expr: Expr) -> Expr { - aggregate_expr("SUM", expr) + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::Sum, + args: vec![expr], + } } /// Create an expression to represent the avg() aggregate function pub fn avg(expr: Expr) -> Expr { - aggregate_expr("AVG", expr) + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::Avg, + args: vec![expr], + } } /// Create an expression to represent the count() aggregate function pub fn count(expr: Expr) -> Expr { - aggregate_expr("COUNT", expr) + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::Count, + args: vec![expr], + } } /// Whether it can be represented as a literal expression @@ -690,14 +670,6 @@ pub fn concat(args: Vec) -> Expr { } } -/// Create an aggregate expression -pub fn aggregate_expr(name: &str, expr: Expr) -> Expr { - Expr::AggregateFunction { - name: name.to_owned(), - args: vec![expr], - } -} - /// Creates a new UDF with a specific signature and specific return type. /// This is a helper function to create a new UDF. /// The function `create_udf` returns a subset of all possible `ScalarFunction`: @@ -767,8 +739,8 @@ impl fmt::Debug for Expr { write!(f, ")") } - Expr::AggregateFunction { name, ref args, .. } => { - write!(f, "{}(", name)?; + Expr::AggregateFunction { fun, ref args, .. } => { + write!(f, "{}(", fun)?; for i in 0..args.len() { if i > 0 { write!(f, ", ")?; @@ -1429,7 +1401,7 @@ mod tests { )? .aggregate( vec![col("state")], - vec![aggregate_expr("SUM", col("salary")).alias("total_salary")], + vec![sum(col("salary")).alias("total_salary")], )? .project(vec![col("state"), col("total_salary")])? .build()?; diff --git a/rust/datafusion/src/optimizer/filter_push_down.rs b/rust/datafusion/src/optimizer/filter_push_down.rs index 64ee18f9d13c..28a99eedfe14 100644 --- a/rust/datafusion/src/optimizer/filter_push_down.rs +++ b/rust/datafusion/src/optimizer/filter_push_down.rs @@ -303,7 +303,7 @@ fn rewrite(expr: &Expr, projection: &HashMap) -> Result { mod tests { use super::*; use crate::logical_plan::col; - use crate::logical_plan::{aggregate_expr, lit, Expr, LogicalPlanBuilder, Operator}; + use crate::logical_plan::{lit, sum, Expr, LogicalPlanBuilder, Operator}; use crate::test::*; fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { @@ -369,10 +369,7 @@ mod tests { fn filter_move_agg() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .aggregate( - vec![col("a")], - vec![aggregate_expr("SUM", col("b")).alias("total_salary")], - )? + .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])? .filter(col("a").gt(lit(10i64)))? .build()?; // filter of key aggregation is commutative @@ -388,10 +385,7 @@ mod tests { fn filter_keep_agg() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .aggregate( - vec![col("a")], - vec![aggregate_expr("SUM", col("b")).alias("b")], - )? + .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? .filter(col("b").gt(lit(10i64)))? .build()?; // filter of aggregate is after aggregation since they are non-commutative @@ -508,7 +502,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) .project(vec![col("a").alias("b"), col("c")])? - .aggregate(vec![col("b")], vec![aggregate_expr("SUM", col("c"))])? + .aggregate(vec![col("b")], vec![sum(col("c"))])? .filter(col("b").gt(lit(10i64)))? .filter(col("SUM(c)").gt(lit(10i64)))? .build()?; diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index 8445e8fce6b3..0bd46ee235c9 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -338,7 +338,7 @@ mod tests { use super::*; use crate::logical_plan::{col, lit}; - use crate::logical_plan::{Expr, LogicalPlanBuilder}; + use crate::logical_plan::{max, min, Expr, LogicalPlanBuilder}; use crate::test::*; use arrow::datatypes::DataType; diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index e0f641f98ad7..788d1e4c2fb0 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -239,8 +239,8 @@ pub fn rewrite_expression(expr: &Expr, expressions: &Vec) -> Result fun: fun.clone(), args: expressions.clone(), }), - Expr::AggregateFunction { name, .. } => Ok(Expr::AggregateFunction { - name: name.clone(), + Expr::AggregateFunction { fun, .. } => Ok(Expr::AggregateFunction { + fun: fun.clone(), args: expressions.clone(), }), Expr::Cast { data_type, .. } => Ok(Expr::Cast { diff --git a/rust/datafusion/src/physical_plan/aggregates.rs b/rust/datafusion/src/physical_plan/aggregates.rs new file mode 100644 index 000000000000..772e9b6967ac --- /dev/null +++ b/rust/datafusion/src/physical_plan/aggregates.rs @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Declaration of built-in (aggregate) functions. +//! This module contains built-in aggregates' enumeration and metadata. +//! +//! Generally, an aggregate has: +//! * a signature +//! * a return type, that is a function of the incoming argument's types +//! * the computation, that must accept each valid signature +//! +//! * Signature: see `Signature` +//! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. + +use super::{ + functions::Signature, + type_coercion::{coerce, data_types}, + AggregateExpr, PhysicalExpr, +}; +use crate::error::{ExecutionError, Result}; +use crate::physical_plan::expressions; +use arrow::datatypes::{DataType, Schema}; +use expressions::{avg_return_type, sum_return_type}; +use std::{fmt, str::FromStr, sync::Arc}; + +/// Enum of all built-in scalar functions +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AggregateFunction { + /// count + Count, + /// sum + Sum, + /// min + Min, + /// max + Max, + /// avg + Avg, +} + +impl fmt::Display for AggregateFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // uppercase of the debug. + write!(f, "{}", format!("{:?}", self).to_uppercase()) + } +} + +impl FromStr for AggregateFunction { + type Err = ExecutionError; + fn from_str(name: &str) -> Result { + Ok(match &*name.to_uppercase() { + "MIN" => AggregateFunction::Min, + "MAX" => AggregateFunction::Max, + "COUNT" => AggregateFunction::Count, + "AVG" => AggregateFunction::Avg, + "SUM" => AggregateFunction::Sum, + _ => { + return Err(ExecutionError::General(format!( + "There is no built-in function named {}", + name + ))) + } + }) + } +} + +/// Returns the datatype of the scalar function +pub fn return_type( + fun: &AggregateFunction, + arg_types: &Vec, +) -> Result { + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + // verify that this is a valid set of data types for this function + data_types(arg_types, &signature(fun))?; + + match fun { + AggregateFunction::Count => Ok(DataType::UInt64), + AggregateFunction::Max | AggregateFunction::Min => Ok(arg_types[0].clone()), + AggregateFunction::Sum => sum_return_type(&arg_types[0]), + AggregateFunction::Avg => avg_return_type(&arg_types[0]), + } +} + +/// Create a physical (function) expression. +/// This function errors when `args`' can't be coerced to a valid argument type of the function. +pub fn create_aggregate_expr( + fun: &AggregateFunction, + args: &Vec>, + input_schema: &Schema, +) -> Result> { + // coerce + let arg = coerce(args, input_schema, &signature(fun))?[0].clone(); + + Ok(match fun { + AggregateFunction::Count => expressions::count(arg), + AggregateFunction::Sum => expressions::sum(arg), + AggregateFunction::Min => expressions::min(arg), + AggregateFunction::Max => expressions::max(arg), + AggregateFunction::Avg => expressions::avg(arg), + }) +} + +static NUMERICS: &'static [DataType] = &[ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float32, + DataType::Float64, +]; + +/// the signatures supported by the function `fun`. +fn signature(fun: &AggregateFunction) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + match fun { + AggregateFunction::Count => Signature::Any(1), + AggregateFunction::Min | AggregateFunction::Max => { + let mut valid = vec![DataType::Utf8]; + valid.extend_from_slice(NUMERICS); + Signature::Uniform(1, valid) + } + AggregateFunction::Avg | AggregateFunction::Sum => { + Signature::Uniform(1, NUMERICS.to_vec()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + + #[test] + fn test_min_max() -> Result<()> { + let observed = return_type(&AggregateFunction::Min, &vec![DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = return_type(&AggregateFunction::Max, &vec![DataType::Int32])?; + assert_eq!(DataType::Int32, observed); + Ok(()) + } + + #[test] + fn test_sum_no_utf8() -> Result<()> { + let observed = return_type(&AggregateFunction::Sum, &vec![DataType::Utf8]); + assert!(observed.is_err()); + Ok(()) + } + + #[test] + fn test_sum_upcasts() -> Result<()> { + let observed = return_type(&AggregateFunction::Sum, &vec![DataType::UInt32])?; + assert_eq!(DataType::UInt64, observed); + Ok(()) + } + + #[test] + fn test_count_return_type() -> Result<()> { + let observed = return_type(&AggregateFunction::Count, &vec![DataType::Utf8])?; + assert_eq!(DataType::UInt64, observed); + + let observed = return_type(&AggregateFunction::Count, &vec![DataType::Int8])?; + assert_eq!(DataType::UInt64, observed); + Ok(()) + } + + #[test] + fn test_avg_return_type() -> Result<()> { + let observed = return_type(&AggregateFunction::Avg, &vec![DataType::Float32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Avg, &vec![DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + Ok(()) + } + + #[test] + fn test_avg_no_utf8() -> Result<()> { + let observed = return_type(&AggregateFunction::Avg, &vec![DataType::Utf8]); + assert!(observed.is_err()); + Ok(()) + } +} diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs index c53599ca7f04..80cf9ea27ca8 100644 --- a/rust/datafusion/src/physical_plan/expressions.rs +++ b/rust/datafusion/src/physical_plan/expressions.rs @@ -107,22 +107,27 @@ impl Sum { } } +/// function return type of a sum +pub fn sum_return_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + Ok(DataType::Int64) + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + Ok(DataType::UInt64) + } + DataType::Float32 => Ok(DataType::Float32), + DataType::Float64 => Ok(DataType::Float64), + other => Err(ExecutionError::General(format!( + "SUM does not support type \"{:?}\"", + other + ))), + } +} + impl AggregateExpr for Sum { fn data_type(&self, input_schema: &Schema) -> Result { - match self.expr.data_type(input_schema)? { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - Ok(DataType::Int64) - } - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - Ok(DataType::UInt64) - } - DataType::Float32 => Ok(DataType::Float32), - DataType::Float64 => Ok(DataType::Float64), - other => Err(ExecutionError::General(format!( - "SUM does not support {:?}", - other - ))), - } + sum_return_type(&self.expr.data_type(input_schema)?) } fn nullable(&self, _input_schema: &Schema) -> Result { @@ -306,24 +311,29 @@ impl Avg { } } +/// function return type of an average +pub fn avg_return_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(ExecutionError::General(format!( + "AVG does not support {:?}", + other + ))), + } +} + impl AggregateExpr for Avg { fn data_type(&self, input_schema: &Schema) -> Result { - match self.expr.data_type(input_schema)? { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(ExecutionError::General(format!( - "AVG does not support {:?}", - other - ))), - } + avg_return_type(&self.expr.data_type(input_schema)?) } fn nullable(&self, _input_schema: &Schema) -> Result { diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 06ce949b91c7..ea6147cacef9 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -63,6 +63,8 @@ pub enum Signature { Uniform(usize, Vec), /// exact number of arguments of an exact type Exact(Vec), + /// fixed number of arguments of arbitrary types + Any(usize), } /// Scalar function diff --git a/rust/datafusion/src/physical_plan/mod.rs b/rust/datafusion/src/physical_plan/mod.rs index e2265cd1d5bb..99ce8d6d4223 100644 --- a/rust/datafusion/src/physical_plan/mod.rs +++ b/rust/datafusion/src/physical_plan/mod.rs @@ -130,6 +130,7 @@ pub trait Accumulator: Debug { fn get_value(&self) -> Result>; } +pub mod aggregates; pub mod common; pub mod csv; pub mod datetime_expressions; diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index c2c0e4b807c4..8230ba883d02 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use super::{empty::EmptyExec, expressions::binary, functions}; +use super::{aggregates, empty::EmptyExec, expressions::binary, functions}; use crate::error::{ExecutionError, Result}; use crate::execution::context::ExecutionContextState; use crate::logical_plan::{ @@ -27,9 +27,7 @@ use crate::logical_plan::{ }; use crate::physical_plan::csv::{CsvExec, CsvReadOptions}; use crate::physical_plan::explain::ExplainExec; -use crate::physical_plan::expressions::{ - Avg, Column, Count, Literal, Max, Min, PhysicalSortExpr, Sum, -}; +use crate::physical_plan::expressions::{Column, Literal, PhysicalSortExpr}; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; @@ -477,38 +475,12 @@ impl DefaultPhysicalPlanner { ctx_state: &ExecutionContextState, ) -> Result> { match e { - Expr::AggregateFunction { name, args, .. } => { - match name.to_lowercase().as_ref() { - "sum" => Ok(Arc::new(Sum::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state, - )?))), - "avg" => Ok(Arc::new(Avg::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state, - )?))), - "max" => Ok(Arc::new(Max::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state, - )?))), - "min" => Ok(Arc::new(Min::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state, - )?))), - "count" => Ok(Arc::new(Count::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state, - )?))), - other => Err(ExecutionError::NotImplemented(format!( - "Unsupported aggregate function '{}'", - other - ))), - } + Expr::AggregateFunction { fun, args, .. } => { + let args = args + .iter() + .map(|e| self.create_physical_expr(e, input_schema, ctx_state)) + .collect::>>()?; + aggregates::create_aggregate_expr(fun, &args, input_schema) } other => Err(ExecutionError::General(format!( "Invalid aggregate expression '{:?}'", @@ -561,7 +533,7 @@ impl ExtensionPlanner for DefaultExtensionPlanner { #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::{aggregate_expr, col, lit, LogicalPlanBuilder}; + use crate::logical_plan::{col, lit, sum, LogicalPlanBuilder}; use crate::physical_plan::{csv::CsvReadOptions, Partitioning}; use crate::{prelude::ExecutionConfig, test::arrow_testdata_path}; use arrow::{ @@ -596,7 +568,7 @@ mod tests { // filter clause needs the type coercion rule applied .filter(col("c7").lt(lit(5_u8)))? .project(vec![col("c1"), col("c2")])? - .aggregate(vec![col("c1")], vec![aggregate_expr("SUM", col("c2"))])? + .aggregate(vec![col("c1")], vec![sum(col("c2"))])? .sort(vec![col("c1").sort(true, true)])? .limit(10)? .build()?; diff --git a/rust/datafusion/src/physical_plan/type_coercion.rs b/rust/datafusion/src/physical_plan/type_coercion.rs index 4ad6085bd141..cfdc8167a57f 100644 --- a/rust/datafusion/src/physical_plan/type_coercion.rs +++ b/rust/datafusion/src/physical_plan/type_coercion.rs @@ -67,6 +67,16 @@ pub fn data_types( .collect()] } Signature::Exact(valid_types) => vec![valid_types.clone()], + Signature::Any(number) => { + if current_types.len() != *number { + return Err(ExecutionError::General(format!( + "The function expected {} arguments but received {}", + number, + current_types.len() + ))); + } + vec![(0..*number).map(|i| current_types[i].clone()).collect()] + } }; if valid_types.contains(current_types) { @@ -276,6 +286,12 @@ mod tests { Signature::Variadic(vec![DataType::UInt32, DataType::UInt64]), vec![DataType::UInt64, DataType::UInt64], )?, + // f32 -> f32 + case( + vec![DataType::Float32], + Signature::Any(1), + vec![DataType::Float32], + )?, ]; for case in cases { @@ -304,6 +320,8 @@ mod tests { Signature::Variadic(vec![DataType::UInt32]), vec![], )?, + // expected two arguments + case(vec![DataType::UInt32], Signature::Any(2), vec![])?, ]; for case in cases { diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index e8c77dbad84a..d08e57a3859f 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -27,8 +27,8 @@ use crate::logical_plan::{ StringifiedPlan, }; use crate::{ - physical_plan::functions, physical_plan::udf::ScalarUDF, + physical_plan::{aggregates, functions}, sql::parser::{CreateExternalTable, FileType, Statement as DFStatement}, }; @@ -519,22 +519,10 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { return Ok(Expr::ScalarFunction { fun, args }); }; - //TODO: fix this hack - match name.to_lowercase().as_ref() { - "min" | "max" | "sum" | "avg" => { - let rex_args = function - .args - .iter() - .map(|a| self.sql_to_rex(a, schema)) - .collect::>>()?; - - Ok(Expr::AggregateFunction { - name: name.clone(), - args: rex_args, - }) - } - "count" => { - let rex_args = function + // next, aggregate built-ins + if let Ok(fun) = aggregates::AggregateFunction::from_str(&name) { + let args = if fun == aggregates::AggregateFunction::Count { + function .args .iter() .map(|a| match a { @@ -542,32 +530,36 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { SQLExpr::Wildcard => Ok(lit(1_u8)), _ => self.sql_to_rex(a, schema), }) + .collect::>>()? + } else { + function + .args + .iter() + .map(|a| self.sql_to_rex(a, schema)) + .collect::>>()? + }; + + return Ok(Expr::AggregateFunction { fun, args }); + }; + + // finally, user-defined functions + match self.schema_provider.get_function_meta(&name) { + Some(fm) => { + let args = function + .args + .iter() + .map(|a| self.sql_to_rex(a, schema)) .collect::>>()?; - Ok(Expr::AggregateFunction { - name: name.clone(), - args: rex_args, + Ok(Expr::ScalarUDF { + fun: fm.clone(), + args, }) } - // finally, user-defined functions - _ => match self.schema_provider.get_function_meta(&name) { - Some(fm) => { - let args = function - .args - .iter() - .map(|a| self.sql_to_rex(a, schema)) - .collect::>>()?; - - Ok(Expr::ScalarUDF { - fun: fm.clone(), - args, - }) - } - _ => Err(ExecutionError::General(format!( - "Invalid function '{}'", - name - ))), - }, + _ => Err(ExecutionError::General(format!( + "Invalid function '{}'", + name + ))), } } diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index c700243c1554..cbd7425fe370 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -20,7 +20,7 @@ use crate::datasource::{MemTable, TableProvider}; use crate::error::Result; use crate::execution::context::ExecutionContext; -use crate::logical_plan::{Expr, LogicalPlan, LogicalPlanBuilder}; +use crate::logical_plan::{LogicalPlan, LogicalPlanBuilder}; use crate::physical_plan::ExecutionPlan; use arrow::array; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -242,18 +242,4 @@ pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { assert_eq!(actual, expected); } -pub fn max(expr: Expr) -> Expr { - Expr::AggregateFunction { - name: "MAX".to_owned(), - args: vec![expr], - } -} - -pub fn min(expr: Expr) -> Expr { - Expr::AggregateFunction { - name: "MIN".to_owned(), - args: vec![expr], - } -} - pub mod variable;