Skip to content

Commit

Permalink
feat: query json columns using dot notation, e.g. field.key~log4j
Browse files Browse the repository at this point in the history
This may be useful for other types of nested structures, too.

Signed-off-by: Jim Crossley <[email protected]>
  • Loading branch information
jcrossley3 authored and ctron committed Nov 7, 2024
1 parent ba0cce2 commit 15f52f6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 12 deletions.
12 changes: 10 additions & 2 deletions common/src/db/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl Query {

fn parse(&self) -> Vec<Constraint> {
// regex for filters: {field}{op}{value}
const RE: &str = r"^(?<field>[[:word:]]+)(?<op>=|!=|~|!~|>=|>|<=|<)(?<value>.*)$";
const RE: &str = r"^(?<field>[[:word:]\.]+)(?<op>=|!=|~|!~|>=|>|<=|<)(?<value>.*)$";
static LOCK: OnceLock<Regex> = OnceLock::new();
#[allow(clippy::unwrap_used)]
let regex = LOCK.get_or_init(|| (Regex::new(RE).unwrap()));
Expand Down Expand Up @@ -225,7 +225,10 @@ pub(crate) mod tests {
/////////////////////////////////////////////////////////////////////////

pub(crate) mod advisory {
use sea_orm::entity::prelude::*;
use std::collections::HashMap;

use sea_orm::{entity::prelude::*, FromJsonQueryResult};
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;

#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
Expand All @@ -238,6 +241,8 @@ pub(crate) mod tests {
pub published: Option<OffsetDateTime>,
pub severity: Severity,
pub score: f64,
#[sea_orm(column_type = "JsonBinary")]
pub purl: CanonicalPurl,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
Expand All @@ -253,5 +258,8 @@ pub(crate) mod tests {
#[sea_orm(string_value = "high")]
High,
}

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, FromJsonQueryResult)]
pub struct CanonicalPurl(pub HashMap<String, String>);
}
}
66 changes: 56 additions & 10 deletions common/src/db/query/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fmt::{Display, Formatter};

use sea_orm::entity::ColumnDef;
use sea_orm::{sea_query, ColumnTrait, ColumnType, EntityTrait, IntoIdentity, Iterable};
use sea_query::extension::postgres::PgExpr;
use sea_query::{Alias, ColumnRef, Expr, IntoColumnRef, IntoIden};

/// Context of columns which can be used for filtering and sorting.
Expand Down Expand Up @@ -127,19 +128,40 @@ impl Columns {
self.columns.iter()
}

/// Look up the column context for a given simple field name.
/// Look up the column context for a given simple field name.
pub(crate) fn for_field(&self, field: &str) -> Option<(Expr, ColumnDef)> {
self.columns
.iter()
.find(|(col, _)| {
fn name_match(tgt: &str) -> impl Fn(&&(ColumnRef, ColumnDef)) -> bool + use<'_> {
|(col, _)| {
matches!(col,
ColumnRef::Column(name)
| ColumnRef::TableColumn(_, name)
| ColumnRef::SchemaTableColumn(_, _, name)
if name.to_string().eq_ignore_ascii_case(field))
})
.map(|(r, d)| (Expr::col(r.clone()), d.clone()))
ColumnRef::Column(name)
| ColumnRef::TableColumn(_, name)
| ColumnRef::SchemaTableColumn(_, _, name)
if name.to_string().eq_ignore_ascii_case(tgt))
}
}
match field.split_once('.') {
None => self
.columns
.iter()
.find(name_match(field))
.map(|(r, d)| (Expr::col(r.clone()), d.clone())),
Some((col, key)) => self
.columns
.iter()
.filter(|(_, def)| {
matches!(
def.get_column_type(),
ColumnType::Json | ColumnType::JsonBinary
)
})
.find(name_match(col))
.map(|(r, d)| {
(
Expr::expr(Expr::col(r.clone()).cast_json_field(key)),
d.clone(),
)
}),
}
}

pub(crate) fn translate(&self, field: &str, op: &str, value: &str) -> Option<String> {
Expand Down Expand Up @@ -291,4 +313,28 @@ mod tests {

Ok(())
}

#[test(tokio::test)]
async fn json_queries() -> Result<(), anyhow::Error> {
let clause = advisory::Entity::find()
.select_only()
.column(advisory::Column::Id)
.filtering_with(
q("purl.name~log4j&purl.version>1.0&purl.ty=maven").sort("purl.name"),
advisory::Entity.columns().alias("advisory", "foo"),
)?
.build(sea_orm::DatabaseBackend::Postgres)
.to_string()
.split("WHERE ")
.last()
.unwrap()
.to_string();

assert_eq!(
clause,
r#"(("foo"."purl" ->> 'name') ILIKE '%log4j%') AND ("foo"."purl" ->> 'version') > '1.0' AND ("foo"."purl" ->> 'ty') = 'maven' ORDER BY "foo"."purl" ->> 'name' ASC"#
);

Ok(())
}
}

0 comments on commit 15f52f6

Please sign in to comment.