Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 56 additions & 54 deletions python/pyspark/sql/connect/proto/types_pb2.py

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions python/pyspark/sql/connect/proto/types_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,40 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...

class Time(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

PRECISION_FIELD_NUMBER: builtins.int
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
precision: builtins.int
type_variation_reference: builtins.int
def __init__(
self,
*,
precision: builtins.int | None = ...,
type_variation_reference: builtins.int = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_precision", b"_precision", "precision", b"precision"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_precision",
b"_precision",
"precision",
b"precision",
"type_variation_reference",
b"type_variation_reference",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_precision", b"_precision"]
) -> typing_extensions.Literal["precision"] | None: ...

class CalendarInterval(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down Expand Up @@ -788,6 +822,7 @@ class DataType(google.protobuf.message.Message):
VARIANT_FIELD_NUMBER: builtins.int
UDT_FIELD_NUMBER: builtins.int
UNPARSED_FIELD_NUMBER: builtins.int
TIME_FIELD_NUMBER: builtins.int
@property
def null(self) -> global___DataType.NULL: ...
@property
Expand Down Expand Up @@ -845,6 +880,8 @@ class DataType(google.protobuf.message.Message):
@property
def unparsed(self) -> global___DataType.Unparsed:
"""UnparsedDataType"""
@property
def time(self) -> global___DataType.Time: ...
def __init__(
self,
*,
Expand Down Expand Up @@ -873,6 +910,7 @@ class DataType(google.protobuf.message.Message):
variant: global___DataType.Variant | None = ...,
udt: global___DataType.UDT | None = ...,
unparsed: global___DataType.Unparsed | None = ...,
time: global___DataType.Time | None = ...,
) -> None: ...
def HasField(
self,
Expand Down Expand Up @@ -915,6 +953,8 @@ class DataType(google.protobuf.message.Message):
b"string",
"struct",
b"struct",
"time",
b"time",
"timestamp",
b"timestamp",
"timestamp_ntz",
Expand Down Expand Up @@ -972,6 +1012,8 @@ class DataType(google.protobuf.message.Message):
b"string",
"struct",
b"struct",
"time",
b"time",
"timestamp",
b"timestamp",
"timestamp_ntz",
Expand Down Expand Up @@ -1017,6 +1059,7 @@ class DataType(google.protobuf.message.Message):
"variant",
"udt",
"unparsed",
"time",
]
| None
): ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,8 @@ object TimeFormatter {
def apply(isParsing: Boolean): TimeFormatter = {
getFormatter(None, defaultLocale, isParsing)
}

def getFractionFormatter(): TimeFormatter = {
new FractionTimeFormatter()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ private[sql] object ArrowUtils {
case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
case TimestampNTZType =>
new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
case _: TimeType => new ArrowType.Time(TimeUnit.NANOSECOND, 8 * 8)
case NullType => ArrowType.Null.INSTANCE
case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND)
Expand Down Expand Up @@ -88,6 +89,8 @@ private[sql] object ArrowUtils {
if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null =>
TimestampNTZType
case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType
case t: ArrowType.Time if t.getUnit == TimeUnit.NANOSECOND =>
TimeType(TimeType.MICROS_PRECISION)
case ArrowType.Null.INSTANCE => NullType
case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH =>
YearMonthIntervalType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ void initAccessor(ValueVector vector) {
accessor = new TimestampAccessor(timeStampMicroTZVector);
} else if (vector instanceof TimeStampMicroVector timeStampMicroVector) {
accessor = new TimestampNTZAccessor(timeStampMicroVector);
} else if (vector instanceof TimeNanoVector timeNanoVector) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a stupid question, when is these method used? why should we change arrow related method when adding connect dataType?

Copy link
Contributor Author

@peter-toth peter-toth Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it isn't a stupid question. Arrow is used for streaming data between connect server and clients: (https://spark.apache.org/docs/latest/spark-connect-overview.html#how-spark-connect-works)

accessor = new TimeNanoAccessor(timeNanoVector);
} else if (vector instanceof MapVector mapVector) {
accessor = new MapAccessor(mapVector);
} else if (vector instanceof ListVector listVector) {
Expand Down Expand Up @@ -522,6 +524,21 @@ final long getLong(int rowId) {
}
}

static class TimeNanoAccessor extends ArrowVectorAccessor {

private final TimeNanoVector accessor;

TimeNanoAccessor(TimeNanoVector vector) {
super(vector);
this.accessor = vector;
}

@Override
final long getLong(int rowId) {
return accessor.get(rowId);
}
}

static class ArrayAccessor extends ArrowVectorAccessor {

private final ListVector accessor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ object ArrowWriter {
case (DateType, vector: DateDayVector) => new DateWriter(vector)
case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector)
case (TimestampNTZType, vector: TimeStampMicroVector) => new TimestampNTZWriter(vector)
case (_: TimeType, vector: TimeNanoVector) => new TimeWriter(vector)
case (ArrayType(_, _), vector: ListVector) =>
val elementVector = createFieldWriter(vector.getDataVector())
new ArrayWriter(vector, elementVector)
Expand Down Expand Up @@ -359,6 +360,18 @@ private[arrow] class TimestampNTZWriter(
}
}

private[arrow] class TimeWriter(
val valueVector: TimeNanoVector) extends ArrowFieldWriter {

override def setNull(): Unit = {
valueVector.setNull(count)
}

override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
valueVector.setSafe(count, input.getLong(ordinal))
}
}

private[arrow] class ArrayWriter(
val valueVector: ListVector,
val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.connect

import java.io.{ByteArrayOutputStream, PrintStream}
import java.nio.file.Files
import java.time.DateTimeException
import java.time.{DateTimeException, LocalTime}
import java.util.Properties

import scala.collection.mutable
Expand Down Expand Up @@ -1670,6 +1670,12 @@ class ClientE2ETestSuite
}
checkAnswer(df, (0 until 6).map(i => Row(i)))
}

test("SPARK-52770: Support Time type") {
val df = spark.sql("SELECT TIME '12:13:14'")

checkAnswer(df, Row(LocalTime.of(12, 13, 14)))
}
}

