Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use datafusion::arrow::{
};

use datafusion::from_slice::FromSlice;
use datafusion::logical_expr::AggregateState;
use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumulator};
use datafusion::{logical_expr::Volatility, prelude::*, scalar::ScalarValue};
use std::sync::Arc;
Expand Down Expand Up @@ -107,10 +108,10 @@ impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
AggregateState::Scalar(ScalarValue::from(self.prod)),
AggregateState::Scalar(ScalarValue::from(self.n)),
])
}

Expand Down
1 change: 1 addition & 0 deletions datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ object_store = { version = "0.3", optional = true }
ordered-float = "3.0"
parquet = { version = "19.0.0", features = ["arrow"], optional = true }
pyo3 = { version = "0.16", optional = true }
serde_json = "1.0"
sqlparser = "0.19"
6 changes: 4 additions & 2 deletions datafusion/core/src/physical_plan/aggregates/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,10 @@ fn create_batch_from_map(
AggregateMode::Partial => {
let res = ScalarValue::iter_to_array(
accumulators.group_states.iter().map(|group_state| {
let x = group_state.accumulator_set[x].state().unwrap();
x[y].clone()
group_state.accumulator_set[x]
.state()
.and_then(|x| x[y].as_scalar().map(|v| v.clone()))
.expect("unexpected accumulator state in hash aggregate")
}),
)?;

Expand Down
189 changes: 186 additions & 3 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ async fn csv_query_stddev_6() -> Result<()> {
}

#[tokio::test]
async fn csv_query_median_1() -> Result<()> {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, I would recommend adding a basic test in sql for a median for all the different data types that are supported (not just on aggregate_test_100 but a dedicated test setup with known data (maybe integers 10, 9, 8, ... 0)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in ef1effd

async fn csv_query_approx_median_1() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT approx_median(c2) FROM aggregate_test_100";
Expand All @@ -232,7 +232,7 @@ async fn csv_query_median_1() -> Result<()> {
}

#[tokio::test]
async fn csv_query_median_2() -> Result<()> {
async fn csv_query_approx_median_2() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT approx_median(c6) FROM aggregate_test_100";
Expand All @@ -243,7 +243,7 @@ async fn csv_query_median_2() -> Result<()> {
}

#[tokio::test]
async fn csv_query_median_3() -> Result<()> {
async fn csv_query_approx_median_3() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT approx_median(c12) FROM aggregate_test_100";
Expand All @@ -253,6 +253,189 @@ async fn csv_query_median_3() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_median_1() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT median(c2) FROM aggregate_test_100";
let actual = execute(&ctx, sql).await;
let expected = vec![vec!["3"]];
assert_float_eq(&expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_query_median_2() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT median(c6) FROM aggregate_test_100";
let actual = execute(&ctx, sql).await;
let expected = vec![vec!["1125553990140691277"]];
assert_float_eq(&expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_query_median_3() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT median(c12) FROM aggregate_test_100";
let actual = execute(&ctx, sql).await;
let expected = vec![vec!["0.5513900544385053"]];
assert_float_eq(&expected, &actual);
Ok(())
}

#[tokio::test]
async fn median_i8() -> Result<()> {
median_test(
"median",
DataType::Int8,
Arc::new(Int8Array::from(vec![i8::MIN, i8::MIN, 100, i8::MAX])),
"-14",
)
.await
}

#[tokio::test]
async fn median_i16() -> Result<()> {
median_test(
"median",
DataType::Int16,
Arc::new(Int16Array::from(vec![i16::MIN, i16::MIN, 100, i16::MAX])),
"-16334",
)
.await
}

#[tokio::test]
async fn median_i32() -> Result<()> {
median_test(
"median",
DataType::Int32,
Arc::new(Int32Array::from(vec![i32::MIN, i32::MIN, 100, i32::MAX])),
"-1073741774",
)
.await
}

#[tokio::test]
async fn median_i64() -> Result<()> {
median_test(
"median",
DataType::Int64,
Arc::new(Int64Array::from(vec![i64::MIN, i64::MIN, 100, i64::MAX])),
"-4611686018427388000",
)
.await
}

#[tokio::test]
async fn median_u8() -> Result<()> {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

median_test(
"median",
DataType::UInt8,
Arc::new(UInt8Array::from(vec![u8::MIN, u8::MIN, 100, u8::MAX])),
"50",
)
.await
}

#[tokio::test]
async fn median_u16() -> Result<()> {
median_test(
"median",
DataType::UInt16,
Arc::new(UInt16Array::from(vec![u16::MIN, u16::MIN, 100, u16::MAX])),
"50",
)
.await
}

#[tokio::test]
async fn median_u32() -> Result<()> {
median_test(
"median",
DataType::UInt32,
Arc::new(UInt32Array::from(vec![u32::MIN, u32::MIN, 100, u32::MAX])),
"50",
)
.await
}

#[tokio::test]
async fn median_u64() -> Result<()> {
median_test(
"median",
DataType::UInt64,
Arc::new(UInt64Array::from(vec![u64::MIN, u64::MIN, 100, u64::MAX])),
"50",
)
.await
}

#[tokio::test]
async fn median_f32() -> Result<()> {
median_test(
"median",
DataType::Float32,
Arc::new(Float32Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])),
"3.3",
)
.await
}

#[tokio::test]
async fn median_f64() -> Result<()> {
median_test(
"median",
DataType::Float64,
Arc::new(Float64Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])),
"3.3",
)
.await
}

#[tokio::test]
async fn median_f64_nan() -> Result<()> {
median_test(
"median",
DataType::Float64,
Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])),
"NaN", // probably not the desired behavior? - see https://github.com/apache/arrow-datafusion/issues/3039
)
.await
}

