Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.types import Row, StringType
from pyspark.sql.types import IntegerType, Row, StringType
from pyspark.sql.utils import install_exception_handler

__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
Expand Down Expand Up @@ -202,6 +202,32 @@ def registerFunction(self, name, f, returnType=StringType()):
"""
self.sparkSession.catalog.registerFunction(name, f, returnType)

@ignore_unicode_prefix
@since(2.1)
def registerJavaFunction(self, name, javaClassName, returnType=None):
"""Register a java UDF so it can be used in SQL statements.

In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not specified we would infer it via reflection.
:param name: name of the UDF
:param javaClassName: fully qualified name of java class
:param returnType: a :class:`pyspark.sql.types.DataType` object

>>> sqlContext.registerJavaFunction("javaStringLength",
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
>>> sqlContext.sql("SELECT javaStringLength('test')").collect()
[Row(UDF(test)=4)]
>>> sqlContext.registerJavaFunction("javaStringLength2",
... "test.org.apache.spark.sql.JavaStringLength")
>>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
[Row(UDF(test)=4)]

"""
jdt = None
if returnType is not None:
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)

# TODO(andrew): delete this once we refactor things to take in SparkSession
def _inferSchema(self, rdd, samplingRatio=None):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ object JavaTypeInference {
* @param typeToken Java type
* @return (SQL data type, nullable)
*/
private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
private[sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
typeToken.getRawType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
Expand Down
75 changes: 73 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,25 @@

package org.apache.spark.sql

import java.io.IOException
import java.lang.reflect.{ParameterizedType, Type}

import scala.reflect.runtime.universe.TypeTag
import scala.util.Try

import com.google.common.reflect.TypeToken

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.internal.Logging
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DataTypes}
import org.apache.spark.util.Utils

/**
* Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this.
Expand Down Expand Up @@ -413,6 +419,71 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////

/**
Copy link
Contributor

@marmbrus marmbrus Oct 12, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to turn style back on here since most of this function is not auto generated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can turn it on, but it would make the function less readable, especially for the following statements where it beyond line length limitation.

case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)

* Register a Java UDF class using reflection, for use from pyspark
*
* @param name udf name
* @param className fully qualified class name of udf
* @param returnDataType return type of udf. If it is null, spark would try to infer
* via reflection.
*/
private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = {

try {
val clazz = Utils.classForName(className)
val udfInterfaces = clazz.getGenericInterfaces
.filter(_.isInstanceOf[ParameterizedType])
.map(_.asInstanceOf[ParameterizedType])
.filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF"))
if (udfInterfaces.length == 0) {
throw new IOException(s"UDF class ${className} doesn't implement any UDF interface")
} else if (udfInterfaces.length > 1) {
throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}")
} else {
try {
val udf = clazz.newInstance()
val udfReturnType = udfInterfaces(0).getActualTypeArguments.last
var returnType = returnDataType
if (returnType == null) {
returnType = JavaTypeInference.inferDataType(TypeToken.of(udfReturnType))._1
}

udfInterfaces(0).getActualTypeArguments.length match {
case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType)
case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType)
case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType)
case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType)
case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType)
case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType)
case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType)
case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType)
case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType)
case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType)
case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case n => logError(s"UDF class with ${n} type arguments is not supported ")
}
} catch {
case e @ (_: InstantiationException | _: IllegalArgumentException) =>
logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
}
}
} catch {
case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath")
}

}

/**
* Register a user-defined function with 1 arguments.
* @since 1.3.0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package test.org.apache.spark.sql;

import org.apache.spark.sql.api.java.UDF1;

/**
* It is used for register Java UDF from PySpark
*/
public class JavaStringLength implements UDF1<String, Integer> {
@Override
public Integer call(String str) throws Exception {
return new Integer(str.length());
}
}
21 changes: 21 additions & 0 deletions sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,25 @@ public Integer call(String str1, String str2) {
Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
Assert.assertEquals(9, result.getInt(0));
}

public static class StringLengthTest implements UDF2<String, String, Integer> {
@Override
public Integer call(String str1, String str2) throws Exception {
return new Integer(str1.length() + str2.length());
}
}

@SuppressWarnings("unchecked")
@Test
public void udf3Test() {
spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(),
DataTypes.IntegerType);
Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
Assert.assertEquals(9, result.getInt(0));

// returnType is not provided
spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null);
result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
Assert.assertEquals(9, result.getInt(0));
}
}