private[sql] case class ClassData(a: String, b: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ message DataType {

// UnparsedDataType
Unparsed unparsed = 24;

Time time = 28;
Copy link
Member

@dongjoon-hyun dongjoon-hyun Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding. Do you know where is 25, 26, 27, @peter-toth ?

Copy link
Contributor Author

@peter-toth peter-toth Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately these are not in order. A few lines up:

  Variant variant = 25;

and down:

  // Reserved for geometry and geography types
  reserved 26, 27;

}

// Reserved for geometry and geography types
Expand Down Expand Up @@ -127,6 +129,11 @@ message DataType {
uint32 type_variation_reference = 1;
}

message Time {
optional int32 precision = 1;
uint32 type_variation_reference = 2;
}

message CalendarInterval {
uint32 type_variation_reference = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ object ArrowDeserializers {
new LeafFieldDeserializer[LocalDateTime](encoder, v, timeZoneId) {
override def value(i: Int): LocalDateTime = reader.getLocalDateTime(i)
}
case (LocalTimeEncoder, v: FieldVector) =>
new LeafFieldDeserializer[LocalTime](encoder, v, timeZoneId) {
override def value(i: Int): LocalTime = reader.getLocalTime(i)
}

case (OptionEncoder(value), v) =>
val deserializer = deserializerFor(value, v, timeZoneId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.io.{ByteArrayOutputStream, OutputStream}
import java.lang.invoke.{MethodHandles, MethodType}
import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInteger}
import java.nio.channels.Channels
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime, Period}
import java.util.{Map => JMap, Objects}

import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -392,6 +392,11 @@ object ArrowSerializer {
override def set(index: Int, value: LocalDateTime): Unit =
vector.setSafe(index, SparkDateTimeUtils.localDateTimeToMicros(value))
}
case (LocalTimeEncoder, v: TimeNanoVector) =>
new FieldSerializer[LocalTime, TimeNanoVector](v) {
override def set(index: Int, value: LocalTime): Unit =
vector.setSafe(index, SparkDateTimeUtils.localTimeToNanos(value))
}

case (OptionEncoder(value), v) =>
new Serializer {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package org.apache.spark.sql.connect.client.arrow

import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Timestamp}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneOffset}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime, Period, ZoneOffset}

import org.apache.arrow.vector._
import org.apache.arrow.vector.util.Text

