diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 3fb48c392cbfe..198ad88b945bd 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -172,34 +172,39 @@ impl ScalarUDFImpl for PowerFunc { fn coerce_types(&self, arg_types: &[DataType]) -> Result> { let [arg1, arg2] = take_function_args(self.name(), arg_types)?; - fn coerced_type_base(name: &str, data_type: &DataType) -> Result { + fn coerced_type_exp(name: &str, data_type: &DataType) -> Result { match data_type { DataType::Null => Ok(DataType::Int64), d if d.is_floating() => Ok(DataType::Float64), d if d.is_integer() => Ok(DataType::Int64), - d if is_decimal(d) => Ok(d.clone()), + d if is_decimal(d) => Ok(DataType::Float64), other => { exec_err!("Unsupported data type {other:?} for {} function", name) } } } - fn coerced_type_exp(name: &str, data_type: &DataType) -> Result { + // Determine the exponent type first, as it affects base coercion + let exp_type = coerced_type_exp(self.name(), arg2)?; + + // For base coercion: always use Float64 for integer/null bases + // This matches PostgreSQL behavior and handles negative exponents correctly + fn coerced_type_base(name: &str, data_type: &DataType) -> Result { match data_type { - DataType::Null => Ok(DataType::Int64), d if d.is_floating() => Ok(DataType::Float64), - d if d.is_integer() => Ok(DataType::Int64), - d if is_decimal(d) => Ok(DataType::Float64), + // Integer and Null bases always coerce to Float64 + // (integer power doesn't support negative exponents, and pow() + // should return float like PostgreSQL does) + DataType::Null => Ok(DataType::Float64), + d if d.is_integer() => Ok(DataType::Float64), + d if is_decimal(d) => Ok(d.clone()), other => { exec_err!("Unsupported data type {other:?} for {} function", name) } } } - Ok(vec![ - coerced_type_base(self.name(), arg1)?, - coerced_type_exp(self.name(), arg2)?, - ]) + Ok(vec![coerced_type_base(self.name(), arg1)?, exp_type]) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -214,18 +219,6 @@ impl ScalarUDFImpl for PowerFunc { |b, e| Ok(f64::powf(b, e)), )? } - (DataType::Int64, _) => { - calculate_binary_math::( - &base, - exponent, - |b, e| match e.try_into() { - Ok(exp_u32) => b.pow_checked(exp_u32), - Err(_) => Err(ArrowError::ArithmeticOverflow(format!( - "Exponent {e} in integer computation is out of bounds." - ))), - }, - )? - } (DataType::Decimal32(precision, scale), DataType::Int64) => { calculate_binary_decimal_math::( &base, @@ -392,10 +385,7 @@ mod tests { use super::*; use arrow::array::{Array, Decimal128Array, Float64Array, Int64Array}; use arrow::datatypes::{DECIMAL128_MAX_SCALE, Field}; - use arrow_buffer::NullBuffer; - use datafusion_common::cast::{ - as_decimal128_array, as_float64_array, as_int64_array, - }; + use datafusion_common::cast::{as_decimal128_array, as_float64_array}; use datafusion_common::config::ConfigOptions; use std::sync::Arc; @@ -446,43 +436,6 @@ mod tests { } } - #[test] - fn test_power_i64() { - let arg_fields = vec![ - Field::new("a", DataType::Int64, true).into(), - Field::new("a", DataType::Int64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base - ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent - ], - arg_fields, - number_rows: 4, - return_field: Field::new("f", DataType::Int64, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_int64_array(&arr) - .expect("failed to convert result to a Int64Array"); - - assert_eq!(ints.len(), 4); - assert_eq!(ints.value(0), 8); - assert_eq!(ints.value(1), 4); - assert_eq!(ints.value(2), 81); - assert_eq!(ints.value(3), 625); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - #[test] fn test_power_i128() { let arg_fields = vec![ @@ -539,20 +492,21 @@ mod tests { #[test] fn test_power_array_null() { let arg_fields = vec![ - Field::new("a", DataType::Int64, true).into(), - Field::new("a", DataType::Int64, true).into(), + Field::new("a", DataType::Float64, true).into(), + Field::new("a", DataType::Float64, true).into(), ]; let args = ScalarFunctionArgs { args: vec![ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 2]))), // base - ColumnarValue::Array(Arc::new(Int64Array::from_iter_values_with_nulls( - vec![1, 2, 3], - Some(NullBuffer::from(vec![true, false, true])), - ))), // exponent + ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 2.0]))), // base + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + Some(1.0), + None, + Some(3.0), + ]))), // exponent ], arg_fields, - number_rows: 1, - return_field: Field::new("f", DataType::Int64, true).into(), + number_rows: 3, + return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; let result = PowerFunc::new() @@ -561,15 +515,15 @@ mod tests { match result { ColumnarValue::Array(arr) => { - let ints = - as_int64_array(&arr).expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 3); - assert!(!ints.is_null(0)); - assert_eq!(ints.value(0), i64::from(2)); - assert!(ints.is_null(1)); - assert!(!ints.is_null(2)); - assert_eq!(ints.value(2), i64::from(8)); + let floats = + as_float64_array(&arr).expect("failed to convert result to an array"); + + assert_eq!(floats.len(), 3); + assert!(!floats.is_null(0)); + assert_eq!(floats.value(0), 2.0); + assert!(floats.is_null(1)); + assert!(!floats.is_null(2)); + assert_eq!(floats.value(2), 8.0); } ColumnarValue::Scalar(_) => { panic!("Expected an array value") @@ -647,4 +601,46 @@ mod tests { "Not yet implemented: Negative scale is not yet supported value: -1" ); } + + #[test] + fn test_power_coerce_types() { + let power_func = PowerFunc::new(); + + // Int64 base with Int64 exponent -> base coerced to Float64 (like PostgreSQL) + // This allows negative exponents to work correctly + let result = power_func + .coerce_types(&[DataType::Int64, DataType::Int64]) + .unwrap(); + assert_eq!(result, vec![DataType::Float64, DataType::Int64]); + + // Float64 base with Float64 exponent -> both stay Float64 + let result = power_func + .coerce_types(&[DataType::Float64, DataType::Float64]) + .unwrap(); + assert_eq!(result, vec![DataType::Float64, DataType::Float64]); + + // Int64 base with Float64 exponent -> base coerced to Float64 + let result = power_func + .coerce_types(&[DataType::Int64, DataType::Float64]) + .unwrap(); + assert_eq!(result, vec![DataType::Float64, DataType::Float64]); + + // Int32 base with Float32 exponent -> both coerced to Float64 + let result = power_func + .coerce_types(&[DataType::Int32, DataType::Float32]) + .unwrap(); + assert_eq!(result, vec![DataType::Float64, DataType::Float64]); + + // Null base with Float64 exponent -> base coerced to Float64 + let result = power_func + .coerce_types(&[DataType::Null, DataType::Float64]) + .unwrap(); + assert_eq!(result, vec![DataType::Float64, DataType::Float64]); + + // Null base with Int64 exponent -> base coerced to Float64 (like PostgreSQL) + let result = power_func + .coerce_types(&[DataType::Null, DataType::Int64]) + .unwrap(); + assert_eq!(result, vec![DataType::Float64, DataType::Int64]); + } } diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index b93340639aafc..49909596e66d4 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -945,13 +945,25 @@ SELECT power(2.5, 0) ---- 1 -# TODO: check backward compatibility for a case with base in64 and exponent float64 since the power coercion is introduced +# int64 base with decimal exponent (coerced to float computation) +query R +SELECT power(10, -2.0) +---- +0.01 + +query R +SELECT power(2, -0.5) +---- +0.707106781187 # query error Unsupported data type Decimal128\(2, 1\) for power function # SELECT power(2.5, 4.0) -query error Arrow error: .*in integer computation is out of bounds +# power() with very large exponent returns infinity (Float64 behavior) +query R SELECT power(2, 100000000000) +---- +Infinity query error Arrow error: Arithmetic overflow: Unsupported exp value SELECT power(2::decimal(38, 0), -5) @@ -1076,7 +1088,7 @@ SELECT power(2.5, 4) ---- 39.0625 -query I +query R SELECT power(2, null) ---- NULL diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index d724ddae30e12..cec9b63675a66 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -22,7 +22,7 @@ SELECT true, false, false = false, true = false true false true false # test_mathematical_expressions_with_null -query RRRRRRRRRRRRRRRRRRRRRRRRIIIRRRRRRBB +query RRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRBB SELECT sqrt(NULL), cbrt(NULL), diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index f34e1156a785a..322ba7a104a7d 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -696,8 +696,49 @@ query error DataFusion error: Arrow error: Compute error: Signed integer overflo select lcm(2, 9223372036854775803); -query error DataFusion error: Arrow error: Arithmetic overflow: Overflow happened on: 2107754225 \^ 1221660777 +## pow/power + +# pow() with integer base and negative float exponent (verifies type coercion) +query R +SELECT pow(2, -0.5) +---- +0.707106781187 + +# pow() with negative integer base and negative float exponent (returns NaN) +query R +SELECT pow(-2, -0.5) +---- +NaN + +# pow() with zero base and negative exponent (returns Infinity) +query R +SELECT pow(0, -0.5) +---- +Infinity + +# pow() with integer base of 1 and negative exponent +query R +SELECT pow(1, -0.5) +---- +1 + +# pow() with large integer base and small negative exponent +query R +SELECT pow(1000, -0.1) +---- +0.501187233627 + +# pow() with integer base and negative integer exponent returns float (like PostgreSQL) +query R +SELECT pow(2, -2) +---- +0.25 + +# power() with very large exponent returns infinity (Float64 behavior) +query R select power(2107754225, 1221660777); +---- +Infinity # factorial overflow query error DataFusion error: Arrow error: Compute error: Overflow happened on FACTORIAL\(350943270\) diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 8eac9bd0c9558..7c6b38b78e500 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -746,26 +746,26 @@ select pi(), pi() / 2, pi() / 3; ## power -# power scalar function -query III rowsort +# power scalar function (always returns Float64, like PostgreSQL) +query RRR rowsort select power(2, 0), power(2, 1), power(2, 2); ---- 1 2 4 # power scalar nulls -query I rowsort +query R rowsort select power(null, 64); ---- NULL # power scalar nulls #1 -query I rowsort +query R rowsort select power(2, null); ---- NULL # power scalar nulls #2 -query I rowsort +query R rowsort select power(null, null); ---- NULL @@ -1775,7 +1775,7 @@ CREATE TABLE test( (-14, -14, -14.5, -14.5), (NULL, NULL, NULL, NULL); -query IIRRIR rowsort +query RRRRRR rowsort SELECT power(i32, exp_i) as power_i32, power(i64, exp_f) as power_i64, pow(f32, exp_i) as power_f32,