Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 77 additions & 81 deletions datafusion/functions/src/math/power.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,34 +172,39 @@ impl ScalarUDFImpl for PowerFunc {
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let [arg1, arg2] = take_function_args(self.name(), arg_types)?;

fn coerced_type_base(name: &str, data_type: &DataType) -> Result<DataType> {
fn coerced_type_exp(name: &str, data_type: &DataType) -> Result<DataType> {
match data_type {
DataType::Null => Ok(DataType::Int64),
d if d.is_floating() => Ok(DataType::Float64),
d if d.is_integer() => Ok(DataType::Int64),
d if is_decimal(d) => Ok(d.clone()),
d if is_decimal(d) => Ok(DataType::Float64),
other => {
exec_err!("Unsupported data type {other:?} for {} function", name)
}
}
}

fn coerced_type_exp(name: &str, data_type: &DataType) -> Result<DataType> {
// Determine the exponent type first, as it affects base coercion
let exp_type = coerced_type_exp(self.name(), arg2)?;

// For base coercion: always use Float64 for integer/null bases
// This matches PostgreSQL behavior and handles negative exponents correctly
fn coerced_type_base(name: &str, data_type: &DataType) -> Result<DataType> {
match data_type {
DataType::Null => Ok(DataType::Int64),
d if d.is_floating() => Ok(DataType::Float64),
d if d.is_integer() => Ok(DataType::Int64),
d if is_decimal(d) => Ok(DataType::Float64),
// Integer and Null bases always coerce to Float64
// (integer power doesn't support negative exponents, and pow()
// should return float like PostgreSQL does)
DataType::Null => Ok(DataType::Float64),
d if d.is_integer() => Ok(DataType::Float64),
d if is_decimal(d) => Ok(d.clone()),
other => {
exec_err!("Unsupported data type {other:?} for {} function", name)
}
}
}

Ok(vec![
coerced_type_base(self.name(), arg1)?,
coerced_type_exp(self.name(), arg2)?,
])
Ok(vec![coerced_type_base(self.name(), arg1)?, exp_type])
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Expand All @@ -214,18 +219,6 @@ impl ScalarUDFImpl for PowerFunc {
|b, e| Ok(f64::powf(b, e)),
)?
}
(DataType::Int64, _) => {
calculate_binary_math::<Int64Type, Int64Type, Int64Type, _>(
&base,
exponent,
|b, e| match e.try_into() {
Ok(exp_u32) => b.pow_checked(exp_u32),
Err(_) => Err(ArrowError::ArithmeticOverflow(format!(
"Exponent {e} in integer computation is out of bounds."
))),
},
)?
}
(DataType::Decimal32(precision, scale), DataType::Int64) => {
calculate_binary_decimal_math::<Decimal32Type, Int64Type, Decimal32Type, _>(
&base,
Expand Down Expand Up @@ -392,10 +385,7 @@ mod tests {
use super::*;
use arrow::array::{Array, Decimal128Array, Float64Array, Int64Array};
use arrow::datatypes::{DECIMAL128_MAX_SCALE, Field};
use arrow_buffer::NullBuffer;
use datafusion_common::cast::{
as_decimal128_array, as_float64_array, as_int64_array,
};
use datafusion_common::cast::{as_decimal128_array, as_float64_array};
use datafusion_common::config::ConfigOptions;
use std::sync::Arc;

Expand Down Expand Up @@ -446,43 +436,6 @@ mod tests {
}
}

#[test]
fn test_power_i64() {
let arg_fields = vec![
Field::new("a", DataType::Int64, true).into(),
Field::new("a", DataType::Int64, true).into(),
];
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base
ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent
],
arg_fields,
number_rows: 4,
return_field: Field::new("f", DataType::Int64, true).into(),
config_options: Arc::new(ConfigOptions::default()),
};
let result = PowerFunc::new()
.invoke_with_args(args)
.expect("failed to initialize function power");

match result {
ColumnarValue::Array(arr) => {
let ints = as_int64_array(&arr)
.expect("failed to convert result to a Int64Array");

assert_eq!(ints.len(), 4);
assert_eq!(ints.value(0), 8);
assert_eq!(ints.value(1), 4);
assert_eq!(ints.value(2), 81);
assert_eq!(ints.value(3), 625);
}
ColumnarValue::Scalar(_) => {
panic!("Expected an array value")
}
}
}

#[test]
fn test_power_i128() {
let arg_fields = vec![
Expand Down Expand Up @@ -539,20 +492,21 @@ mod tests {
#[test]
fn test_power_array_null() {
let arg_fields = vec![
Field::new("a", DataType::Int64, true).into(),
Field::new("a", DataType::Int64, true).into(),
Field::new("a", DataType::Float64, true).into(),
Field::new("a", DataType::Float64, true).into(),
];
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 2]))), // base
ColumnarValue::Array(Arc::new(Int64Array::from_iter_values_with_nulls(
vec![1, 2, 3],
Some(NullBuffer::from(vec![true, false, true])),
))), // exponent
ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 2.0]))), // base
ColumnarValue::Array(Arc::new(Float64Array::from(vec![
Some(1.0),
None,
Some(3.0),
]))), // exponent
],
arg_fields,
number_rows: 1,
return_field: Field::new("f", DataType::Int64, true).into(),
number_rows: 3,
return_field: Field::new("f", DataType::Float64, true).into(),
config_options: Arc::new(ConfigOptions::default()),
};
let result = PowerFunc::new()
Expand All @@ -561,15 +515,15 @@ mod tests {

match result {
ColumnarValue::Array(arr) => {
let ints =
as_int64_array(&arr).expect("failed to convert result to an array");

assert_eq!(ints.len(), 3);
assert!(!ints.is_null(0));
assert_eq!(ints.value(0), i64::from(2));
assert!(ints.is_null(1));
assert!(!ints.is_null(2));
assert_eq!(ints.value(2), i64::from(8));
let floats =
as_float64_array(&arr).expect("failed to convert result to an array");

assert_eq!(floats.len(), 3);
assert!(!floats.is_null(0));
assert_eq!(floats.value(0), 2.0);
assert!(floats.is_null(1));
assert!(!floats.is_null(2));
assert_eq!(floats.value(2), 8.0);
}
ColumnarValue::Scalar(_) => {
panic!("Expected an array value")
Expand Down Expand Up @@ -647,4 +601,46 @@ mod tests {
"Not yet implemented: Negative scale is not yet supported value: -1"
);
}

#[test]
fn test_power_coerce_types() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this, very useful unit test

let power_func = PowerFunc::new();

// Int64 base with Int64 exponent -> base coerced to Float64 (like PostgreSQL)
// This allows negative exponents to work correctly
let result = power_func
.coerce_types(&[DataType::Int64, DataType::Int64])
.unwrap();
assert_eq!(result, vec![DataType::Float64, DataType::Int64]);

// Float64 base with Float64 exponent -> both stay Float64
let result = power_func
.coerce_types(&[DataType::Float64, DataType::Float64])
.unwrap();
assert_eq!(result, vec![DataType::Float64, DataType::Float64]);

// Int64 base with Float64 exponent -> base coerced to Float64
let result = power_func
.coerce_types(&[DataType::Int64, DataType::Float64])
.unwrap();
assert_eq!(result, vec![DataType::Float64, DataType::Float64]);

// Int32 base with Float32 exponent -> both coerced to Float64
let result = power_func
.coerce_types(&[DataType::Int32, DataType::Float32])
.unwrap();
assert_eq!(result, vec![DataType::Float64, DataType::Float64]);

// Null base with Float64 exponent -> base coerced to Float64
let result = power_func
.coerce_types(&[DataType::Null, DataType::Float64])
.unwrap();
assert_eq!(result, vec![DataType::Float64, DataType::Float64]);

// Null base with Int64 exponent -> base coerced to Float64 (like PostgreSQL)
let result = power_func
.coerce_types(&[DataType::Null, DataType::Int64])
.unwrap();
assert_eq!(result, vec![DataType::Float64, DataType::Int64]);
}
}
18 changes: 15 additions & 3 deletions datafusion/sqllogictest/test_files/decimal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -945,13 +945,25 @@ SELECT power(2.5, 0)
----
1

