Skip to content

Commit

Permalink
wip: merge pgxsql and #[pgx(sql)] behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
bitwalker committed Feb 1, 2022
1 parent 043a4c6 commit 6e20553
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 92 deletions.
19 changes: 13 additions & 6 deletions pgx-utils/src/sql_entity_graph/pg_extern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ impl PgExtern {
&self.attr_tokens
}

fn overridden(&self) -> Option<String> {
fn overridden(&self) -> Option<syn::LitStr> {
let mut span = None;
let mut retval = None;
let mut in_commented_sql_block = false;
for attr in &self.func.attrs {
Expand All @@ -91,7 +92,8 @@ impl PgExtern {
Meta::Path(_) | Meta::List(_) => continue,
Meta::NameValue(mnv) => mnv,
};
if let syn::Lit::Str(inner) = content.lit {
if let syn::Lit::Str(ref inner) = content.lit {
span.get_or_insert(content.lit.span());
if !in_commented_sql_block && inner.value().trim() == "```pgxsql" {
in_commented_sql_block = true;
} else if in_commented_sql_block && inner.value().trim() == "```" {
Expand All @@ -108,7 +110,7 @@ impl PgExtern {
}
}
}
retval
retval.map(|s| syn::LitStr::new(s.as_ref(), span.unwrap()))
}

fn operator(&self) -> Option<PgOperator> {
Expand Down Expand Up @@ -227,8 +229,14 @@ impl ToTokens for PgExtern {
}
};
let operator = self.operator().into_iter();
let overridden = self.overridden().into_iter();
let to_sql_config = &self.to_sql_config;
let to_sql_config = match self.overridden() {
None => self.to_sql_config.clone(),
Some(content) => {
let mut config = self.to_sql_config.clone();
config.content = Some(content);
config
}
};

let sql_graph_entity_fn_name =
syn::Ident::new(&format!("__pgx_internals_fn_{}", ident), Span::call_site());
Expand All @@ -249,7 +257,6 @@ impl ToTokens for PgExtern {
fn_args: vec![#(#inputs),*],
fn_return: #returns,
operator: None#( .unwrap_or(Some(#operator)) )*,
overridden: None#( .unwrap_or(Some(#overridden)) )*,
to_sql_config: #to_sql_config,
};
pgx::datum::sql_entity_graph::SqlGraphEntity::Function(submission)
Expand Down
6 changes: 3 additions & 3 deletions pgx-utils/src/sql_entity_graph/to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use syn::{AttrStyle, Attribute, Lit, Meta, MetaList, MetaNameValue, NestedMeta};

#[derive(Debug, Clone)]
pub struct ToSqlConfig {
enabled: bool,
callback: Option<syn::Path>,
content: Option<syn::LitStr>,
pub enabled: bool,
pub callback: Option<syn::Path>,
pub content: Option<syn::LitStr>,
}
impl Default for ToSqlConfig {
fn default() -> Self {
Expand Down
149 changes: 66 additions & 83 deletions pgx/src/datum/sql_entity_graph/pg_extern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ pub struct PgExternEntity {
pub fn_args: Vec<PgExternArgumentEntity>,
pub fn_return: PgExternReturnEntity,
pub operator: Option<PgOperatorEntity>,
pub overridden: Option<&'static str>,
pub to_sql_config: ToSqlConfigEntity,
}

Expand Down Expand Up @@ -239,25 +238,12 @@ impl ToSql for PgExternEntity {
-- {module_path}::{name}\n\
{requires}\
{fn_sql}\
{overridden}\
",
name = self.name,
module_path = self.module_path,
file = self.file,
line = self.line,
fn_sql = if self.overridden.is_some() {
let mut inner = fn_sql
.lines()
.map(|f| format!("-- {}", f))
.collect::<Vec<_>>()
.join("\n");
inner.push_str(
"\n--\n-- Overridden as (due to a `///` comment with a `pgxsql` code block):",
);
inner
} else {
fn_sql
},
fn_sql = fn_sql,
requires = {
let requires_attrs = self
.extern_attrs
Expand All @@ -284,59 +270,56 @@ impl ToSql for PgExternEntity {
"".to_string()
}
},
overridden = self
.overridden
.map(|f| String::from("\n") + f + "\n")
.unwrap_or_default(),
);
tracing::trace!(sql = %ext_sql);

let rendered = match (self.overridden, &self.operator) {
(None, Some(op)) => {
let mut optionals = vec![];
if let Some(it) = op.commutator {
optionals.push(format!("\tCOMMUTATOR = {}", it));
};
if let Some(it) = op.negator {
optionals.push(format!("\tNEGATOR = {}", it));
};
if let Some(it) = op.restrict {
optionals.push(format!("\tRESTRICT = {}", it));
};
if let Some(it) = op.join {
optionals.push(format!("\tJOIN = {}", it));
};
if op.hashes {
optionals.push(String::from("\tHASHES"));
};
if op.merges {
optionals.push(String::from("\tMERGES"));
};
let rendered = if let Some(op) = &self.operator {
let mut optionals = vec![];
if let Some(it) = op.commutator {
optionals.push(format!("\tCOMMUTATOR = {}", it));
};
if let Some(it) = op.negator {
optionals.push(format!("\tNEGATOR = {}", it));
};
if let Some(it) = op.restrict {
optionals.push(format!("\tRESTRICT = {}", it));
};
if let Some(it) = op.join {
optionals.push(format!("\tJOIN = {}", it));
};
if op.hashes {
optionals.push(String::from("\tHASHES"));
};
if op.merges {
optionals.push(String::from("\tMERGES"));
};

let left_arg = self.fn_args.get(0).ok_or_else(|| {
eyre!("Did not find `left_arg` for operator `{}`.", self.name)
})?;
let left_arg_graph_index = context
.graph
.neighbors_undirected(self_index)
.find(|neighbor| match &context.graph[*neighbor] {
SqlGraphEntity::Type(ty) => ty.id_matches(&left_arg.ty_id),
_ => false,
})
.ok_or_else(|| eyre!("Could not find left arg function in graph."))?;
let right_arg = self.fn_args.get(1).ok_or_else(|| {
eyre!("Did not find `left_arg` for operator `{}`.", self.name)
})?;
let right_arg_graph_index = context
.graph
.neighbors_undirected(self_index)
.find(|neighbor| match &context.graph[*neighbor] {
SqlGraphEntity::Type(ty) => ty.id_matches(&right_arg.ty_id),
_ => false,
})
.ok_or_else(|| eyre!("Could not find right arg function in graph."))?;
let left_arg = self
.fn_args
.get(0)
.ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?;
let left_arg_graph_index = context
.graph
.neighbors_undirected(self_index)
.find(|neighbor| match &context.graph[*neighbor] {
SqlGraphEntity::Type(ty) => ty.id_matches(&left_arg.ty_id),
_ => false,
})
.ok_or_else(|| eyre!("Could not find left arg function in graph."))?;
let right_arg = self
.fn_args
.get(1)
.ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?;
let right_arg_graph_index = context
.graph
.neighbors_undirected(self_index)
.find(|neighbor| match &context.graph[*neighbor] {
SqlGraphEntity::Type(ty) => ty.id_matches(&right_arg.ty_id),
_ => false,
})
.ok_or_else(|| eyre!("Could not find right arg function in graph."))?;

let operator_sql = format!("\n\n\
let operator_sql = format!("\n\n\
-- {file}:{line}\n\
-- {module_path}::{unaliased_name}\n\
CREATE OPERATOR {opname} (\n\
Expand All @@ -345,26 +328,26 @@ impl ToSql for PgExternEntity {
\tRIGHTARG={schema_prefix_right}{right_arg}{maybe_comma} /* {right_name} */\n\
{optionals}\
);\
",
opname = op.opname.unwrap(),
file = self.file,
line = self.line,
name = self.name,
unaliased_name = self.unaliased_name,
module_path = self.module_path,
left_name = left_arg.full_path,
right_name = right_arg.full_path,
schema_prefix_left = context.schema_prefix_for(&left_arg_graph_index),
left_arg = context.type_id_to_sql_type(left_arg.ty_id).ok_or_else(|| eyre!("Failed to map argument `{}` type `{}` to SQL type while building operator `{}`.", left_arg.pattern, left_arg.full_path, self.name))?,
schema_prefix_right = context.schema_prefix_for(&right_arg_graph_index),
right_arg = context.type_id_to_sql_type(right_arg.ty_id).ok_or_else(|| eyre!("Failed to map argument `{}` type `{}` to SQL type while building operator `{}`.", right_arg.pattern, right_arg.full_path, self.name))?,
maybe_comma = if optionals.len() >= 1 { "," } else { "" },
optionals = if !optionals.is_empty() { optionals.join(",\n") + "\n" } else { "".to_string() },
);
tracing::trace!(sql = %operator_sql);
ext_sql + &operator_sql
}
(None, None) | (Some(_), Some(_)) | (Some(_), None) => ext_sql,
",
opname = op.opname.unwrap(),
file = self.file,
line = self.line,
name = self.name,
unaliased_name = self.unaliased_name,
module_path = self.module_path,
left_name = left_arg.full_path,
right_name = right_arg.full_path,
schema_prefix_left = context.schema_prefix_for(&left_arg_graph_index),
left_arg = context.type_id_to_sql_type(left_arg.ty_id).ok_or_else(|| eyre!("Failed to map argument `{}` type `{}` to SQL type while building operator `{}`.", left_arg.pattern, left_arg.full_path, self.name))?,
schema_prefix_right = context.schema_prefix_for(&right_arg_graph_index),
right_arg = context.type_id_to_sql_type(right_arg.ty_id).ok_or_else(|| eyre!("Failed to map argument `{}` type `{}` to SQL type while building operator `{}`.", right_arg.pattern, right_arg.full_path, self.name))?,
maybe_comma = if optionals.len() >= 1 { "," } else { "" },
optionals = if !optionals.is_empty() { optionals.join(",\n") + "\n" } else { "".to_string() },
);
tracing::trace!(sql = %operator_sql);
ext_sql + &operator_sql
} else {
ext_sql
};
Ok(rendered)
}
Expand Down

0 comments on commit 6e20553

Please sign in to comment.