Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.connector.SimpleCounter;
import org.apache.spark.sql.connector.TestingV2Source;
import org.apache.spark.sql.connector.catalog.SessionConfigSupport;
import org.apache.spark.sql.connector.catalog.SupportsWrite;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableCapability;
Expand All @@ -50,12 +49,7 @@
* Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`.
* Each job moves files from `target/_temporary/uniqueId/` to `target`.
*/
public class JavaSimpleWritableDataSource implements TestingV2Source, SessionConfigSupport {

@Override
public String keyPrefix() {
return "javaSimpleWritableDataSource";
}
public class JavaSimpleWritableDataSource implements TestingV2Source {

static class MyScanBuilder extends JavaSimpleScanBuilder {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory, ScanBuilder}
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration

Expand All @@ -41,11 +39,7 @@ import org.apache.spark.util.SerializableConfiguration
* Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`.
* Each job moves files from `target/_temporary/uniqueId/` to `target`.
*/
class SimpleWritableDataSource extends SimpleTableProvider with SessionConfigSupport {

private val tableSchema = new StructType().add("i", "long").add("j", "long")

override def keyPrefix: String = "simpleWritableDataSource"
class SimpleWritableDataSource extends TestingV2Source {

class MyScanBuilder(path: String, conf: Configuration) extends SimpleScanBuilder {
override def planInputPartitions(): Array[InputPartition] = {
Expand All @@ -67,8 +61,6 @@ class SimpleWritableDataSource extends SimpleTableProvider with SessionConfigSup
val serializableConf = new SerializableConfiguration(conf)
new CSVReaderFactory(serializableConf)
}

override def readSchema(): StructType = tableSchema
}

class MyWriteBuilder(path: String, info: LogicalWriteInfo)
Expand Down Expand Up @@ -134,8 +126,6 @@ class SimpleWritableDataSource extends SimpleTableProvider with SessionConfigSup
private val path = options.get("path")
private val conf = SparkContext.getActive.get.hadoopConfiguration

override def schema(): StructType = tableSchema

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new MyScanBuilder(new Path(path).toUri.toString, conf)
}
Expand Down Expand Up @@ -179,7 +169,7 @@ class CSVReaderFactory(conf: SerializableConfiguration)
}
}

override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*)
override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toInt): _*)

override def close(): Unit = {
inputStream.close()
Expand Down Expand Up @@ -222,7 +212,7 @@ class CSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow]
private val out = fs.create(file)

override def write(record: InternalRow): Unit = {
out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n")
out.writeBytes(s"${record.getInt(0)},${record.getInt(1)}\n")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change just matches the TestingV2Source.schema?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea

}

override def commit(): WriterCommitMessage = {
Expand Down