Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions datafusion/sql/src/unparser/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ impl QueryBuilder {
new.body = Option::Some(value);
new
}
pub fn take_body(&mut self) -> Option<Box<ast::SetExpr>> {
self.body.take()
}
pub fn order_by(&mut self, value: Vec<ast::OrderByExpr>) -> &mut Self {
let new = self;
new.order_by = value;
Expand Down
77 changes: 64 additions & 13 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, Result};
use datafusion_expr::{expr::Alias, Expr, JoinConstraint, JoinType, LogicalPlan};
use sqlparser::ast::{self};
use sqlparser::ast::{self, SetExpr};

use crate::unparser::utils::unproject_agg_exprs;

Expand Down Expand Up @@ -78,7 +78,7 @@ impl Unparser<'_> {
| LogicalPlan::Limit(_)
| LogicalPlan::Statement(_)
| LogicalPlan::Values(_)
| LogicalPlan::Distinct(_) => self.select_to_sql(plan),
| LogicalPlan::Distinct(_) => self.select_to_sql_statement(plan),
LogicalPlan::Dml(_) => self.dml_to_sql(plan),
LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
Expand All @@ -92,32 +92,47 @@ impl Unparser<'_> {
}
}

fn select_to_sql(&self, plan: &LogicalPlan) -> Result<ast::Statement> {
let mut query_builder = QueryBuilder::default();
fn select_to_sql_statement(&self, plan: &LogicalPlan) -> Result<ast::Statement> {
let mut query_builder = Some(QueryBuilder::default());

let body = self.select_to_sql_expr(plan, &mut query_builder)?;

let query = query_builder.unwrap().body(Box::new(body)).build()?;

Ok(ast::Statement::Query(Box::new(query)))
}

fn select_to_sql_expr(
&self,
plan: &LogicalPlan,
query: &mut Option<QueryBuilder>,
) -> Result<ast::SetExpr> {
let mut select_builder = SelectBuilder::default();
select_builder.push_from(TableWithJoinsBuilder::default());
let mut relation_builder = RelationBuilder::default();
self.select_to_sql_recursively(
plan,
&mut query_builder,
query,
&mut select_builder,
&mut relation_builder,
)?;

// If we were able to construct a full body (i.e. UNION ALL), return it
if let Some(body) = query.as_mut().and_then(|q| q.take_body()) {
return Ok(*body);
}

let mut twj = select_builder.pop_from().unwrap();
twj.relation(relation_builder);
select_builder.push_from(twj);

let body = ast::SetExpr::Select(Box::new(select_builder.build()?));
let query = query_builder.body(Box::new(body)).build()?;

Ok(ast::Statement::Query(Box::new(query)))
Ok(ast::SetExpr::Select(Box::new(select_builder.build()?)))
}

fn select_to_sql_recursively(
&self,
plan: &LogicalPlan,
query: &mut QueryBuilder,
query: &mut Option<QueryBuilder>,
select: &mut SelectBuilder,
relation: &mut RelationBuilder,
) -> Result<()> {
Expand Down Expand Up @@ -206,6 +221,11 @@ impl Unparser<'_> {
}
LogicalPlan::Limit(limit) => {
if let Some(fetch) = limit.fetch {
let Some(query) = query.as_mut() else {
return internal_err!(
"Limit operator only valid in a statement context."
);
};
query.limit(Some(ast::Expr::Value(ast::Value::Number(
fetch.to_string(),
false,
Expand All @@ -220,7 +240,13 @@ impl Unparser<'_> {
)
}
LogicalPlan::Sort(sort) => {
query.order_by(self.sort_to_sql(sort.expr.clone())?);
if let Some(query_ref) = query {
query_ref.order_by(self.sort_to_sql(sort.expr.clone())?);
} else {
return internal_err!(
"Sort operator only valid in a statement context."
);
}

self.select_to_sql_recursively(
sort.input.as_ref(),
Expand Down Expand Up @@ -347,8 +373,33 @@ impl Unparser<'_> {

Ok(())
}
LogicalPlan::Union(_union) => {
not_impl_err!("Unsupported operator: {plan:?}")
LogicalPlan::Union(union) => {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if union.inputs.len() != 2 {
return not_impl_err!(
"UNION ALL expected 2 inputs, but found {}",
union.inputs.len()
);
}

let input_exprs: Vec<SetExpr> = union
.inputs
.iter()
.map(|input| self.select_to_sql_expr(input, &mut None))
.collect::<Result<Vec<_>>>()?;

let union_expr = ast::SetExpr::SetOperation {
op: ast::SetOperator::Union,
set_quantifier: ast::SetQuantifier::All,
left: Box::new(input_exprs[0].clone()),
right: Box::new(input_exprs[1].clone()),
};

query
.as_mut()
.expect("to have a query builder")
.body(Box::new(union_expr));

Ok(())
}
LogicalPlan::Window(_window) => {
not_impl_err!("Unsupported operator: {plan:?}")
Expand Down
8 changes: 8 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4641,6 +4641,14 @@ fn roundtrip_statement() -> Result<()> {
group by "Last Name", p.id
having count_first_name>5 and count_first_name<10
order by count_first_name, "Last Name""#,
r#"SELECT j1_string as string FROM j1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should also add a test (as a follow on PR) for UNION (not UNION ALL)

UNION ALL
SELECT j2_string as string FROM j2"#,
r#"SELECT j1_string as string FROM j1
UNION ALL
SELECT j2_string as string FROM j2
ORDER BY string DESC
LIMIT 10"#
];

// For each test sql string, we transform as follows:
Expand Down