diff --git a/queryscript/Cargo.toml b/queryscript/Cargo.toml index 2c9e300..b359b13 100644 --- a/queryscript/Cargo.toml +++ b/queryscript/Cargo.toml @@ -46,8 +46,12 @@ lsp = [ "tower-lsp", ] backtraces = ["snafu/backtraces"] + +# Engines duckdb-bundled = ["duckdb/bundled"] +# mysql = ["mysql_async"] + [dependencies] # This version is synchronized with duckdb-rs arrow = { version = "28", default-features = false, features = [ @@ -89,12 +93,14 @@ url = "2.3.1" # -- ENGINES --- -# TODO: We should put each database dependency runtime behind a feature flag # DuckDB. duckdb = { version = "0.6.1" } # These are duckdb dependencies that we access directly hashlink = { version = "0.8" } +# MySQL (TODO) +# mysql_async = { version = "0.31.3", optional = true } + # -- CLI --- clap = { version = "4.0", features = ["derive"], optional = true } diff --git a/queryscript/src/runtime/error.rs b/queryscript/src/runtime/error.rs index a30d154..391e815 100644 --- a/queryscript/src/runtime/error.rs +++ b/queryscript/src/runtime/error.rs @@ -51,6 +51,13 @@ pub enum RuntimeError { backtrace: Option, }, + #[cfg(feature = "mysql")] + #[snafu(context(false))] + MySQLError { + source: mysql_async::Error, + backtrace: Option, + }, + #[snafu(context(false))] IOError { source: std::io::Error, diff --git a/queryscript/src/runtime/mysql/mod.rs b/queryscript/src/runtime/mysql/mod.rs new file mode 100644 index 0000000..c758c60 --- /dev/null +++ b/queryscript/src/runtime/mysql/mod.rs @@ -0,0 +1,150 @@ +use std::{ + cell::RefCell, + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use mysql_async::{prelude::*, Conn, Params, Pool, Row}; +use sqlparser::ast as sqlast; + +use crate::{ + compile::schema::Ident, + types::{record::VecRow, Relation, Type, Value}, +}; + +use super::{ + error::rt_unimplemented, normalize::Normalizer, Result, SQLEngine, SQLEnginePool, + SQLEngineType, SQLParam, +}; + +mod value; + +#[derive(Debug)] +pub struct MySQLEngine { + pool: Pool, + conn: Conn, +} + +#[async_trait::async_trait] +impl SQLEnginePool for MySQLEngine { + async fn new(url: Arc) -> Result> { + let pool = Pool::from_url(url.get_url().as_str())?; + let conn = pool.get_conn().await?; + Ok(Box::new(MySQLEngine { pool, conn })) + } +} + +#[async_trait::async_trait] +impl SQLEngine for MySQLEngine { + async fn eval( + &mut self, + query: &sqlast::Statement, + params: HashMap, + ) -> Result> { + let mut scalar_params = Vec::new(); + + for (key, param) in params.iter() { + match ¶m.value { + Value::Relation(_) => { + return rt_unimplemented!("Relation parameters in MySQL"); + } + Value::Fn(_) => { + return rt_unimplemented!("Function parameters in MySQL"); + } + _ => { + scalar_params.push(key.clone()); + } + } + } + + let normalizer = MySQLNormalizer::new(&scalar_params); + let query = normalizer.normalize(&query).as_result()?; + + let query_string = format!("{}", query); + let mysql_params = if params.len() > 0 { + Params::Named( + params + .into_iter() + .map(|(k, v)| { + ( + normalizer + .params + .get(k.as_str()) + .unwrap() + .as_bytes() + .to_vec(), + v.value.into(), + ) + }) + .collect(), + ) + } else { + Params::Empty + }; + + let result: Vec = self.conn.exec(query_string, mysql_params).await?; + + let empty_schema = Arc::new(Vec::new()); + let relation = Arc::new(value::MySQLRelation { + rows: result + .into_iter() + .map(|r| VecRow::new(empty_schema.clone(), r.0)) + .collect(), + schema: empty_schema.clone(), + }); + Ok(relation) + } + + async fn load( + &mut self, + table: &sqlast::ObjectName, + value: Value, + type_: Type, + temporary: bool, + ) -> Result<()> { + todo!() + } + + async fn create(&mut self) -> Result<()> { + // NOTE: This should probably be a method on the pool, not the engine, + // since the connection assumes that the database exists. + Ok(()) + } + + /// Ideally, this gets generalized and we use information schema tables. However, there's + /// no standard way to tell what database we're currently in. We should generalize this function + /// eventually. + async fn table_exists(&mut self, name: &sqlast::ObjectName) -> Result { + todo!() + } + + fn engine_type(&self) -> SQLEngineType { + SQLEngineType::MySQL + } +} + +pub struct MySQLNormalizer { + params: HashMap, +} + +impl MySQLNormalizer { + pub fn new(scalar_params: &[Ident]) -> MySQLNormalizer { + let params: HashMap = scalar_params + .iter() + .enumerate() + .map(|(i, s)| (s.to_string(), format!(":{}", s))) + .collect(); + + MySQLNormalizer { params } + } +} + +impl Normalizer for MySQLNormalizer { + fn quote_style(&self) -> Option { + Some('`') + } + + fn param(&self, key: &str) -> Option<&str> { + self.params.get(key).map(|s| s.as_str()) + } +} diff --git a/queryscript/src/runtime/mysql/value.rs b/queryscript/src/runtime/mysql/value.rs new file mode 100644 index 0000000..9b9c221 --- /dev/null +++ b/queryscript/src/runtime/mysql/value.rs @@ -0,0 +1,116 @@ +use std::any::Any; +use std::sync::Arc; + +use crate::types::{ + record::VecRow, value::ArrowRecordBatch, Field, Record, RecordBatch, Relation, Value, +}; +use mysql_async::{ + prelude::{ConvIr, FromRow, FromValue}, + FromRowError, FromValueError, Row as MRow, Value as MValue, +}; + +impl Into for Value { + fn into(self) -> MValue { + todo!() + } +} + +pub struct ValueIr { + source_value: MValue, + target_value: Value, +} + +impl ConvIr for ValueIr { + fn new(v: MValue) -> Result { + let target_value = match &v { + MValue::NULL => Value::Null, + MValue::Bytes(bytes) => Value::Binary(bytes.clone()), + MValue::Int(i) => Value::Int64(*i), + MValue::UInt(u) => Value::UInt64(*u), + MValue::Float(f32) => Value::Float32(*f32), + MValue::Double(f64) => Value::Float64(*f64), + MValue::Date(..) | MValue::Time(..) => return Err(FromValueError(v)), + }; + + Ok(ValueIr { + source_value: v, + target_value, + }) + } + fn commit(self) -> Value { + self.target_value + } + fn rollback(self) -> MValue { + self.source_value + } +} + +impl FromValue for Value { + type Intermediate = ValueIr; +} + +#[derive(Debug, Clone)] +pub struct MySQLRow(pub Vec); + +impl FromRow for MySQLRow { + fn from_row_opt(mut row: MRow) -> Result { + match (0..row.len()) + .map(|i| row.take_opt(i).unwrap()) + .collect::, _>>() + { + Ok(values) => Ok(MySQLRow(values)), + Err(_) => Err(FromRowError(row)), + } + } +} + +#[derive(Debug)] +pub struct MySQLRelation { + pub rows: Vec>, + pub schema: Arc>, +} + +impl Relation for MySQLRelation { + fn schema(&self) -> Vec { + self.schema.as_ref().clone() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn num_batches(&self) -> usize { + 1 + } + fn batch(&self, _index: usize) -> &dyn RecordBatch { + self + } + + fn try_cast( + &self, + target_schema: &Vec, + ) -> crate::types::error::Result> { + Ok(Arc::new(MySQLRelation { + rows: self.rows.clone(), + schema: Arc::new(target_schema.clone()), + })) + } +} + +impl RecordBatch for MySQLRelation { + fn schema(&self) -> Vec { + self.schema.as_ref().clone() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn records(&self) -> Vec> { + self.rows.clone() + } + + fn as_arrow_recordbatch(&self) -> &ArrowRecordBatch { + todo!() + } +}