-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11775][PYSPARK][SQL] Allow PySpark to register Java UDF #9766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.test; | ||
|
|
||
| import org.apache.spark.sql.api.java.UDF1; | ||
|
|
||
| /** | ||
| * It is used for register Java UDF from PySpark | ||
| */ | ||
| public class JavaStringLength implements UDF1<String, Integer> { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this be moved to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
| @Override | ||
| public Integer call(String str) throws Exception { | ||
| return new Integer(str.length()); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,9 @@ | |
|
|
||
| package org.apache.spark.sql | ||
|
|
||
| import java.io.IOException | ||
| import java.lang.reflect.{ParameterizedType, Type} | ||
|
|
||
| import scala.reflect.runtime.universe.TypeTag | ||
| import scala.util.Try | ||
|
|
||
|
|
@@ -29,7 +32,8 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} | |
| import org.apache.spark.sql.execution.aggregate.ScalaUDAF | ||
| import org.apache.spark.sql.execution.python.UserDefinedPythonFunction | ||
| import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} | ||
| import org.apache.spark.sql.types.DataType | ||
| import org.apache.spark.sql.types.{DataType, DataTypes} | ||
| import org.apache.spark.util.Utils | ||
|
|
||
| /** | ||
| * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. | ||
|
|
@@ -413,6 +417,84 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | |
| ////////////////////////////////////////////////////////////////////////////////////////////// | ||
| ////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
|
||
| /** | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to turn style back on here since most of this function is not auto generated.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can turn it on, but it would make the function less readable, especially for the following statements where it beyond line length limitation. |
||
| * Register a Java UDF class using reflection, for use from pyspark | ||
| * | ||
| * @param name udf name | ||
| * @param className fully qualified class name of udf | ||
| * @param returnDataType return type of udf. If it is null, spark would try to infer | ||
| * via reflection. | ||
| */ | ||
| def registerJava(name: String, className: String, returnDataType: DataType): Unit = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to make this non-public? I believe we do this in other cases for code only called from python.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
|
|
||
| try { | ||
| val clazz = Utils.classForName(className) | ||
| val udfInterfaces = clazz.getGenericInterfaces | ||
| .filter(_.isInstanceOf[ParameterizedType]) | ||
| .map(_.asInstanceOf[ParameterizedType]) | ||
| .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF")) | ||
| if (udfInterfaces.length == 0) { | ||
| throw new IOException(s"UDF class ${className} doesn't implement any UDF interface") | ||
| } else if (udfInterfaces.length > 1) { | ||
| throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") | ||
| } else { | ||
| try { | ||
| val udf = clazz.newInstance() | ||
| val udfReturnType = udfInterfaces(0).getActualTypeArguments.last | ||
| var returnType = returnDataType | ||
| if (returnType == null) { | ||
| if (udfReturnType.isInstanceOf[Class[_]]) { | ||
| returnType = udfReturnType.asInstanceOf[Class[_]].getCanonicalName match { | ||
|
||
| case "java.lang.String" => DataTypes.BooleanType | ||
| case "java.lang.Double" => DataTypes.DoubleType | ||
| case "java.lang.Float" => DataTypes.FloatType | ||
| case "java.lang.Byte" => DataTypes.ByteType | ||
| case "java.lang.Integer" => DataTypes.IntegerType | ||
| case "java.lang.Long" => DataTypes.LongType | ||
| case "java.lang.Short" => DataTypes.ShortType | ||
| case t => throw new RuntimeException("Can not infer the return type: ${udfReturnType}, please declare returnType explicitly.") | ||
| } | ||
| } else { | ||
| throw new RuntimeException("The return type of UDF is not valid, returnType:" + udfReturnType) | ||
| } | ||
| } | ||
|
|
||
| udfInterfaces(0).getActualTypeArguments.length match { | ||
| case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType) | ||
| case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType) | ||
| case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType) | ||
| case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType) | ||
| case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType) | ||
| case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType) | ||
| case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType) | ||
| case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType) | ||
| case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) | ||
| case n => logError(s"UDF class with ${n} type arguments is not supported ") | ||
| } | ||
| } catch { | ||
| case e @ (_: InstantiationException | _: IllegalArgumentException) => | ||
| logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") | ||
| } | ||
| } | ||
| } catch { | ||
| case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath") | ||
| } | ||
|
|
||
| } | ||
|
|
||
| /** | ||
| * Register a user-defined function with 1 arguments. | ||
| * @since 1.3.0 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: its a little odd to mix
return typewithreturnType. Perhaps, "When the return type is not specified we attempt to infer it using reflection"There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed