diff --git a/core/src/main/scala/io/projectglow/functions.scala b/core/src/main/scala/io/projectglow/functions.scala index bf2221fbf..c38b0a6ce 100644 --- a/core/src/main/scala/io/projectglow/functions.scala +++ b/core/src/main/scala/io/projectglow/functions.scala @@ -217,7 +217,7 @@ object functions { } /** - * Sustitutes the missing values of a numeric array using the mean of the non-missing values. Any values that are NaN, null or equal to the missing value parameter are considered missing. See :ref:`variant-data-transformations` for more details. + * Substitutes the missing values of a numeric array using the mean of the non-missing values. Any values that are NaN, null or equal to the missing value parameter are considered missing. See :ref:`variant-data-transformations` for more details. * @group etl * @since 0.4.0 * diff --git a/functions.yml b/functions.yml index 7b1649cf0..6e7f8d24d 100644 --- a/functions.yml +++ b/functions.yml @@ -245,7 +245,7 @@ etl: type: str returns: A struct as explained above - name: mean_substitute - doc: Sustitutes the missing values of a numeric array using the mean of the non-missing values. Any values that + doc: Substitutes the missing values of a numeric array using the mean of the non-missing values. Any values that are NaN, null or equal to the missing value parameter are considered missing. See :ref:`variant-data-transformations` for more details. since: 0.4.0 diff --git a/python/glow/conversions.py b/python/glow/conversions.py new file mode 100644 index 000000000..cd3c20484 --- /dev/null +++ b/python/glow/conversions.py @@ -0,0 +1,86 @@ +import numpy as np +from py4j.java_collections import JavaArray +from pyspark import SparkContext +from typeguard import check_argument_types, check_return_type + + +def _is_numpy_double_array(object, dimensions: int) -> bool: + assert check_argument_types() + output = isinstance(object, np.ndarray) and len(object.shape) == dimensions and object.dtype.type == np.double + assert check_return_type(output) + return output + + +def _convert_numpy_to_java_array(np_arr: np.ndarray) -> JavaArray: + """ + Converts a flat numpy array of doubles to a Java array of doubles. + """ + assert check_argument_types() + assert len(np_arr.shape) == 1 + assert np_arr.dtype.type == np.double + + sc = SparkContext._active_spark_context + java_arr = sc._gateway.new_array(sc._jvm.double, np_arr.shape[0]) + for idx, ele in enumerate(np_arr): + java_arr[idx] = ele.item() + + assert check_return_type(java_arr) + return java_arr + + +class OneDimensionalDoubleNumpyArrayConverter(object): + """ + Replaces any 1-dimensional numpy array of doubles with a literal Java array. + + Added in version 0.4.0. + + Examples: + >>> import numpy as np + >>> from pyspark.sql.functions import lit + >>> from pyspark.sql.types import StringType + >>> str_list = ['a', 'b'] + >>> df = spark.createDataFrame(str_list, StringType()) + >>> ndarray = np.array([1.0, 2.1, 3.2]) + >>> df.withColumn("array", lit(ndarray)).collect() + [Row(value='a', array=[1.0, 2.1, 3.2]), Row(value='b', array=[1.0, 2.1, 3.2])] + """ + + def can_convert(self, object): + return _is_numpy_double_array(object, dimensions = 1) + + def convert(self, object, gateway_client): + sc = SparkContext._active_spark_context + java_arr = _convert_numpy_to_java_array(object) + return java_arr + + +class TwoDimensionalDoubleNumpyArrayConverter(object): + """ + Replaces any 2-dimensional numpy array of doubles with a literal DenseMatrix. + + Added in version 0.4.0. + + Examples: + >>> import numpy as np + >>> from pyspark.sql.functions import lit + >>> from pyspark.sql.types import StringType + >>> str_list = ['a', 'b'] + >>> df = spark.createDataFrame(str_list, StringType()) + >>> ndarray = np.array([[1.0, 2.1, 3.2], [4.3, 5.4, 6.5]]) + >>> df.withColumn("matrix", lit(ndarray)).collect() + [Row(value='a', matrix=DenseMatrix(2, 3, [1.0, 2.1, 3.2, 4.3, 5.4, 6.5], False)), Row(value='b', matrix=DenseMatrix(2, 3, [1.0, 2.1, 3.2, 4.3, 5.4, 6.5], False))] + """ + + def can_convert(self, object): + return _is_numpy_double_array(object, dimensions = 2) + + def convert(self, object, gateway_client): + sc = SparkContext._active_spark_context + flat_arr = object.ravel() + java_arr = _convert_numpy_to_java_array(flat_arr) + dense_matrix = sc._jvm.org.apache.spark.ml.linalg.DenseMatrix(object.shape[0], object.shape[1], java_arr) + matrix_udt = sc._jvm.org.apache.spark.ml.linalg.MatrixUDT() + converter = sc._jvm.org.apache.spark.sql.catalyst.CatalystTypeConverters.createToCatalystConverter(matrix_udt) + literal_matrix = sc._jvm.org.apache.spark.sql.catalyst.expressions.Literal.create( + converter.apply(dense_matrix), matrix_udt) + return literal_matrix diff --git a/python/glow/functions.py b/python/glow/functions.py index 8db42610e..0565c924d 100644 --- a/python/glow/functions.py +++ b/python/glow/functions.py @@ -318,7 +318,7 @@ def normalize_variant(contigName: Union[Column, str], start: Union[Column, str], def mean_substitute(array: Union[Column, str], missingValue: Union[Column, str] = None) -> Column: """ - Sustitutes the missing values of a numeric array using the mean of the non-missing values. Any values that are NaN, null or equal to the missing value parameter are considered missing. See :ref:`variant-data-transformations` for more details. + Substitutes the missing values of a numeric array using the mean of the non-missing values. Any values that are NaN, null or equal to the missing value parameter are considered missing. See :ref:`variant-data-transformations` for more details. Added in version 0.4.0. diff --git a/python/glow/glow.py b/python/glow/glow.py index 23496ccc9..77968a0c8 100644 --- a/python/glow/glow.py +++ b/python/glow/glow.py @@ -1,3 +1,6 @@ +from glow.conversions import OneDimensionalDoubleNumpyArrayConverter, TwoDimensionalDoubleNumpyArrayConverter +from py4j import protocol +from py4j.protocol import register_input_converter from pyspark import SparkContext from pyspark.sql import DataFrame, SQLContext, SparkSession from typing import Dict @@ -40,7 +43,7 @@ def transform(operation: str, df: DataFrame, arg_map: Dict[str, str]=None, def register(session: SparkSession): """ - Register SQL extensions for a Spark session. + Register SQL extensions and py4j converters for a Spark session. Args: session: Spark session @@ -51,3 +54,9 @@ def register(session: SparkSession): """ assert check_argument_types() session._jvm.io.projectglow.Glow.register(session._jsparkSession) + +# Register input converters in idempotent fashion +glow_input_converters = [OneDimensionalDoubleNumpyArrayConverter, TwoDimensionalDoubleNumpyArrayConverter] +for gic in glow_input_converters: + if not any(type(pic) is gic for pic in protocol.INPUT_CONVERTER): + register_input_converter(gic(), prepend = True) diff --git a/python/glow/tests/test_conversions.py b/python/glow/tests/test_conversions.py new file mode 100644 index 000000000..a3eaa6ebe --- /dev/null +++ b/python/glow/tests/test_conversions.py @@ -0,0 +1,62 @@ +from glow.conversions import OneDimensionalDoubleNumpyArrayConverter, TwoDimensionalDoubleNumpyArrayConverter +from importlib import reload +import numpy as np +from py4j.protocol import Py4JJavaError +from pyspark.ml.linalg import DenseMatrix +from pyspark.sql.functions import lit +from pyspark.sql.types import StringType +import pytest + + +def test_convert_matrix(spark): + str_list = ['a', 'b'] + df = spark.createDataFrame(str_list, StringType()) + ndarray = np.array([[1.0, 2.1, 3.2], [4.3, 5.4, 6.5]]) + output_rows = df.withColumn("matrix", lit(ndarray)).collect() + expected_matrix = DenseMatrix(2, 3, [1.0, 2.1, 3.2, 4.3, 5.4, 6.5]) + assert(output_rows[0].matrix == expected_matrix) + assert(output_rows[1].matrix == expected_matrix) + + +def test_convert_array(spark): + str_list = ['a', 'b'] + df = spark.createDataFrame(str_list, StringType()) + ndarray = np.array([1.0, 2.1, 3.2]) + output_rows = df.withColumn("array", lit(ndarray)).collect() + expected_array = [1.0, 2.1, 3.2] + assert(output_rows[0].array == expected_array) + assert(output_rows[1].array == expected_array) + + +def test_convert_checks_dimension(spark): + # No support for 3-dimensional arrays + ndarray = np.array([[[1.]]]) + with pytest.raises(Py4JJavaError): + lit(ndarray) + + +def test_convert_matrix_checks_type(spark): + ndarray = np.array([[1, 2], [3, 4]]) + with pytest.raises(AttributeError): + lit(ndarray) + + +def test_convert_array_checks_type(spark): + ndarray = np.array([1, 2]) + with pytest.raises(AttributeError): + lit(ndarray) + + +def test_register_converters_idempotent(spark): + import glow.glow + for _ in range(3): + reload(glow.glow) + one_d_converters = 0 + two_d_converters = 0 + for c in spark._sc._gateway._gateway_client.converters: + if type(c) is OneDimensionalDoubleNumpyArrayConverter: + one_d_converters += 1 + if type(c) is TwoDimensionalDoubleNumpyArrayConverter: + two_d_converters += 1 + assert(one_d_converters == 1) + assert(two_d_converters == 1)