# TODO: check backward compatibility for a case with base in64 and exponent float64 since the power coercion is introduced
# int64 base with decimal exponent (coerced to float computation)
query R
SELECT power(10, -2.0)
----
0.01

query R
SELECT power(2, -0.5)
----
0.707106781187

# query error Unsupported data type Decimal128\(2, 1\) for power function
# SELECT power(2.5, 4.0)

query error Arrow error: .*in integer computation is out of bounds
# power() with very large exponent returns infinity (Float64 behavior)
query R
SELECT power(2, 100000000000)
----
Infinity

query error Arrow error: Arithmetic overflow: Unsupported exp value
SELECT power(2::decimal(38, 0), -5)
Expand Down Expand Up @@ -1076,7 +1088,7 @@ SELECT power(2.5, 4)
----
39.0625

query I
query R
SELECT power(2, null)
----
NULL
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/expr.slt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ SELECT true, false, false = false, true = false
true false true false

# test_mathematical_expressions_with_null
query RRRRRRRRRRRRRRRRRRRRRRRRIIIRRRRRRBB
query RRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRBB
SELECT
sqrt(NULL),
cbrt(NULL),
Expand Down
43 changes: 42 additions & 1 deletion datafusion/sqllogictest/test_files/math.slt
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,49 @@ query error DataFusion error: Arrow error: Compute error: Signed integer overflo
select lcm(2, 9223372036854775803);


