Skip to content

ReturnFieldArgs.scalar_arguments type doesn't match with arg_fields #19982

@Jefffrey

Description

@Jefffrey

Describe the bug

/// Information about arguments passed to the function
///
/// This structure contains metadata about how the function was called
/// such as the type of the arguments, any scalar arguments and if the
/// arguments can (ever) be null
///
/// See [`ScalarUDFImpl::return_field_from_args`] for more information
#[derive(Debug)]
pub struct ReturnFieldArgs<'a> {
/// The data types of the arguments to the function
pub arg_fields: &'a [FieldRef],
/// Is argument `i` to the function a scalar (constant)?
///
/// If the argument `i` is not a scalar, it will be None
///
/// For example, if a function is called like `my_function(column_a, 5)`
/// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]`
pub scalar_arguments: &'a [Option<&'a ScalarValue>],
}

The ScalarValues in scalar_arguments don't necessarily have types which match arg_fields which can lead to confusion and extra handling in UDF definitions.

To Reproduce

// datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
#[tokio::test]
async fn test_scalar_arg_types() -> Result<()> {
    #[derive(Debug, PartialEq, Eq, Hash)]
    struct TestUdf {
        signature: Signature,
    }

    impl Default for TestUdf {
        fn default() -> Self {
            Self {
                signature: Signature::coercible(
                    vec![Coercion::new_implicit(
                        TypeSignatureClass::Native(logical_int16()),
                        vec![TypeSignatureClass::Numeric],
                        NativeType::Int16,
                    )],
                    Volatility::Immutable,
                ),
            }
        }
    }

    impl ScalarUDFImpl for TestUdf {
        fn as_any(&self) -> &dyn Any {
            self
        }

        fn name(&self) -> &str {
            "test_udf"
        }

        fn signature(&self) -> &Signature {
            &self.signature
        }

        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
            unreachable!()
        }

        fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
            assert_eq!(args.arg_fields.len(), 1);
            assert_eq!(args.scalar_arguments.len(), 1);
            assert_eq!(
                args.arg_fields[0].data_type(),
                &args.scalar_arguments[0].unwrap().data_type()
            ); // <-- This assert failing
            Ok(
                Field::new(self.name(), args.arg_fields[0].data_type().clone(), true)
                    .into(),
            )
        }

        fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
            assert!(matches!(
                args.args[0],
                ColumnarValue::Scalar(ScalarValue::Int16(Some(_)))
            ));
            Ok(args.args[0].clone())
        }
    }

    let ctx = SessionContext::new();
    ctx.register_udf(TestUdf::default().into());

    ctx.sql("select test_udf(1)").await?.collect().await?;

    Ok(())
}

Running this test:

datafusion (main)$ cargo test -p datafusion --test user_defined_integration user_defined::user_defined_scalar_functions::test_scalar_arg_types -- --nocapture --exact
    Finished `test` profile [unoptimized + debuginfo] target(s) in 0.38s
     Running tests/user_defined_integration.rs (/Users/jeffrey/.cargo_target_cache/debug/deps/user_defined_integration-53cd66eeac0051d3)

running 1 test

thread 'user_defined::user_defined_scalar_functions::test_scalar_arg_types' (14600918) panicked at datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:2264:13:
assertion `left == right` failed
  left: Int16
 right: Int64
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
test user_defined::user_defined_scalar_functions::test_scalar_arg_types ... FAILED

failures:

failures:
    user_defined::user_defined_scalar_functions::test_scalar_arg_types

test result: FAILED. 0 passed; 1 failed; 0 ignored; 0 measured; 79 filtered out; finished in 0.01s

error: test failed, to rerun pass `-p datafusion --test user_defined_integration`

The field says the argument should be an i16 according to signature, but the scalar argument is a i64.

Expected behavior

Either somehow get type coercion to apply on scalar_arguments (not sure if this is possible), or update documentation to make this limitation clear.

Additional context

Related code

Expr::ScalarFunction(ScalarFunction { func, args }) => {
let (arg_types, fields): (Vec<DataType>, Vec<Arc<Field>>) = args
.iter()
.map(|e| e.to_field(schema).map(|(_, f)| f))
.collect::<Result<Vec<_>>>()?
.into_iter()
.map(|f| (f.data_type().clone(), f))
.unzip();
// Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
let new_fields =
fields_with_udf(&fields, func.as_ref()).map_err(|err| {
plan_datafusion_err!(
"{} {}",
match err {
DataFusionError::Plan(msg) => msg,
err => err.to_string(),
},
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&arg_types,
)
)
})?;
let arguments = args
.iter()
.map(|e| match e {
Expr::Literal(sv, _) => Some(sv),
_ => None,
})
.collect::<Vec<_>>();
let args = ReturnFieldArgs {
arg_fields: &new_fields,
scalar_arguments: &arguments,
};
func.return_field_from_args(args)
}

Might also affect udaf/udwf?

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No fields configured for Bug.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions