From 6d5549c625d6327f1537be7228f90ed2d2a351ed Mon Sep 17 00:00:00 2001 From: Frank Yin Date: Mon, 27 Jan 2025 16:22:42 -0800 Subject: [PATCH] fix maptype doesn't use the frameless injection properly (#421) --- .../scalapb/spark/FromCatalystHelpers.scala | 2 +- .../src/test/protobuf/customizations.proto | 4 ++++ .../test/{assets => resources}/address.json | 0 .../person_null_repeated.json | 0 .../src/test/scala/PersonSpec.scala | 4 ++-- .../src/test/scala/TimestampSpec.scala | 18 +++++++++++++++++- 6 files changed, 24 insertions(+), 4 deletions(-) rename sparksql-scalapb/src/test/{assets => resources}/address.json (100%) rename sparksql-scalapb/src/test/{assets => resources}/person_null_repeated.json (100%) diff --git a/sparksql-scalapb/src/main/scala/scalapb/spark/FromCatalystHelpers.scala b/sparksql-scalapb/src/main/scala/scalapb/spark/FromCatalystHelpers.scala index 944efcb..0fffcda 100644 --- a/sparksql-scalapb/src/main/scala/scalapb/spark/FromCatalystHelpers.scala +++ b/sparksql-scalapb/src/main/scala/scalapb/spark/FromCatalystHelpers.scala @@ -97,7 +97,7 @@ trait FromCatalystHelpers { input, (in: Expression) => singleFieldValueFromCatalyst(mapEntryCmp, keyDesc, in), (in: Expression) => singleFieldValueFromCatalyst(mapEntryCmp, valDesc, in), - ProtoSQL.dataTypeFor(fd).asInstanceOf[MapType], + protoSql.dataTypeFor(fd).asInstanceOf[MapType], classOf[Vector[(Any, Any)]] ) val objs = MyCatalystToExternalMap(urobjs) diff --git a/sparksql-scalapb/src/test/protobuf/customizations.proto b/sparksql-scalapb/src/test/protobuf/customizations.proto index 5e3badc..f5911f0 100644 --- a/sparksql-scalapb/src/test/protobuf/customizations.proto +++ b/sparksql-scalapb/src/test/protobuf/customizations.proto @@ -22,3 +22,7 @@ message BothTimestampTypes { google.protobuf.Timestamp google_ts = 1; google.protobuf.Timestamp google_ts_as_sql_ts = 2 [(scalapb.field).type = "java.sql.Timestamp"]; } + +message TimestampTypesMap { + map map_field = 1; +} diff --git a/sparksql-scalapb/src/test/assets/address.json b/sparksql-scalapb/src/test/resources/address.json similarity index 100% rename from sparksql-scalapb/src/test/assets/address.json rename to sparksql-scalapb/src/test/resources/address.json diff --git a/sparksql-scalapb/src/test/assets/person_null_repeated.json b/sparksql-scalapb/src/test/resources/person_null_repeated.json similarity index 100% rename from sparksql-scalapb/src/test/assets/person_null_repeated.json rename to sparksql-scalapb/src/test/resources/person_null_repeated.json diff --git a/sparksql-scalapb/src/test/scala/PersonSpec.scala b/sparksql-scalapb/src/test/scala/PersonSpec.scala index 3c6a3ab..96c021c 100644 --- a/sparksql-scalapb/src/test/scala/PersonSpec.scala +++ b/sparksql-scalapb/src/test/scala/PersonSpec.scala @@ -310,7 +310,7 @@ class PersonSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll { } "UDFs that returns protos" should "work when reading local files" in { - val df = spark.read.json("./sparksql-scalapb/src/test/assets/address.json") + val df = spark.read.json(getClass.getResource("/address.json").toURI.toString) val returnAddress = ProtoSQL.udf { s: String => Address() } @@ -349,7 +349,7 @@ class PersonSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll { "parsing null repeated from json" should "work" in { spark.read .schema(ProtoSQL.schemaFor[Person].asInstanceOf[types.StructType]) - .json("./sparksql-scalapb/src/test/assets/person_null_repeated.json") + .json(getClass.getResource("/person_null_repeated.json").toURI.toString) .as[Person] .collect() must contain theSameElementsAs Seq( Person().withTags(Seq("foo", "bar")), diff --git a/sparksql-scalapb/src/test/scala/TimestampSpec.scala b/sparksql-scalapb/src/test/scala/TimestampSpec.scala index 776604f..cadf9a2 100644 --- a/sparksql-scalapb/src/test/scala/TimestampSpec.scala +++ b/sparksql-scalapb/src/test/scala/TimestampSpec.scala @@ -8,7 +8,8 @@ import org.scalatest.matchers.must.Matchers import scalapb.spark.test3.customizations.{ BothTimestampTypes, SQLTimestampFromGoogleTimestamp, - StructFromGoogleTimestamp + StructFromGoogleTimestamp, + TimestampTypesMap } import java.sql.{Timestamp => SQLTimestamp} @@ -158,6 +159,21 @@ class TimestampSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll { ) } + "spark.createDataset from proto messages with spark timestamp in map" should "be able to convert items with correct timestamp values" in { + import ProtoSQL.withSparkTimestamps.implicits._ + + val value = TimestampTypesMap(mapField = + Map( + "a" -> SQLTimestampFromGoogleTimestamp(googleTsAsSqlTs = Some(sqlTimestampMicrosPrecision)) + ) + ) + val ds: Dataset[TimestampTypesMap] = spark.createDataset(Seq(value)) + + ds.collect() must contain theSameElementsAs Seq( + value + ) + } + "df with case class timestamp as well as both types of google timestamp" should "not have StructType for timestamps" in { import ProtoSQL.withSparkTimestamps.implicits._