#[tokio::test]
async fn approx_median_f64_nan() -> Result<()> {
median_test(
"approx_median",
DataType::Float64,
Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])),
"NaN", // probably not the desired behavior? - see https://github.com/apache/arrow-datafusion/issues/3039
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testing for the win!

)
.await
}

async fn median_test(
func: &str,
data_type: DataType,
values: ArrayRef,
expected: &str,
) -> Result<()> {
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![Field::new("a", data_type, false)]));
let batch = RecordBatch::try_new(schema.clone(), vec![values])?;
let table = Arc::new(MemTable::try_new(schema, vec![vec![batch]])?);
ctx.register_table("t", table)?;
let sql = format!("SELECT {}(a) FROM t", func);
let actual = execute(&ctx, &sql).await;
let expected = vec![vec![expected.to_owned()]];
assert_float_eq(&expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_query_external_table_count() {
let ctx = SessionContext::new();
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/tests/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ where
l.as_ref().parse::<f64>().unwrap(),
r.as_str().parse::<f64>().unwrap(),
);
assert!((l - r).abs() <= 2.0 * f64::EPSILON);
if l.is_nan() || r.is_nan() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

assert!(l.is_nan() && r.is_nan());
} else if (l - r).abs() > 2.0 * f64::EPSILON {
panic!("{} != {}", l, r)
}
});
}

Expand Down
45 changes: 40 additions & 5 deletions datafusion/expr/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@
//! Accumulator module contains the trait definition for aggregation function's accumulators.

use arrow::array::ArrayRef;
use datafusion_common::{Result, ScalarValue};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use std::fmt::Debug;

/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
/// generically accumulates values.
///
/// An accumulator knows how to:
/// * update its state from inputs via `update_batch`
/// * convert its internal state to a vector of scalar values
/// * convert its internal state to a vector of aggregate values
/// * update its state from multiple accumulators' states via `merge_batch`
/// * compute the final value from its internal state via `evaluate`
pub trait Accumulator: Send + Sync + Debug {
/// Returns the state of the accumulator at the end of the accumulation.
// in the case of an average on which we track `sum` and `n`, this function should return a vector
// of two values, sum and n.
fn state(&self) -> Result<Vec<ScalarValue>>;
/// in the case of an average on which we track `sum` and `n`, this function should return a vector
/// of two values, sum and n.
fn state(&self) -> Result<Vec<AggregateState>>;

/// updates the accumulator's state from a vector of arrays.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
Expand All @@ -44,3 +44,38 @@ pub trait Accumulator: Send + Sync + Debug {
/// returns its value based on its current state.
fn evaluate(&self) -> Result<ScalarValue>;
}

/// Representation of internal accumulator state. Accumulators can potentially have a mix of
/// scalar and array values. It may be desirable to add custom aggregator states here as well
/// in the future (perhaps `Custom(Box<dyn Any>)`?).
#[derive(Debug)]
pub enum AggregateState {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very elegant idea. Can you please add docstrings to AggregateState explaining what is going on?

I think it would be worth updating the docstrings in the accumulator trait with some discussion / examples of how to use the Array state.

/// Simple scalar value. Note that `ScalarValue::List` can be used to pass multiple
/// values around
Scalar(ScalarValue),
/// Arrays can be used instead of `ScalarValue::List` and could potentially have better
/// performance with large data sets, although this has not been verified. It also allows
/// for use of arrow kernels with less overhead.
Array(ArrayRef),
}

impl AggregateState {
/// Access the aggregate state as a scalar value. An error will occur if the
/// state is not a scalar value.
pub fn as_scalar(&self) -> Result<&ScalarValue> {
match &self {
Self::Scalar(v) => Ok(v),
_ => Err(DataFusionError::Internal(
"AggregateState is not a scalar aggregate".to_string(),
)),
}
}

/// Access the aggregate state as an array value.
pub fn to_array(&self) -> ArrayRef {
match &self {
Self::Scalar(v) => v.to_array(),
Self::Array(array) => array.clone(),
}
}
}
9 changes: 8 additions & 1 deletion datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ pub enum AggregateFunction {
Max,
/// avg
Avg,
/// median
Median,
/// Approximate aggregate function
ApproxDistinct,
/// array_agg
Expand Down Expand Up @@ -107,6 +109,7 @@ impl FromStr for AggregateFunction {
"avg" => AggregateFunction::Avg,
"mean" => AggregateFunction::Avg,
"sum" => AggregateFunction::Sum,
"median" => AggregateFunction::Median,
"approx_distinct" => AggregateFunction::ApproxDistinct,
"array_agg" => AggregateFunction::ArrayAgg,
"var" => AggregateFunction::Variance,
Expand Down Expand Up @@ -175,7 +178,9 @@ pub fn return_type(
AggregateFunction::ApproxPercentileContWithWeight => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
AggregateFunction::ApproxMedian | AggregateFunction::Median => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::Grouping => Ok(DataType::Int32),
}
}
Expand Down Expand Up @@ -330,6 +335,7 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::Median => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
}
}
Expand Down Expand Up @@ -358,6 +364,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
| AggregateFunction::VariancePop
| AggregateFunction::Stddev
| AggregateFunction::StddevPop
| AggregateFunction::Median
| AggregateFunction::ApproxMedian => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub mod utils;
pub mod window_frame;
pub mod window_function;

pub use accumulator::Accumulator;
pub use accumulator::{Accumulator, AggregateState};
pub use aggregate_function::AggregateFunction;
pub use built_in_function::BuiltinScalarFunction;
pub use columnar_value::{ColumnarValue, NullColumnarValue};
Expand Down
Loading