Skip to content

Commit

Permalink
review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Weakky committed Aug 23, 2024
1 parent 7edb9b8 commit aad0c68
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 34 deletions.
1 change: 0 additions & 1 deletion libs/test-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ async fn main() -> anyhow::Result<()> {
.first_datasource()
.load_url(|key| std::env::var(key).ok())
.unwrap(),
force: false,
queries: vec![SqlQueryInput {
name: "query".to_string(),
source: query_str,
Expand Down
4 changes: 2 additions & 2 deletions quaint/src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
mod column_type;
mod connection_info;

mod describe;
pub mod external;
pub mod metrics;
#[cfg(native)]
pub mod native;
mod parsed_query;
mod queryable;
mod result_set;
#[cfg(any(feature = "mssql-native", feature = "postgresql-native", feature = "mysql-native"))]
Expand All @@ -32,8 +32,8 @@ pub use connection_info::*;
#[cfg(native)]
pub use native::*;

pub use describe::*;
pub use external::*;
pub use parsed_query::*;
pub use queryable::*;
pub use transaction::*;

Expand Down
File renamed without changes.
47 changes: 28 additions & 19 deletions quaint/src/connector/postgres/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,30 +595,41 @@ impl Queryable for PostgreSql {
let mut columns: Vec<DescribedColumn> = Vec::with_capacity(stmt.columns().len());
let mut parameters: Vec<DescribedParameter> = Vec::with_capacity(stmt.params().len());

async fn infer_type(this: &PostgreSql, ty: &PostgresType) -> crate::Result<(ColumnType, Option<String>)> {
let enums_results = self
.query_raw("SELECT oid, typname FROM pg_type WHERE typtype = 'e';", &[])
.await?;

fn find_enum_by_oid(enums: &ResultSet, enum_oid: u32) -> Option<&str> {
enums.iter().find_map(|row| {
let oid = row.get("oid")?.as_i64()?;
let name = row.get("typname")?.as_str()?;

if enum_oid == u32::try_from(oid).unwrap() {
Some(name)
} else {
None
}
})
}

fn resolve_type(ty: &PostgresType, enums: &ResultSet) -> (ColumnType, Option<String>) {
let column_type = ColumnType::from(ty);

match ty.kind() {
PostgresKind::Enum => {
let enum_name = this
.query_raw("SELECT typname FROM pg_type WHERE oid = $1;", &[Value::int64(ty.oid())])
.await?
.into_single()?
.at(0)
.expect("could not find enum name")
.to_string()
.expect("enum name is not a string");

Ok((column_type, Some(enum_name)))
let enum_name = find_enum_by_oid(enums, ty.oid())
.unwrap_or_else(|| panic!("Could not find enum with oid {}", ty.oid()));

(column_type, Some(enum_name.to_string()))
}
_ => Ok((column_type, None)),
_ => (column_type, None),
}
}

let nullables = self.get_nullable_for_columns(&stmt).await?;

for (idx, (col, nullable)) in stmt.columns().iter().zip(nullables).enumerate() {
let (typ, enum_name) = infer_type(self, col.type_()).await?;
let (typ, enum_name) = resolve_type(col.type_(), &enums_results);

if col.name() == "?column?" {
let kind = ErrorKind::QueryInvalidInput(format!("Invalid column name '?column?' for index {idx}. Your SQL query must explicitly alias that column name."));
Expand All @@ -635,17 +646,15 @@ impl Queryable for PostgreSql {
}

for param in stmt.params() {
let (typ, enum_name) = infer_type(self, param).await?;
let (typ, enum_name) = resolve_type(param, &enums_results);

parameters.push(DescribedParameter::new_named(param.name(), typ).with_enum_name(enum_name));
}

let enum_names = self
.query_raw("SELECT typname FROM pg_type WHERE typtype = 'e';", &[])
.await?
let enum_names = enums_results
.into_iter()
.flat_map(|row| row.into_single().ok())
.flat_map(|v| v.to_string())
.filter_map(|row| row.take("typname"))
.filter_map(|v| v.to_string())
.collect::<Vec<_>>();

Ok(DescribedQuery {
Expand Down
10 changes: 9 additions & 1 deletion quaint/src/connector/result_set/result_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,20 @@ impl ResultRow {
}
}

/// Take a value with the given column name from the row. Usage
/// Get a value with the given column name from the row. Usage
/// documentation in [ResultRowRef](struct.ResultRowRef.html).
pub fn get(&self, name: &str) -> Option<&Value<'static>> {
self.columns.iter().position(|c| c == name).map(|idx| &self.values[idx])
}

/// Take a value with the given column name from the row.
pub fn take(mut self, name: &str) -> Option<Value<'static>> {
self.columns
.iter()
.position(|c| c == name)
.map(|idx| self.values.remove(idx))
}

/// Make a referring [ResultRowRef](struct.ResultRowRef.html).
pub fn as_ref(&self) -> ResultRowRef {
ResultRowRef {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,17 @@ fn build_error(input: Input<'_>, msg: &str) -> ConnectorError {
ConnectorError::from_msg(format!("SQL documentation parsing: {msg} at '{input}'."))
}

fn render_enum_names(enum_names: &[String]) -> String {
if enum_names.is_empty() {
String::new()
} else {
format!(
", {enum_names}",
enum_names = enum_names.iter().map(|name| format!("'{name}'")).join(", ")
)
}
}

fn parse_typ_opt<'a>(
input: Input<'a>,
enum_names: &'a [String],
Expand Down Expand Up @@ -215,7 +226,7 @@ fn parse_typ_opt<'a>(
})
.ok_or_else(|| build_error(
input,
&format!("invalid type: '{typ}' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal'{})", if enum_names.is_empty() { String::new() } else { format!(" , {}", enum_names.iter().map(|name| format!("'{name}'")).join(", ")) }),
&format!("invalid type: '{typ}' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal'{})", render_enum_names(enum_names)),
))?;

Ok((input.move_from(end + 1), Some(parsed_typ)))
Expand Down Expand Up @@ -247,19 +258,19 @@ fn parse_position_opt(input: Input<'_>) -> ConnectorResult<(Input<'_>, Option<us
}
}

fn parse_alias_opt(input: Input<'_>) -> ConnectorResult<(Input<'_>, Option<&'_ str>, bool)> {
fn parse_alias_opt(input: Input<'_>) -> ConnectorResult<(Input<'_>, Option<&'_ str>, Option<bool>)> {
if let Some((input, alias)) = input
.trim_start()
.strip_prefix_char(':')
.map(|input| input.take_until_pattern_or_eol(&[' ']))
{
if let Some(alias) = alias.strip_suffix_char('?') {
Ok((input, Some(alias.inner()), true))
Ok((input, Some(alias.inner()), Some(true)))
} else {
Ok((input, Some(alias.inner()), false))
Ok((input, Some(alias.inner()), None))
}
} else {
Ok((input, None, false))
Ok((input, None, None))
}
}

Expand Down Expand Up @@ -296,7 +307,7 @@ fn parse_param<'a>(param_input: Input<'a>, enum_names: &'a [String]) -> Connecto
let mut param = ParsedParameterDoc::default();

param.set_typ(typ);
param.set_nullable(nullable.then_some(true));
param.set_nullable(nullable);
param.set_position(position);
param.set_alias(alias);
param.set_documentation(documentation);
Expand Down Expand Up @@ -902,12 +913,12 @@ mod tests {
ConnectorErrorImpl {
user_facing_error: None,
message: Some(
"SQL documentation parsing: invalid type: 'UnknownTyp' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal' , 'MyEnum', 'MyEnum2') at '{UnknownTyp} $12567:alias'.",
"SQL documentation parsing: invalid type: 'UnknownTyp' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal', 'MyEnum', 'MyEnum2') at '{UnknownTyp} $12567:alias'.",
),
source: None,
context: SpanTrace [],
}
SQL documentation parsing: invalid type: 'UnknownTyp' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal' , 'MyEnum', 'MyEnum2') at '{UnknownTyp} $12567:alias'.
SQL documentation parsing: invalid type: 'UnknownTyp' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal', 'MyEnum', 'MyEnum2') at '{UnknownTyp} $12567:alias'.
,
)
"#]];
Expand Down
3 changes: 0 additions & 3 deletions schema-engine/json-rpc-api-build/methods/introspectSql.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ shape = "string"
shape = "sqlQueryInput"
isList = true

[recordShapes.introspectSqlParams.fields.force]
shape = "bool"

# Result

[recordShapes.introspectSqlResult]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,46 @@ fn parses_doc_no_alias(api: TestApi) {
api.introspect_sql("test_1", sql).send_sync().expect_result(expected)
}

#[test_connector(tags(Postgres))]
fn parses_doc_enum_name(api: TestApi) {
api.schema_push(ENUM_SCHEMA).send().assert_green();

let expected = expect![[r#"
IntrospectSqlQueryOutput {
name: "test_1",
source: "\n -- @param {MyFancyEnum} $1\n SELECT * FROM model WHERE id = $1;\n ",
documentation: None,
parameters: [
IntrospectSqlQueryParameterOutput {
documentation: None,
name: "int4",
typ: "MyFancyEnum",
nullable: false,
},
],
result_columns: [
IntrospectSqlQueryColumnOutput {
name: "id",
typ: "int",
nullable: false,
},
IntrospectSqlQueryColumnOutput {
name: "enum",
typ: "MyFancyEnum",
nullable: false,
},
],
}
"#]];

let sql = r#"
-- @param {MyFancyEnum} $1
SELECT * FROM model WHERE id = ?;
"#;

api.introspect_sql("test_1", sql).send_sync().expect_result(expected)
}

#[test_connector(tags(Postgres))]
fn invalid_position_fails(api: TestApi) {
api.schema_push(SIMPLE_SCHEMA).send().assert_green();
Expand Down

0 comments on commit aad0c68

Please sign in to comment.