diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index 0a76aee2e0660..c8f53cd7baeba 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -50,6 +50,9 @@ impl QueryBuilder { new.body = Option::Some(value); new } + pub fn take_body(&mut self) -> Option> { + self.body.take() + } pub fn order_by(&mut self, value: Vec) -> &mut Self { let new = self; new.order_by = value; diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index de36fb0371ee7..3373220a84761 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -17,7 +17,7 @@ use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::{expr::Alias, Expr, JoinConstraint, JoinType, LogicalPlan}; -use sqlparser::ast::{self}; +use sqlparser::ast::{self, SetExpr}; use crate::unparser::utils::unproject_agg_exprs; @@ -78,7 +78,7 @@ impl Unparser<'_> { | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) - | LogicalPlan::Distinct(_) => self.select_to_sql(plan), + | LogicalPlan::Distinct(_) => self.select_to_sql_statement(plan), LogicalPlan::Dml(_) => self.dml_to_sql(plan), LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) @@ -92,32 +92,47 @@ impl Unparser<'_> { } } - fn select_to_sql(&self, plan: &LogicalPlan) -> Result { - let mut query_builder = QueryBuilder::default(); + fn select_to_sql_statement(&self, plan: &LogicalPlan) -> Result { + let mut query_builder = Some(QueryBuilder::default()); + + let body = self.select_to_sql_expr(plan, &mut query_builder)?; + + let query = query_builder.unwrap().body(Box::new(body)).build()?; + + Ok(ast::Statement::Query(Box::new(query))) + } + + fn select_to_sql_expr( + &self, + plan: &LogicalPlan, + query: &mut Option, + ) -> Result { let mut select_builder = SelectBuilder::default(); select_builder.push_from(TableWithJoinsBuilder::default()); let mut relation_builder = RelationBuilder::default(); self.select_to_sql_recursively( plan, - &mut query_builder, + query, &mut select_builder, &mut relation_builder, )?; + // If we were able to construct a full body (i.e. UNION ALL), return it + if let Some(body) = query.as_mut().and_then(|q| q.take_body()) { + return Ok(*body); + } + let mut twj = select_builder.pop_from().unwrap(); twj.relation(relation_builder); select_builder.push_from(twj); - let body = ast::SetExpr::Select(Box::new(select_builder.build()?)); - let query = query_builder.body(Box::new(body)).build()?; - - Ok(ast::Statement::Query(Box::new(query))) + Ok(ast::SetExpr::Select(Box::new(select_builder.build()?))) } fn select_to_sql_recursively( &self, plan: &LogicalPlan, - query: &mut QueryBuilder, + query: &mut Option, select: &mut SelectBuilder, relation: &mut RelationBuilder, ) -> Result<()> { @@ -206,6 +221,11 @@ impl Unparser<'_> { } LogicalPlan::Limit(limit) => { if let Some(fetch) = limit.fetch { + let Some(query) = query.as_mut() else { + return internal_err!( + "Limit operator only valid in a statement context." + ); + }; query.limit(Some(ast::Expr::Value(ast::Value::Number( fetch.to_string(), false, @@ -220,7 +240,13 @@ impl Unparser<'_> { ) } LogicalPlan::Sort(sort) => { - query.order_by(self.sort_to_sql(sort.expr.clone())?); + if let Some(query_ref) = query { + query_ref.order_by(self.sort_to_sql(sort.expr.clone())?); + } else { + return internal_err!( + "Sort operator only valid in a statement context." + ); + } self.select_to_sql_recursively( sort.input.as_ref(), @@ -347,8 +373,33 @@ impl Unparser<'_> { Ok(()) } - LogicalPlan::Union(_union) => { - not_impl_err!("Unsupported operator: {plan:?}") + LogicalPlan::Union(union) => { + if union.inputs.len() != 2 { + return not_impl_err!( + "UNION ALL expected 2 inputs, but found {}", + union.inputs.len() + ); + } + + let input_exprs: Vec = union + .inputs + .iter() + .map(|input| self.select_to_sql_expr(input, &mut None)) + .collect::>>()?; + + let union_expr = ast::SetExpr::SetOperation { + op: ast::SetOperator::Union, + set_quantifier: ast::SetQuantifier::All, + left: Box::new(input_exprs[0].clone()), + right: Box::new(input_exprs[1].clone()), + }; + + query + .as_mut() + .expect("to have a query builder") + .body(Box::new(union_expr)); + + Ok(()) } LogicalPlan::Window(_window) => { not_impl_err!("Unsupported operator: {plan:?}") diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index af4dac5f3f891..29243a211d30b 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -4641,6 +4641,14 @@ fn roundtrip_statement() -> Result<()> { group by "Last Name", p.id having count_first_name>5 and count_first_name<10 order by count_first_name, "Last Name""#, + r#"SELECT j1_string as string FROM j1 + UNION ALL + SELECT j2_string as string FROM j2"#, + r#"SELECT j1_string as string FROM j1 + UNION ALL + SELECT j2_string as string FROM j2 + ORDER BY string DESC + LIMIT 10"# ]; // For each test sql string, we transform as follows: