Skip to content

Commit

Permalink
feat(api): add API for unwrapping JSON values into backend-native val…
Browse files Browse the repository at this point in the history
…ues (#8958)

Adds `str`, `int`, `float`, and `bool` properties to `JSONValue` as well
as an `unwrap_as` method for easier programmatic usage and more
fine-grained casting.

Unless someone really hates the static property names, I'd prefer to
keep them as they are. Open to alternative names for `unwrap_as` though.

In theory this can all be done with casting, but if you look at what's
being done in the various backends it's typically a lot more involved
than that. Trino in particular requires queries over JSON to be
`VARCHAR` inputs, when then have to be cast **back** to its `JSON` type
to be able to cast _that_ to the desired output type.

Complicating the cast branching _just_ for the `JSON -> not JSON` case
seemed like the wrong tradeoff.

I went with these names to match the `map` and `array` APIs, and to
match the short type names we have for the specific types (`str`, `int`,
`float`, and `bool`), which exist to match the equivalent Python types.
  • Loading branch information
cpcloud committed Apr 15, 2024
1 parent c8d98a1 commit aebb5cf
Show file tree
Hide file tree
Showing 22 changed files with 772 additions and 16 deletions.
10 changes: 9 additions & 1 deletion ci/schema/bigquery.sql
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,15 @@ INSERT INTO {dataset}.json_t VALUES
(JSON '{{"a":"foo", "c":null}}'),
(JSON 'null'),
(JSON '[42,47,55]'),
(JSON '[]');
(JSON '[]'),
(JSON '"a"'),
(JSON '""'),
(JSON '"b"'),
(NULL),
(JSON 'true'),
(JSON 'false'),
(JSON '42'),
(JSON '37.37');


LOAD DATA OVERWRITE {dataset}.functional_alltypes (
Expand Down
12 changes: 10 additions & 2 deletions ci/schema/duckdb.sql
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,23 @@ INSERT INTO struct VALUES
(NULL),
({'a': 3.0, 'b': 'orange', 'c': NULL});

CREATE OR REPLACE TABLE json_t (js TEXT);
CREATE OR REPLACE TABLE json_t (js JSON);

INSERT INTO json_t VALUES
('{"a": [1,2,3,4], "b": 1}'),
('{"a":null,"b":2}'),
('{"a":"foo", "c":null}'),
('null'),
('[42,47,55]'),
('[]');
('[]'),
('"a"'),
('""'),
('"b"'),
(NULL),
('true'),
('false'),
('42'),
('37.37');

CREATE OR REPLACE TABLE win (g TEXT, x BIGINT NOT NULL, y BIGINT);
INSERT INTO win VALUES
Expand Down
10 changes: 9 additions & 1 deletion ci/schema/mysql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,15 @@ INSERT INTO json_t VALUES
('{"a":"foo", "c":null}'),
('null'),
('[42,47,55]'),
('[]');
('[]'),
('"a"'),
('""'),
('"b"'),
(NULL),
('true'),
('false'),
('42'),
('37.37');

DROP TABLE IF EXISTS win CASCADE;

Expand Down
10 changes: 9 additions & 1 deletion ci/schema/postgres.sql
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,15 @@ INSERT INTO json_t VALUES
('{"a":"foo", "c":null}'),
('null'),
('[42,47,55]'),
('[]');
('[]'),
('"a"'),
('""'),
('"b"'),
(NULL),
('true'),
('false'),
('42'),
('37.37');

DROP TABLE IF EXISTS win CASCADE;
CREATE TABLE win (g TEXT, x BIGINT NOT NULL, y BIGINT);
Expand Down
10 changes: 9 additions & 1 deletion ci/schema/risingwave.sql
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,15 @@ INSERT INTO "json_t" VALUES
('{"a":"foo", "c":null}'),
('null'),
('[42,47,55]'),
('[]');
('[]'),
('"a"'),
('""'),
('"b"'),
(NULL),
('true'),
('false'),
('42'),
('37.37');

DROP TABLE IF EXISTS "win" CASCADE;
CREATE TABLE "win" ("g" TEXT, "x" BIGINT, "y" BIGINT);
Expand Down
10 changes: 9 additions & 1 deletion ci/schema/snowflake.sql
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,15 @@ INSERT INTO "json_t" ("js")
SELECT parse_json('{"a":"foo", "c":null}') UNION
SELECT parse_json('null') UNION
SELECT parse_json('[42,47,55]') UNION
SELECT parse_json('[]');
SELECT parse_json('[]') UNION
SELECT parse_json('"a"') UNION
SELECT parse_json('""') UNION
SELECT parse_json('"b"') UNION
SELECT NULL UNION
SELECT parse_json('true') UNION
SELECT parse_json('false') UNION
SELECT parse_json('42') UNION
SELECT parse_json('37.37');

CREATE OR REPLACE TABLE "win" ("g" TEXT, "x" BIGINT NOT NULL, "y" BIGINT);
INSERT INTO "win" VALUES
Expand Down
10 changes: 9 additions & 1 deletion ci/schema/sqlite.sql
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,15 @@ INSERT INTO json_t VALUES
('{"a":"foo", "c":null}'),
('null'),
('[42,47,55]'),
('[]');
('[]'),
('"a"'),
('""'),
('"b"'),
(NULL),
('true'),
('false'),
('42'),
('37.37');

DROP TABLE IF EXISTS win;
CREATE TABLE win (g TEXT, x BIGINT NOT NULL, y BIGINT);
Expand Down
10 changes: 9 additions & 1 deletion ci/schema/trino.sql
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,15 @@ INSERT INTO memory.default.json_t VALUES
(JSON '{"a":"foo", "c":null}'),
(JSON 'null'),
(JSON '[42,47,55]'),
(JSON '[]');
(JSON '[]'),
(JSON '"a"'),
(JSON '""'),
(JSON '"b"'),
(NULL),
(JSON 'true'),
(JSON 'false'),
(JSON '42'),
(JSON '37.37');

DROP TABLE IF EXISTS win;
CREATE TABLE win (g VARCHAR, x BIGINT, y BIGINT);
Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,18 @@ def visit_Cast(self, op, *, arg, to):
def visit_JSONGetItem(self, op, *, arg, index):
return arg[index]

def visit_UnwrapJSONString(self, op, *, arg):
return self.f.anon["safe.string"](arg)

def visit_UnwrapJSONInt64(self, op, *, arg):
return self.f.anon["safe.int64"](arg)

def visit_UnwrapJSONFloat64(self, op, *, arg):
return self.f.anon["safe.float64"](arg)

def visit_UnwrapJSONBoolean(self, op, *, arg):
return self.f.anon["safe.bool"](arg)

def visit_ExtractEpochSeconds(self, op, *, arg):
return self.f.unix_seconds(arg)

Expand Down
36 changes: 34 additions & 2 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,42 @@ def visit_MapContains(self, op, *, arg, key):
return self.f.len(self.f.element_at(arg, key)).neq(0)

def visit_ToJSONMap(self, op, *, arg):
return sge.TryCast(this=arg, to=self.type_mapper.from_ibis(op.dtype))
return self.if_(
self.f.json_type(arg).eq("OBJECT"),
self.cast(self.cast(arg, dt.json), op.dtype),
NULL,
)

def visit_ToJSONArray(self, op, *, arg):
return self.visit_ToJSONMap(op, arg=arg)
return self.if_(
self.f.json_type(arg).eq("ARRAY"),
self.cast(self.cast(arg, dt.json), op.dtype),
NULL,
)

def visit_UnwrapJSONString(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("VARCHAR"),
self.f.json_extract_string(arg, "$"),
NULL,
)

def visit_UnwrapJSONInt64(self, op, *, arg):
arg_type = self.f.json_type(arg)
return self.if_(
arg_type.isin("UBIGINT", "BIGINT"), self.cast(arg, op.dtype), NULL
)

def visit_UnwrapJSONFloat64(self, op, *, arg):
arg_type = self.f.json_type(arg)
return self.if_(
arg_type.isin("UBIGINT", "BIGINT", "DOUBLE"), self.cast(arg, op.dtype), NULL
)

def visit_UnwrapJSONBoolean(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("BOOLEAN"), self.cast(arg, op.dtype), NULL
)

def visit_ArrayConcat(self, op, *, arg):
# TODO(cpcloud): map ArrayConcat to this in sqlglot instead of here
Expand Down
22 changes: 22 additions & 0 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,25 @@ def visit_TimestampAdd(self, op, *, left, right):
this=right.this * 1_000, unit=sge.Var(this="MICROSECOND")
)
return self.f.date_add(left, right, dialect=self.dialect)

def visit_UnwrapJSONString(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("STRING"), self.f.json_unquote(arg), NULL
)

def visit_UnwrapJSONInt64(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("INTEGER"), self.cast(arg, op.dtype), NULL
)

def visit_UnwrapJSONFloat64(self, op, *, arg):
return self.if_(
self.f.json_type(arg).isin("DOUBLE", "INTEGER"),
self.cast(arg, op.dtype),
NULL,
)

def visit_UnwrapJSONBoolean(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("BOOLEAN"), self.if_(arg.eq("true"), 1, 0), NULL
)
47 changes: 47 additions & 0 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,53 @@ def visit_StructField(self, op, *, arg, field):
op.dtype,
)

