Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion queryscript/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ lsp = [
"tower-lsp",
]
backtraces = ["snafu/backtraces"]

# Engines
duckdb-bundled = ["duckdb/bundled"]

# mysql = ["mysql_async"]

[dependencies]
# This version is synchronized with duckdb-rs
arrow = { version = "28", default-features = false, features = [
Expand Down Expand Up @@ -89,12 +93,14 @@ url = "2.3.1"


# -- ENGINES ---
# TODO: We should put each database dependency runtime behind a feature flag
# DuckDB.
duckdb = { version = "0.6.1" }
# These are duckdb dependencies that we access directly
hashlink = { version = "0.8" }

# MySQL (TODO)
# mysql_async = { version = "0.31.3", optional = true }


# -- CLI ---
clap = { version = "4.0", features = ["derive"], optional = true }
Expand Down
7 changes: 7 additions & 0 deletions queryscript/src/runtime/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ pub enum RuntimeError {
backtrace: Option<Backtrace>,
},

#[cfg(feature = "mysql")]
#[snafu(context(false))]
MySQLError {
source: mysql_async::Error,
backtrace: Option<Backtrace>,
},

#[snafu(context(false))]
IOError {
source: std::io::Error,
Expand Down
150 changes: 150 additions & 0 deletions queryscript/src/runtime/mysql/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use std::{
cell::RefCell,
collections::{HashMap, HashSet},
sync::Arc,
};

use mysql_async::{prelude::*, Conn, Params, Pool, Row};
use sqlparser::ast as sqlast;

use crate::{
compile::schema::Ident,
types::{record::VecRow, Relation, Type, Value},
};

use super::{
error::rt_unimplemented, normalize::Normalizer, Result, SQLEngine, SQLEnginePool,
SQLEngineType, SQLParam,
};

mod value;

#[derive(Debug)]
pub struct MySQLEngine {
pool: Pool,
conn: Conn,
}

#[async_trait::async_trait]
impl SQLEnginePool for MySQLEngine {
async fn new(url: Arc<crate::compile::ConnectionString>) -> Result<Box<dyn SQLEngine>> {
let pool = Pool::from_url(url.get_url().as_str())?;
let conn = pool.get_conn().await?;
Ok(Box::new(MySQLEngine { pool, conn }))
}
}

#[async_trait::async_trait]
impl SQLEngine for MySQLEngine {
async fn eval(
&mut self,
query: &sqlast::Statement,
params: HashMap<Ident, SQLParam>,
) -> Result<Arc<dyn Relation>> {
let mut scalar_params = Vec::new();

for (key, param) in params.iter() {
match &param.value {
Value::Relation(_) => {
return rt_unimplemented!("Relation parameters in MySQL");
}
Value::Fn(_) => {
return rt_unimplemented!("Function parameters in MySQL");
}
_ => {
scalar_params.push(key.clone());
}
}
}

let normalizer = MySQLNormalizer::new(&scalar_params);
let query = normalizer.normalize(&query).as_result()?;

let query_string = format!("{}", query);
let mysql_params = if params.len() > 0 {
Params::Named(
params
.into_iter()
.map(|(k, v)| {
(
normalizer
.params
.get(k.as_str())
.unwrap()
.as_bytes()
.to_vec(),
v.value.into(),
)
})
.collect(),
)
} else {
Params::Empty
};

let result: Vec<value::MySQLRow> = self.conn.exec(query_string, mysql_params).await?;

let empty_schema = Arc::new(Vec::new());
let relation = Arc::new(value::MySQLRelation {
rows: result
.into_iter()
.map(|r| VecRow::new(empty_schema.clone(), r.0))
.collect(),
schema: empty_schema.clone(),
});
Ok(relation)
}

async fn load(
&mut self,
table: &sqlast::ObjectName,
value: Value,
type_: Type,
temporary: bool,
) -> Result<()> {
todo!()
}

async fn create(&mut self) -> Result<()> {
// NOTE: This should probably be a method on the pool, not the engine,
// since the connection assumes that the database exists.
Ok(())
}

/// Ideally, this gets generalized and we use information schema tables. However, there's
/// no standard way to tell what database we're currently in. We should generalize this function
/// eventually.
async fn table_exists(&mut self, name: &sqlast::ObjectName) -> Result<bool> {
todo!()
}

fn engine_type(&self) -> SQLEngineType {
SQLEngineType::MySQL
}
}

pub struct MySQLNormalizer {
params: HashMap<String, String>,
}

impl MySQLNormalizer {
pub fn new(scalar_params: &[Ident]) -> MySQLNormalizer {
let params: HashMap<String, String> = scalar_params
.iter()
.enumerate()
.map(|(i, s)| (s.to_string(), format!(":{}", s)))
.collect();

MySQLNormalizer { params }
}
}

impl Normalizer for MySQLNormalizer {
fn quote_style(&self) -> Option<char> {
Some('`')
}

fn param(&self, key: &str) -> Option<&str> {
self.params.get(key).map(|s| s.as_str())
}
}
116 changes: 116 additions & 0 deletions queryscript/src/runtime/mysql/value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use std::any::Any;
use std::sync::Arc;

use crate::types::{
record::VecRow, value::ArrowRecordBatch, Field, Record, RecordBatch, Relation, Value,
};
use mysql_async::{
prelude::{ConvIr, FromRow, FromValue},
FromRowError, FromValueError, Row as MRow, Value as MValue,
};

impl Into<MValue> for Value {
fn into(self) -> MValue {
todo!()
}
}

pub struct ValueIr {
source_value: MValue,
target_value: Value,
}

impl ConvIr<Value> for ValueIr {
fn new(v: MValue) -> Result<ValueIr, FromValueError> {
let target_value = match &v {
MValue::NULL => Value::Null,
MValue::Bytes(bytes) => Value::Binary(bytes.clone()),
MValue::Int(i) => Value::Int64(*i),
MValue::UInt(u) => Value::UInt64(*u),
MValue::Float(f32) => Value::Float32(*f32),
MValue::Double(f64) => Value::Float64(*f64),
MValue::Date(..) | MValue::Time(..) => return Err(FromValueError(v)),
};

Ok(ValueIr {
source_value: v,
target_value,
})
}
fn commit(self) -> Value {
self.target_value
}
fn rollback(self) -> MValue {
self.source_value
}
}

impl FromValue for Value {
type Intermediate = ValueIr;
}

#[derive(Debug, Clone)]
pub struct MySQLRow(pub Vec<Value>);

impl FromRow for MySQLRow {
fn from_row_opt(mut row: MRow) -> Result<Self, FromRowError> {
match (0..row.len())
.map(|i| row.take_opt(i).unwrap())
.collect::<Result<Vec<_>, _>>()
{
Ok(values) => Ok(MySQLRow(values)),
Err(_) => Err(FromRowError(row)),
}
}
}

#[derive(Debug)]
pub struct MySQLRelation {
pub rows: Vec<Arc<dyn Record>>,
pub schema: Arc<Vec<Field>>,
}

impl Relation for MySQLRelation {
fn schema(&self) -> Vec<Field> {
self.schema.as_ref().clone()
}

fn as_any(&self) -> &dyn Any {
self
}

fn num_batches(&self) -> usize {
1
}
fn batch(&self, _index: usize) -> &dyn RecordBatch {
self
}

fn try_cast(
&self,
target_schema: &Vec<Field>,
) -> crate::types::error::Result<Arc<dyn Relation>> {
Ok(Arc::new(MySQLRelation {
rows: self.rows.clone(),
schema: Arc::new(target_schema.clone()),
}))
}
}

impl RecordBatch for MySQLRelation {
fn schema(&self) -> Vec<Field> {
self.schema.as_ref().clone()
}

fn as_any(&self) -> &dyn Any {
self
}

fn records(&self) -> Vec<Arc<dyn Record>> {
self.rows.clone()
}

fn as_arrow_recordbatch(&self) -> &ArrowRecordBatch {
todo!()
}
}