Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vpj] Fix DATASET_CHANGED tracking for Spark jobs #1513

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,54 +113,19 @@ public void configure(VeniceProperties props, PushJobSetting pushJobSetting) {
this.props = props;
this.pushJobSetting = pushJobSetting;
setupDefaultSparkSessionForDataWriterJob(pushJobSetting, props);
setupSparkDataWriterJobFlow(pushJobSetting);
}

private void setupSparkDataWriterJobFlow(PushJobSetting pushJobSetting) {
ExpressionEncoder<Row> rowEncoder = RowEncoder.apply(DEFAULT_SCHEMA);
ExpressionEncoder<Row> rowEncoderWithPartition = RowEncoder.apply(DEFAULT_SCHEMA_WITH_PARTITION);
int numOutputPartitions = pushJobSetting.partitionCount;

// Load data from input path
Dataset<Row> dataFrameForDataWriterJob = getInputDataFrame();
Objects.requireNonNull(dataFrameForDataWriterJob, "The input data frame cannot be null");
this.dataFrame = getInputDataFrame();
Objects.requireNonNull(this.dataFrame, "The input data frame cannot be null");
validateDataFrameSchema(this.dataFrame);

Properties jobProps = new Properties();
sparkSession.conf().getAll().foreach(entry -> jobProps.setProperty(entry._1, entry._2));
if (pushJobSetting.materializedViewConfigFlatMap != null) {
jobProps.put(PUSH_JOB_VIEW_CONFIGS, pushJobSetting.materializedViewConfigFlatMap);
}
JavaSparkContext sparkContext = JavaSparkContext.fromSparkContext(sparkSession.sparkContext());
Broadcast<Properties> broadcastProperties = sparkContext.broadcast(jobProps);
accumulatorsForDataWriterJob = new DataWriterAccumulators(sparkSession);
taskTracker = new SparkDataWriterTaskTracker(accumulatorsForDataWriterJob);

// Validate the schema of the input data
validateDataFrameSchema(dataFrameForDataWriterJob);

// Convert all rows to byte[], byte[] pairs (compressed if compression is enabled)
// We could have worked with "map", but because of spraying all PartitionWriters, we need to use "flatMap"
dataFrameForDataWriterJob = dataFrameForDataWriterJob
.flatMap(new SparkInputRecordProcessorFactory(broadcastProperties, accumulatorsForDataWriterJob), rowEncoder);

// TODO: Add map-side combiner to reduce the data size before shuffling

// Partition the data using the custom partitioner and sort the data within that partition
dataFrameForDataWriterJob = SparkPartitionUtils.repartitionAndSortWithinPartitions(
dataFrameForDataWriterJob,
new VeniceSparkPartitioner(broadcastProperties, numOutputPartitions),
new PartitionSorter());

// Add a partition column to all rows based on the custom partitioner
dataFrameForDataWriterJob =
dataFrameForDataWriterJob.withColumn(PARTITION_COLUMN_NAME, functions.spark_partition_id());

// Write the data to PubSub
dataFrameForDataWriterJob = dataFrameForDataWriterJob.mapPartitions(
new SparkPartitionWriterFactory(broadcastProperties, accumulatorsForDataWriterJob),
rowEncoderWithPartition);

this.dataFrame = dataFrameForDataWriterJob;
}

/**
Expand Down Expand Up @@ -372,8 +337,38 @@ public PushJobSetting getPushJobSetting() {

@Override
protected void runComputeJob() {
ExpressionEncoder<Row> rowEncoder = RowEncoder.apply(DEFAULT_SCHEMA);
ExpressionEncoder<Row> rowEncoderWithPartition = RowEncoder.apply(DEFAULT_SCHEMA_WITH_PARTITION);
int numOutputPartitions = pushJobSetting.partitionCount;

Properties jobProps = new Properties();
this.sparkSession.conf().getAll().foreach(entry -> jobProps.setProperty(entry._1, entry._2));
JavaSparkContext sparkContext = JavaSparkContext.fromSparkContext(sparkSession.sparkContext());
Broadcast<Properties> broadcastProperties = sparkContext.broadcast(jobProps);

LOGGER.info("Triggering Spark job for data writer");
try {
// Convert all rows to byte[], byte[] pairs (compressed if compression is enabled)
// We could have worked with "map", but because of spraying all PartitionWriters, we need to use "flatMap"
dataFrame = dataFrame
.flatMap(new SparkInputRecordProcessorFactory(broadcastProperties, accumulatorsForDataWriterJob), rowEncoder);

// TODO: Add map-side combiner to reduce the data size before shuffling

// Partition the data using the custom partitioner and sort the data within that partition
dataFrame = SparkPartitionUtils.repartitionAndSortWithinPartitions(
dataFrame,
new VeniceSparkPartitioner(broadcastProperties, numOutputPartitions),
new PartitionSorter());

// Add a partition column to all rows based on the custom partitioner
dataFrame = dataFrame.withColumn(PARTITION_COLUMN_NAME, functions.spark_partition_id());

// Write the data to PubSub
dataFrame = dataFrame.mapPartitions(
new SparkPartitionWriterFactory(broadcastProperties, accumulatorsForDataWriterJob),
rowEncoderWithPartition);

// For VPJ, we don't care about the output from the DAG. ".count()" is an action that will trigger execution of
// the DAG to completion and will not copy all the rows to the driver to be more memory efficient.
dataFrame.count();
Expand Down
Loading