def visit_UnwrapJSONString(self, op, *, arg):
return self.if_(
self.f.json_typeof(arg).eq("string"),
self.f.json_extract_path_text(
arg,
# this is apparently how you pass in no additional arguments to
# a variadic function, see the "Variadic Function Resolution"
# section in
# https://www.postgresql.org/docs/current/typeconv-func.html
sge.Var(this="VARIADIC ARRAY[]::TEXT[]"),
),
NULL,
)

def visit_UnwrapJSONInt64(self, op, *, arg):
text = self.f.json_extract_path_text(
arg, sge.Var(this="VARIADIC ARRAY[]::TEXT[]")
)
return self.if_(
self.f.json_typeof(arg).eq("number"),
self.cast(
self.if_(self.f.regexp_like(text, r"^\d+$", "g"), text, NULL),
op.dtype,
),
NULL,
)

def visit_UnwrapJSONFloat64(self, op, *, arg):
text = self.f.json_extract_path_text(
arg, sge.Var(this="VARIADIC ARRAY[]::TEXT[]")
)
return self.if_(
self.f.json_typeof(arg).eq("number"), self.cast(text, op.dtype), NULL
)

def visit_UnwrapJSONBoolean(self, op, *, arg):
return self.if_(
self.f.json_typeof(arg).eq("boolean"),
self.cast(
self.f.json_extract_path_text(
arg, sge.Var(this="VARIADIC ARRAY[]::TEXT[]")
),
op.dtype,
),
NULL,
)

