Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,28 @@
package org.apache.spark.sql.hive

import scala.collection.JavaConverters._
import scala.util.Random

import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StandardListObjectInspector}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory

import org.apache.spark.sql.{QueryTest, RandomDataGenerator, Row}
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT}
import org.apache.spark.sql.types.StructType


class HiveUserDefinedTypeSuite extends QueryTest with TestHiveSingleton {
private val functionClass = classOf[org.apache.spark.sql.hive.TestUDF].getCanonicalName

test("Support UDT in Hive UDF") {
val rand = new Random
val functionName = "get_point_x"
try {
val schema = new StructType().add("point", new ExamplePointUDT, nullable = false)
val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get
val input = inputGenerator.apply().asInstanceOf[Row]
val input = Row.fromSeq(Seq(new ExamplePoint(rand.nextDouble(), rand.nextDouble())))
val df = spark.createDataFrame(Array(input).toList.asJava, schema)
df.createOrReplaceTempView("src")
spark.sql(s"CREATE FUNCTION $functionName AS '$functionClass'")
Expand Down