query error DataFusion error: Arrow error: Arithmetic overflow: Overflow happened on: 2107754225 \^ 1221660777
## pow/power
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've placed a TODO
# TODO: check backward compatibility for a case with base in64 and exponent float64 since the power coercion is introduced
in decimal.slt:948, and they are testing datafusion.sql_parser.parse_float_as_decimal = true behaviour.

Tests in math.slt test pure float behaviour for this specific issue. Unfortunately, decimal behaviour is not triggered. Can we port a few of them to the decimal.slt to test the new flag?

If you've covered this case, let's remove that TODO.


# pow() with integer base and negative float exponent (verifies type coercion)
query R
SELECT pow(2, -0.5)
----
0.707106781187

# pow() with negative integer base and negative float exponent (returns NaN)
query R
SELECT pow(-2, -0.5)
----
NaN

# pow() with zero base and negative exponent (returns Infinity)
query R
SELECT pow(0, -0.5)
----
Infinity

# pow() with integer base of 1 and negative exponent
query R
SELECT pow(1, -0.5)
----
1

# pow() with large integer base and small negative exponent
query R
SELECT pow(1000, -0.1)
----
0.501187233627

# pow() with integer base and negative integer exponent returns float (like PostgreSQL)
query R
SELECT pow(2, -2)
----
0.25

# power() with very large exponent returns infinity (Float64 behavior)
query R
select power(2107754225, 1221660777);
----
Infinity

# factorial overflow
query error DataFusion error: Arrow error: Compute error: Overflow happened on FACTORIAL\(350943270\)
Expand Down
12 changes: 6 additions & 6 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -746,26 +746,26 @@ select pi(), pi() / 2, pi() / 3;

## power

# power scalar function
query III rowsort
# power scalar function (always returns Float64, like PostgreSQL)
query RRR rowsort
select power(2, 0), power(2, 1), power(2, 2);
----
1 2 4

# power scalar nulls
query I rowsort
query R rowsort
select power(null, 64);
----
NULL

# power scalar nulls #1
query I rowsort
query R rowsort
select power(2, null);
----
NULL

# power scalar nulls #2
query I rowsort
query R rowsort
select power(null, null);
----
NULL
Expand Down Expand Up @@ -1775,7 +1775,7 @@ CREATE TABLE test(
(-14, -14, -14.5, -14.5),
(NULL, NULL, NULL, NULL);

query IIRRIR rowsort
query RRRRRR rowsort
SELECT power(i32, exp_i) as power_i32,
power(i64, exp_f) as power_i64,
pow(f32, exp_i) as power_f32,
Expand Down