Skip to content

Support table inputs for user defined table functions #18121

@alamb

Description

@alamb

Is your feature request related to a problem or challenge?

(based on discord thread with @timsaucer @pepijnve and @bmmeijers )

DataFusion has a TableFunction API:
https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableFunctionImpl.html
https://docs.rs/datafusion/latest/datafusion/catalog/struct.TableFunction.html

These functions take one or more arguments and produce a table as output. They are used for functions like generate_series which works like this

> select * from generate_series(1,2);
+-------+
| value |
+-------+
| 1     |
| 2     |
+-------+
2 row(s) fetched.
Elapsed 0.004 seconds.

However, As @bmmeijers reports, the TableFunction API does not allow access to data from columns from another table (e.g. LATERAL joins).

See commented lines 117--125 in https://pastebin.com/td76M8Fj

(alternate copy)

use datafusion::arrow::array::{ArrayRef, Int64Array, StringArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::common::{Result, ScalarValue, plan_err};
use datafusion::datasource::memory::MemTable;
use datafusion::logical_expr::Expr;
use datafusion::prelude::SessionContext;
use std::sync::Arc;
 
#[derive(Debug)]
pub struct TransformFunction {}
 
impl TableFunctionImpl for TransformFunction {
    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
        if exprs.len() != 3 {
            return plan_err!(
                "Expected exactly three arguments: a, b, c, but got {}",
                exprs.len()
            );
        }
 
        println!("{:?}", exprs);
 
        let extract_int64 = |expr: &Expr, arg_name: &str| -> Result<i64> {
            match expr {
                Expr::Literal(ScalarValue::Int64(Some(val)), _) => Ok(*val),
                // Expr::Column()
                _ => plan_err!("Argument {} must be an Int64 literal", arg_name),
            }
        };
 
        let a = extract_int64(&exprs[0], "a")?;
        let b = extract_int64(&exprs[1], "b")?;
        let c = extract_int64(&exprs[2], "c")?;
 
        // Compute output columns: x = a + b, y = b * c
        let x = a + b;
        let y = b * c;
 
        // Define output schema
        let schema = Arc::new(Schema::new(vec![
            Field::new("x", DataType::Int64, false),
            Field::new("y", DataType::Int64, false),
        ]));
 
        // Create output arrays
        let x_array = Arc::new(Int64Array::from(vec![x])) as ArrayRef;
        let y_array = Arc::new(Int64Array::from(vec![y])) as ArrayRef;
 
        // Create a single RecordBatch
        let batch = RecordBatch::try_new(schema.clone(), vec![x_array, y_array])?;
 
        // Wrap in a MemTable
        let provider = MemTable::try_new(schema, vec![vec![batch]])?;
 
        Ok(Arc::new(provider))
    }
}
 
// --- Usage Example ---
 
// /// Registers the TransformFunction as a TableUDF in the SessionContext.
fn register_udtf(ctx: &mut SessionContext) -> Result<()> {
    // 1. Create the implementation instance
    let udtf = Arc::new(TransformFunction {});
    ctx.register_udtf("my_transform", udtf);
 
    Ok(())
}
 
/// Creates a small in-memory table for demonstration.
fn create_dummy_table(ctx: &mut SessionContext) -> Result<()> {
    let schema = Arc::new(Schema::new(vec![
        Field::new("id", DataType::Utf8, false),
        Field::new("a", DataType::Int64, false),
        Field::new("b", DataType::Int64, false),
        Field::new("c", DataType::Int64, false),
    ]));
 
    let batch = RecordBatch::try_new(
        schema.clone(),
        vec![
            Arc::new(StringArray::from(vec!["r1", "r2"])) as ArrayRef,
            Arc::new(Int64Array::from(vec![10, 20])) as ArrayRef,
            Arc::new(Int64Array::from(vec![5, 6])) as ArrayRef,
            Arc::new(Int64Array::from(vec![2, 3])) as ArrayRef,
        ],
    )?;
 
    let provider = MemTable::try_new(schema, vec![vec![batch]])?;
    ctx.register_table("my_table", Arc::new(provider))?;
    Ok(())
}
 
#[tokio::main]
async fn main() -> Result<()> {
    let mut ctx = SessionContext::new();
 
    // 1. Register the custom UDTF
    register_udtf(&mut ctx)?;
 
    // 2. Register a dummy table
    create_dummy_table(&mut ctx)?;
 
    // 3. Define and execute the SQL query
    let sql = r#"
        SELECT 
            t1.id, 
            t2.x AS a_plus_b, 
            t2.y AS b_times_c
        FROM 
            my_table AS t1,
            LATERAL my_transform(1, 2, 3) AS t2(x, y)
    "#;
 
    // let sql = r#"
    //     SELECT 
    //         t1.id, 
    //         t2.x AS a_plus_b, 
    //         t2.y AS b_times_c
    //     FROM 
    //         my_table AS t1,
    //         LATERAL my_transform(t1.a, t1.b, t1.c) AS t2(x, y)
    // "#;
 
 
    println!("Executing SQL:\n{}", sql);
 
    let df = ctx.sql(sql).await?;
 
    println!("\nQuery Result:");
    df.show().await?;
 
    Ok(())
}

