Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/io/projectglow/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ object functions {
}

/**
* Sustitutes the missing values of a numeric array using the mean of the non-missing values. Any values that are NaN, null or equal to the missing value parameter are considered missing. See :ref:`variant-data-transformations` for more details.
* Substitutes the missing values of a numeric array using the mean of the non-missing values. Any values that are NaN, null or equal to the missing value parameter are considered missing. See :ref:`variant-data-transformations` for more details.
* @group etl
* @since 0.4.0
*
Expand Down
2 changes: 1 addition & 1 deletion functions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ etl:
type: str
returns: A struct as explained above
- name: mean_substitute
doc: Sustitutes the missing values of a numeric array using the mean of the non-missing values. Any values that
doc: Substitutes the missing values of a numeric array using the mean of the non-missing values. Any values that
are NaN, null or equal to the missing value parameter are considered missing. See
:ref:`variant-data-transformations` for more details.
since: 0.4.0
Expand Down
86 changes: 86 additions & 0 deletions python/glow/conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import numpy as np
from py4j.java_collections import JavaArray
from pyspark import SparkContext
from typeguard import check_argument_types, check_return_type


def _is_numpy_double_array(object, dimensions: int) -> bool:
assert check_argument_types()
output = isinstance(object, np.ndarray) and len(object.shape) == dimensions and object.dtype.type == np.double
assert check_return_type(output)
return output


def _convert_numpy_to_java_array(np_arr: np.ndarray) -> JavaArray:
"""
Converts a flat numpy array of doubles to a Java array of doubles.
"""
assert check_argument_types()
assert len(np_arr.shape) == 1
assert np_arr.dtype.type == np.double

sc = SparkContext._active_spark_context
java_arr = sc._gateway.new_array(sc._jvm.double, np_arr.shape[0])
for idx, ele in enumerate(np_arr):
java_arr[idx] = ele.item()

assert check_return_type(java_arr)
return java_arr


class OneDimensionalDoubleNumpyArrayConverter(object):
"""
Replaces any 1-dimensional numpy array of doubles with a literal Java array.

Added in version 0.4.0.

Examples:
>>> import numpy as np
>>> from pyspark.sql.functions import lit
>>> from pyspark.sql.types import StringType
>>> str_list = ['a', 'b']
>>> df = spark.createDataFrame(str_list, StringType())
>>> ndarray = np.array([1.0, 2.1, 3.2])
>>> df.withColumn("array", lit(ndarray)).collect()
[Row(value='a', array=[1.0, 2.1, 3.2]), Row(value='b', array=[1.0, 2.1, 3.2])]
"""

def can_convert(self, object):
return _is_numpy_double_array(object, dimensions = 1)

def convert(self, object, gateway_client):
sc = SparkContext._active_spark_context
java_arr = _convert_numpy_to_java_array(object)
return java_arr


class TwoDimensionalDoubleNumpyArrayConverter(object):
"""
Replaces any 2-dimensional numpy array of doubles with a literal DenseMatrix.

Added in version 0.4.0.

Examples:
>>> import numpy as np
>>> from pyspark.sql.functions import lit
>>> from pyspark.sql.types import StringType
>>> str_list = ['a', 'b']
>>> df = spark.createDataFrame(str_list, StringType())
>>> ndarray = np.array([[1.0, 2.1, 3.2], [4.3, 5.4, 6.5]])
>>> df.withColumn("matrix", lit(ndarray)).collect()
[Row(value='a', matrix=DenseMatrix(2, 3, [1.0, 2.1, 3.2, 4.3, 5.4, 6.5], False)), Row(value='b', matrix=DenseMatrix(2, 3, [1.0, 2.1, 3.2, 4.3, 5.4, 6.5], False))]
"""

def can_convert(self, object):
return _is_numpy_double_array(object, dimensions = 2)

def convert(self, object, gateway_client):
sc = SparkContext._active_spark_context
flat_arr = object.ravel()
java_arr = _convert_numpy_to_java_array(flat_arr)
dense_matrix = sc._jvm.org.apache.spark.ml.linalg.DenseMatrix(object.shape[0], object.shape[1], java_arr)
matrix_udt = sc._jvm.org.apache.spark.ml.linalg.MatrixUDT()
converter = sc._jvm.org.apache.spark.sql.catalyst.CatalystTypeConverters.createToCatalystConverter(matrix_udt)
literal_matrix = sc._jvm.org.apache.spark.sql.catalyst.expressions.Literal.create(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, it seems a little weird that we explicitly wrap in a Literal here. Does that mean that this conversion doesn't work with spark.createDataFrame? Not sure I would expect that as a user.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the PySpark pathways for createDataFrame and lit are very different. Passing this as a Literal allows us to bypass the usual type-checking and conversion logic.

converter.apply(dense_matrix), matrix_udt)
return literal_matrix
2 changes: 1 addition & 1 deletion python/glow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def normalize_variant(contigName: Union[Column, str], start: Union[Column, str],

def mean_substitute(array: Union[Column, str], missingValue: Union[Column, str] = None) -> Column:
"""
Sustitutes the missing values of a numeric array using the mean of the non-missing values. Any values that are NaN, null or equal to the missing value parameter are considered missing. See :ref:`variant-data-transformations` for more details.
Substitutes the missing values of a numeric array using the mean of the non-missing values. Any values that are NaN, null or equal to the missing value parameter are considered missing. See :ref:`variant-data-transformations` for more details.

Added in version 0.4.0.

Expand Down
11 changes: 10 additions & 1 deletion python/glow/glow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from glow.conversions import OneDimensionalDoubleNumpyArrayConverter, TwoDimensionalDoubleNumpyArrayConverter
from py4j import protocol
from py4j.protocol import register_input_converter
from pyspark import SparkContext
from pyspark.sql import DataFrame, SQLContext, SparkSession
from typing import Dict
Expand Down Expand Up @@ -40,7 +43,7 @@ def transform(operation: str, df: DataFrame, arg_map: Dict[str, str]=None,

def register(session: SparkSession):
"""
Register SQL extensions for a Spark session.
Register SQL extensions and py4j converters for a Spark session.

Args:
session: Spark session
Expand All @@ -51,3 +54,9 @@ def register(session: SparkSession):
"""
assert check_argument_types()
session._jvm.io.projectglow.Glow.register(session._jsparkSession)

# Register input converters in idempotent fashion
glow_input_converters = [OneDimensionalDoubleNumpyArrayConverter, TwoDimensionalDoubleNumpyArrayConverter]
for gic in glow_input_converters:
if not any(type(pic) is gic for pic in protocol.INPUT_CONVERTER):
register_input_converter(gic(), prepend = True)
62 changes: 62 additions & 0 deletions python/glow/tests/test_conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from glow.conversions import OneDimensionalDoubleNumpyArrayConverter, TwoDimensionalDoubleNumpyArrayConverter
from importlib import reload
import numpy as np
from py4j.protocol import Py4JJavaError
from pyspark.ml.linalg import DenseMatrix
from pyspark.sql.functions import lit
from pyspark.sql.types import StringType
import pytest


def test_convert_matrix(spark):
str_list = ['a', 'b']
df = spark.createDataFrame(str_list, StringType())
ndarray = np.array([[1.0, 2.1, 3.2], [4.3, 5.4, 6.5]])
output_rows = df.withColumn("matrix", lit(ndarray)).collect()
expected_matrix = DenseMatrix(2, 3, [1.0, 2.1, 3.2, 4.3, 5.4, 6.5])
assert(output_rows[0].matrix == expected_matrix)
assert(output_rows[1].matrix == expected_matrix)


def test_convert_array(spark):
str_list = ['a', 'b']
df = spark.createDataFrame(str_list, StringType())
ndarray = np.array([1.0, 2.1, 3.2])
output_rows = df.withColumn("array", lit(ndarray)).collect()
expected_array = [1.0, 2.1, 3.2]
assert(output_rows[0].array == expected_array)
assert(output_rows[1].array == expected_array)


def test_convert_checks_dimension(spark):
# No support for 3-dimensional arrays
ndarray = np.array([[[1.]]])
with pytest.raises(Py4JJavaError):
lit(ndarray)


def test_convert_matrix_checks_type(spark):
ndarray = np.array([[1, 2], [3, 4]])
with pytest.raises(AttributeError):
lit(ndarray)


def test_convert_array_checks_type(spark):
ndarray = np.array([1, 2])
with pytest.raises(AttributeError):
lit(ndarray)


def test_register_converters_idempotent(spark):
import glow.glow
for _ in range(3):
reload(glow.glow)
one_d_converters = 0
two_d_converters = 0
for c in spark._sc._gateway._gateway_client.converters:
if type(c) is OneDimensionalDoubleNumpyArrayConverter:
one_d_converters += 1
if type(c) is TwoDimensionalDoubleNumpyArrayConverter:
two_d_converters += 1
assert(one_d_converters == 1)
assert(two_d_converters == 1)