def visit_StructColumn(self, op, *, names, values):
return self.f.row(*map(self.cast, values, op.dtype.types))

Expand Down
46 changes: 46 additions & 0 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pyspark import SparkConf
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import PandasUDFType, pandas_udf
from pyspark.sql.types import BooleanType, DoubleType, LongType, StringType

import ibis.common.exceptions as com
import ibis.config
Expand Down Expand Up @@ -40,6 +41,47 @@ def normalize_filenames(source_list):
return list(map(util.normalize_filename, source_list))


@pandas_udf(returnType=DoubleType(), functionType=PandasUDFType.SCALAR)
def unwrap_json_float(s: pd.Series) -> pd.Series:
import json

import pandas as pd

def nullify_type_mismatched_value(raw):
if pd.isna(raw):
return None

value = json.loads(raw)
# exact type check because we want to distinguish between integer
# and booleans and bool is a subclass of int
return value if type(value) in (float, int) else None

return s.map(nullify_type_mismatched_value)


def unwrap_json(typ):
import json

import pandas as pd

type_mapping = {str: StringType(), int: LongType(), bool: BooleanType()}

@pandas_udf(returnType=type_mapping[typ], functionType=PandasUDFType.SCALAR)
def unwrap(s: pd.Series) -> pd.Series:
def nullify_type_mismatched_value(raw):
if pd.isna(raw):
return None

value = json.loads(raw)
# exact type check because we want to distinguish between integer
# and booleans and bool is a subclass of int
return value if type(value) == typ else None

return s.map(nullify_type_mismatched_value)

return unwrap


class _PySparkCursor:
"""Spark cursor.
Expand Down Expand Up @@ -252,6 +294,10 @@ def _register_udfs(self, expr: ir.Expr) -> None:
spark_udf = pandas_udf(udf_func, udf_return, PandasUDFType.GROUPED_AGG)
self._session.udf.register(udf_name, spark_udf)

for typ in (str, int, bool):
self._session.udf.register(f"unwrap_json_{typ.__name__}", unwrap_json(typ))
self._session.udf.register("unwrap_json_float", unwrap_json_float)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = PySparkSchema.from_ibis(op.schema)
df = self._session.createDataFrame(data=op.data.to_frame(), schema=schema)
Expand Down
4 changes: 4 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class PySparkCompiler(SQLGlotCompiler):
ops.MapMerge: "map_concat",
ops.MapKeys: "map_keys",
ops.MapValues: "map_values",
ops.UnwrapJSONString: "unwrap_json_str",
ops.UnwrapJSONInt64: "unwrap_json_int",
ops.UnwrapJSONFloat64: "unwrap_json_float",
ops.UnwrapJSONBoolean: "unwrap_json_bool",
}

def _aggregate(self, funcname: str, *args, where):
Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,18 @@ def visit_ToJSONMap(self, op, *, arg):
def visit_ToJSONArray(self, op, *, arg):
return self.if_(self.f.is_array(arg), arg, NULL)

def visit_UnwrapJSONString(self, op, *, arg):
return self.if_(self.f.is_varchar(arg), self.f.as_varchar(arg), NULL)

def visit_UnwrapJSONInt64(self, op, *, arg):
return self.if_(self.f.is_integer(arg), self.f.as_integer(arg), NULL)

def visit_UnwrapJSONFloat64(self, op, *, arg):
return self.if_(self.f.is_double(arg), self.f.as_double(arg), NULL)

def visit_UnwrapJSONBoolean(self, op, *, arg):
return self.if_(self.f.is_boolean(arg), self.f.as_boolean(arg), NULL)

def visit_IsNan(self, op, *, arg):
return arg.eq(self.NAN)

Expand Down
Loading

0 comments on commit aebb5cf

Please sign in to comment.