import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkIntervalUtils, SparkStringUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkIntervalUtils, SparkStringUtils, TimeFormatter, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._
Expand Down Expand Up @@ -59,6 +59,7 @@ private[arrow] abstract class ArrowVectorReader {
def getInstant(i: Int): java.time.Instant = unsupported()
def getLocalDate(i: Int): java.time.LocalDate = unsupported()
def getLocalDateTime(i: Int): java.time.LocalDateTime = unsupported()
def getLocalTime(i: Int): java.time.LocalTime = unsupported()
private def unsupported(): Nothing = throw new UnsupportedOperationException()
}

Expand Down Expand Up @@ -90,6 +91,7 @@ object ArrowVectorReader {
case v: DateDayVector => new DateDayVectorReader(v, timeZoneId)
case v: TimeStampMicroTZVector => new TimeStampMicroTZVectorReader(v)
case v: TimeStampMicroVector => new TimeStampMicroVectorReader(v, timeZoneId)
case v: TimeNanoVector => new TimeVectorReader(v)
case _: NullVector => NullVectorReader
case _ => throw new RuntimeException("Unsupported Vector Type: " + vector.getClass)
}
Expand Down Expand Up @@ -275,3 +277,11 @@ private[arrow] class TimeStampMicroVectorReader(v: TimeStampMicroVector, timeZon
override def getLocalDateTime(i: Int): LocalDateTime = microsToLocalDateTime(utcMicros(i))
override def getString(i: Int): String = formatter.format(utcMicros(i))
}

private[arrow] class TimeVectorReader(v: TimeNanoVector)
extends TypedArrowVectorReader[TimeNanoVector](v) {
private lazy val formatter = TimeFormatter.getFractionFormatter()
private def nanos(i: Int): Long = vector.get(i)
override def getLocalTime(i: Int): LocalTime = nanosToLocalTime(nanos(i))
override def getString(i: Int): String = formatter.format(nanos(i))
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ object DataTypeProtoConverter {
case proto.DataType.KindCase.DATE => DateType
case proto.DataType.KindCase.TIMESTAMP => TimestampType
case proto.DataType.KindCase.TIMESTAMP_NTZ => TimestampNTZType
case proto.DataType.KindCase.TIME =>
if (t.getTime.hasPrecision) {
TimeType(t.getTime.getPrecision)
} else {
TimeType()
}

case proto.DataType.KindCase.CALENDAR_INTERVAL => CalendarIntervalType
case proto.DataType.KindCase.YEAR_MONTH_INTERVAL =>
Expand Down Expand Up @@ -204,6 +210,12 @@ object DataTypeProtoConverter {

case TimestampNTZType => ProtoDataTypes.TimestampNTZType

case TimeType(precision) =>
proto.DataType
.newBuilder()
.setTime(proto.DataType.Time.newBuilder().setPrecision(precision).build())
.build()

case CalendarIntervalType => ProtoDataTypes.CalendarIntervalType

case YearMonthIntervalType(startField, endField) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.io.{ByteArrayOutputStream, DataOutputStream, File}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.time.LocalTime
import java.util.Locale

import com.google.common.io.Files
Expand Down Expand Up @@ -731,6 +732,43 @@ class ArrowConvertersSuite extends SharedSparkSession {
}
}

test("time type conversion") {
val json =
s"""
|{
| "schema" : {
| "fields" : [ {
| "name" : "time",
| "type" : {
| "name" : "time",
| "unit" : "NANOSECOND",
| "bitWidth" : 64
| },
| "nullable" : true,
| "children" : [ ]
| } ]
| },
| "batches" : [ {
| "count" : 3,
| "columns" : [ {
| "name" : "time",
| "count" : 3,
| "VALIDITY" : [ 1, 1, 1 ],
| "DATA" : [ 0, 43200000000000, 3723123456789 ]
| } ]
| } ]
|}
""".stripMargin

val t1 = LocalTime.of(0, 0, 0)
val t2 = LocalTime.of(12, 0, 0)
val t3 = LocalTime.of(1, 2, 3, 123456789)

val df = Seq(t1, t2, t3).toDF("time")

collectAndValidate(df, json, "timeData.json")
}

test("floating-point NaN") {
val json =
s"""
Expand Down
Loading