-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Implement exact median, add AggregateState
#3009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
429b001
60d8395
a18567a
408f9c3
fa40574
e456b39
a3e3c2b
1975629
34e87eb
8c89a74
dc4cb35
547f980
9c17969
63cfbe6
ef0f7dc
ff89d6d
f31d92d
b6eae6f
ca9d6a9
ef1effd
74ff2ef
7be6781
1a92bea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -221,7 +221,7 @@ async fn csv_query_stddev_6() -> Result<()> { | |
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn csv_query_median_1() -> Result<()> { | ||
| async fn csv_query_approx_median_1() -> Result<()> { | ||
| let ctx = SessionContext::new(); | ||
| register_aggregate_csv(&ctx).await?; | ||
| let sql = "SELECT approx_median(c2) FROM aggregate_test_100"; | ||
|
|
@@ -232,7 +232,7 @@ async fn csv_query_median_1() -> Result<()> { | |
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn csv_query_median_2() -> Result<()> { | ||
| async fn csv_query_approx_median_2() -> Result<()> { | ||
| let ctx = SessionContext::new(); | ||
| register_aggregate_csv(&ctx).await?; | ||
| let sql = "SELECT approx_median(c6) FROM aggregate_test_100"; | ||
|
|
@@ -243,7 +243,7 @@ async fn csv_query_median_2() -> Result<()> { | |
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn csv_query_median_3() -> Result<()> { | ||
| async fn csv_query_approx_median_3() -> Result<()> { | ||
| let ctx = SessionContext::new(); | ||
| register_aggregate_csv(&ctx).await?; | ||
| let sql = "SELECT approx_median(c12) FROM aggregate_test_100"; | ||
|
|
@@ -253,6 +253,189 @@ async fn csv_query_median_3() -> Result<()> { | |
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn csv_query_median_1() -> Result<()> { | ||
| let ctx = SessionContext::new(); | ||
| register_aggregate_csv(&ctx).await?; | ||
| let sql = "SELECT median(c2) FROM aggregate_test_100"; | ||
| let actual = execute(&ctx, sql).await; | ||
| let expected = vec![vec!["3"]]; | ||
| assert_float_eq(&expected, &actual); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn csv_query_median_2() -> Result<()> { | ||
| let ctx = SessionContext::new(); | ||
| register_aggregate_csv(&ctx).await?; | ||
| let sql = "SELECT median(c6) FROM aggregate_test_100"; | ||
| let actual = execute(&ctx, sql).await; | ||
| let expected = vec![vec!["1125553990140691277"]]; | ||
| assert_float_eq(&expected, &actual); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn csv_query_median_3() -> Result<()> { | ||
| let ctx = SessionContext::new(); | ||
| register_aggregate_csv(&ctx).await?; | ||
| let sql = "SELECT median(c12) FROM aggregate_test_100"; | ||
| let actual = execute(&ctx, sql).await; | ||
| let expected = vec![vec!["0.5513900544385053"]]; | ||
| assert_float_eq(&expected, &actual); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn median_i8() -> Result<()> { | ||
| median_test( | ||
| "median", | ||
| DataType::Int8, | ||
| Arc::new(Int8Array::from(vec![i8::MIN, i8::MIN, 100, i8::MAX])), | ||
| "-14", | ||
| ) | ||
| .await | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn median_i16() -> Result<()> { | ||
| median_test( | ||
| "median", | ||
| DataType::Int16, | ||
| Arc::new(Int16Array::from(vec![i16::MIN, i16::MIN, 100, i16::MAX])), | ||
| "-16334", | ||
| ) | ||
| .await | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn median_i32() -> Result<()> { | ||
| median_test( | ||
| "median", | ||
| DataType::Int32, | ||
| Arc::new(Int32Array::from(vec![i32::MIN, i32::MIN, 100, i32::MAX])), | ||
| "-1073741774", | ||
| ) | ||
| .await | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn median_i64() -> Result<()> { | ||
| median_test( | ||
| "median", | ||
| DataType::Int64, | ||
| Arc::new(Int64Array::from(vec![i64::MIN, i64::MIN, 100, i64::MAX])), | ||
| "-4611686018427388000", | ||
| ) | ||
| .await | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn median_u8() -> Result<()> { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
| median_test( | ||
| "median", | ||
| DataType::UInt8, | ||
| Arc::new(UInt8Array::from(vec![u8::MIN, u8::MIN, 100, u8::MAX])), | ||
| "50", | ||
| ) | ||
| .await | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn median_u16() -> Result<()> { | ||
| median_test( | ||
| "median", | ||
| DataType::UInt16, | ||
| Arc::new(UInt16Array::from(vec![u16::MIN, u16::MIN, 100, u16::MAX])), | ||
| "50", | ||
| ) | ||
| .await | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn median_u32() -> Result<()> { | ||
| median_test( | ||
| "median", | ||
| DataType::UInt32, | ||
| Arc::new(UInt32Array::from(vec![u32::MIN, u32::MIN, 100, u32::MAX])), | ||
| "50", | ||
| ) | ||
| .await | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn median_u64() -> Result<()> { | ||
| median_test( | ||
| "median", | ||
| DataType::UInt64, | ||
| Arc::new(UInt64Array::from(vec![u64::MIN, u64::MIN, 100, u64::MAX])), | ||
| "50", | ||
| ) | ||
| .await | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn median_f32() -> Result<()> { | ||
| median_test( | ||
| "median", | ||
| DataType::Float32, | ||
| Arc::new(Float32Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])), | ||
| "3.3", | ||
| ) | ||
| .await | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn median_f64() -> Result<()> { | ||
| median_test( | ||
| "median", | ||
| DataType::Float64, | ||
| Arc::new(Float64Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])), | ||
| "3.3", | ||
| ) | ||
| .await | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn median_f64_nan() -> Result<()> { | ||
| median_test( | ||
| "median", | ||
| DataType::Float64, | ||
| Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])), | ||
| "NaN", // probably not the desired behavior? - see https://github.com/apache/arrow-datafusion/issues/3039 | ||
| ) | ||
| .await | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn approx_median_f64_nan() -> Result<()> { | ||
| median_test( | ||
| "approx_median", | ||
| DataType::Float64, | ||
| Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])), | ||
| "NaN", // probably not the desired behavior? - see https://github.com/apache/arrow-datafusion/issues/3039 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. testing for the win! |
||
| ) | ||
| .await | ||
| } | ||
|
|
||
| async fn median_test( | ||
| func: &str, | ||
| data_type: DataType, | ||
| values: ArrayRef, | ||
| expected: &str, | ||
| ) -> Result<()> { | ||
| let ctx = SessionContext::new(); | ||
| let schema = Arc::new(Schema::new(vec![Field::new("a", data_type, false)])); | ||
| let batch = RecordBatch::try_new(schema.clone(), vec![values])?; | ||
| let table = Arc::new(MemTable::try_new(schema, vec![vec![batch]])?); | ||
| ctx.register_table("t", table)?; | ||
| let sql = format!("SELECT {}(a) FROM t", func); | ||
| let actual = execute(&ctx, &sql).await; | ||
| let expected = vec![vec![expected.to_owned()]]; | ||
| assert_float_eq(&expected, &actual); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn csv_query_external_table_count() { | ||
| let ctx = SessionContext::new(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,7 +127,11 @@ where | |
| l.as_ref().parse::<f64>().unwrap(), | ||
| r.as_str().parse::<f64>().unwrap(), | ||
| ); | ||
| assert!((l - r).abs() <= 2.0 * f64::EPSILON); | ||
| if l.is_nan() || r.is_nan() { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
| assert!(l.is_nan() && r.is_nan()); | ||
| } else if (l - r).abs() > 2.0 * f64::EPSILON { | ||
| panic!("{} != {}", l, r) | ||
| } | ||
| }); | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,22 +18,22 @@ | |
| //! Accumulator module contains the trait definition for aggregation function's accumulators. | ||
|
|
||
| use arrow::array::ArrayRef; | ||
| use datafusion_common::{Result, ScalarValue}; | ||
| use datafusion_common::{DataFusionError, Result, ScalarValue}; | ||
| use std::fmt::Debug; | ||
|
|
||
| /// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and | ||
| /// generically accumulates values. | ||
| /// | ||
| /// An accumulator knows how to: | ||
| /// * update its state from inputs via `update_batch` | ||
| /// * convert its internal state to a vector of scalar values | ||
| /// * convert its internal state to a vector of aggregate values | ||
| /// * update its state from multiple accumulators' states via `merge_batch` | ||
| /// * compute the final value from its internal state via `evaluate` | ||
| pub trait Accumulator: Send + Sync + Debug { | ||
| /// Returns the state of the accumulator at the end of the accumulation. | ||
| // in the case of an average on which we track `sum` and `n`, this function should return a vector | ||
| // of two values, sum and n. | ||
| fn state(&self) -> Result<Vec<ScalarValue>>; | ||
| /// in the case of an average on which we track `sum` and `n`, this function should return a vector | ||
| /// of two values, sum and n. | ||
| fn state(&self) -> Result<Vec<AggregateState>>; | ||
|
|
||
| /// updates the accumulator's state from a vector of arrays. | ||
| fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; | ||
|
|
@@ -44,3 +44,38 @@ pub trait Accumulator: Send + Sync + Debug { | |
| /// returns its value based on its current state. | ||
| fn evaluate(&self) -> Result<ScalarValue>; | ||
| } | ||
|
|
||
| /// Representation of internal accumulator state. Accumulators can potentially have a mix of | ||
| /// scalar and array values. It may be desirable to add custom aggregator states here as well | ||
| /// in the future (perhaps `Custom(Box<dyn Any>)`?). | ||
| #[derive(Debug)] | ||
| pub enum AggregateState { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a very elegant idea. Can you please add docstrings to I think it would be worth updating the docstrings in the accumulator trait with some discussion / examples of how to use the Array state. |
||
| /// Simple scalar value. Note that `ScalarValue::List` can be used to pass multiple | ||
| /// values around | ||
| Scalar(ScalarValue), | ||
| /// Arrays can be used instead of `ScalarValue::List` and could potentially have better | ||
| /// performance with large data sets, although this has not been verified. It also allows | ||
| /// for use of arrow kernels with less overhead. | ||
| Array(ArrayRef), | ||
| } | ||
|
|
||
| impl AggregateState { | ||
| /// Access the aggregate state as a scalar value. An error will occur if the | ||
| /// state is not a scalar value. | ||
| pub fn as_scalar(&self) -> Result<&ScalarValue> { | ||
| match &self { | ||
| Self::Scalar(v) => Ok(v), | ||
| _ => Err(DataFusionError::Internal( | ||
| "AggregateState is not a scalar aggregate".to_string(), | ||
| )), | ||
| } | ||
| } | ||
|
|
||
| /// Access the aggregate state as an array value. | ||
| pub fn to_array(&self) -> ArrayRef { | ||
| match &self { | ||
| Self::Scalar(v) => v.to_array(), | ||
| Self::Array(array) => array.clone(), | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible, I would recommend adding a basic test in sql for a median for all the different data types that are supported (not just on aggregate_test_100 but a dedicated test setup with known data (maybe integers 10, 9, 8, ... 0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in ef1effd