Specifically, it is possible to pass in columns for args in call(&self, args: &[Expr]) which works, for example this works (because the arguments to my_transform are scalar values 1,2,3:

        SELECT 
            t1.id, 
            t2.x AS a_plus_b, 
            t2.y AS b_times_c
        FROM 
            my_table AS t1,
            LATERAL my_transform(1, 2, 3) AS t2(x, y)

However, this doesn't work:

         SELECT 
             t1.id, 
             t2.x AS a_plus_b, 
             t2.y AS b_times_c
         FROM 
             my_table AS t1,
             LATERAL my_transform(t1.a, t1.b, t1.c) AS t2(x, y)

And you get an error like This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(

Describe the solution you'd like

I would like to be able to access the values from another table from a TableFunction

Here is a simple example of LATERAL and the range function that works today :

> SELECT *
  FROM range(3) t(i), LATERAL range(3) t2(j);
+---+---+
| i | j |
+---+---+
| 0 | 0 |
| 0 | 1 |
| 0 | 2 |
| 1 | 0 |
| 1 | 1 |
| 1 | 2 |
| 2 | 0 |
| 2 | 1 |
| 2 | 2 |
+---+---+
9 row(s) fetched.
Elapsed 0.002 seconds.

However, you can't get refer to the previous subquery in the argument to range,

> SELECT *
  FROM range(3) t(i), LATERAL (SELECT i + 1) t2(j);
This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Field { name: "i", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Column { relation: Some(Bare { table: "t" }), name: "i" })

I want some way to refer to other inputs in table expressions. Here is how it works in DuckDB

D SELECT *
  FROM range(3) t(i), LATERAL (SELECT i + 1) t2(j);
┌───────┬───────┐
│   i   │   j   │
│ int64 │ int64 │
├───────┼───────┤
│     01 │
│     12 │
│     23 │
└───────┴───────┘

Describe alternatives you've considered

Here is the duckdb explain plan

Details

D explain SELECT *
  FROM range(3) t(i), LATERAL (SELECT i + 1) t2(j);

┌─────────────────────────────┐
│┌───────────────────────────┐│
││       Physical Plan       ││
│└───────────────────────────┘│
└─────────────────────────────┘
┌───────────────────────────┐
│      LEFT_DELIM_JOIN      │
│    ────────────────────   │
│      Join Type: INNER     │
│                           │
│        Conditions:        ├──────────────┬─────────────────────────────────────────────────────────┐
│  i IS NOT DISTINCT FROM i │              │                                                         │
│                           │              │                                                         │
│          ~3 rows          │              │                                                         │
└─────────────┬─────────────┘              │                                                         │
┌─────────────┴─────────────┐┌─────────────┴─────────────┐                             ┌─────────────┴─────────────┐
│           RANGE           ││         HASH_JOIN         │                             │       HASH_GROUP_BY       │
│    ────────────────────   ││    ────────────────────   │                             │    ────────────────────   │
│      Function: RANGE      ││      Join Type: INNER     │                             │         Groups: #0        │
│                           ││                           │                             │                           │
│                           ││        Conditions:        ├──────────────┐              │                           │
│                           ││  i IS NOT DISTINCT FROM i │              │              │                           │
│                           ││                           │              │              │                           │
│          ~3 rows          ││          ~3 rows          │              │              │           ~1 row          │
└───────────────────────────┘└─────────────┬─────────────┘              │              └───────────────────────────┘
                             ┌─────────────┴─────────────┐┌─────────────┴─────────────┐
                             │      COLUMN_DATA_SCAN     ││         PROJECTION        │
                             │    ────────────────────   ││    ────────────────────   │
                             │                           ││          (i + 1)          │
                             │                           ││             i             │
                             │                           ││                           │
                             │          ~3 rows          ││           ~1 row          │
                             └───────────────────────────┘└─────────────┬─────────────┘
                                                          ┌─────────────┴─────────────┐
                                                          │         DELIM_SCAN        │
                                                          │    ────────────────────   │
                                                          │       Delim Index: 1      │
                                                          │                           │
                                                          │           ~1 row          │
                                                          └───────────────────────────┘

### Additional context

Here are some good documents about user defined table functions in the Snowflake documentation:

Here are the docs about LATERAL join from DuckDB:

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request
    No fields configured for Feature.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions