diff --git a/hudi-examples/hudi-examples-spark/src/test/java/org/apache/hudi/examples/quickstart/TestHoodieSparkQuickstart.java b/hudi-examples/hudi-examples-spark/src/test/java/org/apache/hudi/examples/quickstart/TestHoodieSparkQuickstart.java index b9ab12046058..32c51788ee85 100644 --- a/hudi-examples/hudi-examples-spark/src/test/java/org/apache/hudi/examples/quickstart/TestHoodieSparkQuickstart.java +++ b/hudi-examples/hudi-examples-spark/src/test/java/org/apache/hudi/examples/quickstart/TestHoodieSparkQuickstart.java @@ -24,6 +24,7 @@ import org.apache.hudi.common.model.HoodieAvroPayload; import org.apache.hudi.examples.common.HoodieExampleDataGenerator; import org.apache.hudi.testutils.providers.SparkProvider; + import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SQLContext; diff --git a/hudi-examples/hudi-examples-spark/src/test/python/HoodiePySparkQuickstart.py b/hudi-examples/hudi-examples-spark/src/test/python/HoodiePySparkQuickstart.py new file mode 100644 index 000000000000..c3be6a176c9b --- /dev/null +++ b/hudi-examples/hudi-examples-spark/src/test/python/HoodiePySparkQuickstart.py @@ -0,0 +1,266 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +import os +from pyspark import sql +import random +from pyspark.sql.functions import lit +from functools import reduce +import tempfile +import argparse + + +class ExamplePySpark: + def __init__(self, spark: sql.SparkSession, tableName: str, basePath: str): + self.spark = spark + self.tableName = tableName + self.basePath = basePath + "/" + tableName + self.hudi_options = { + 'hoodie.table.name': tableName, + 'hoodie.datasource.write.recordkey.field': 'uuid', + 'hoodie.datasource.write.partitionpath.field': 'partitionpath', + 'hoodie.datasource.write.table.name': tableName, + 'hoodie.datasource.write.operation': 'upsert', + 'hoodie.datasource.write.precombine.field': 'ts', + 'hoodie.upsert.shuffle.parallelism': 2, + 'hoodie.insert.shuffle.parallelism': 2 + } + + self.dataGen = spark._jvm.org.apache.hudi.QuickstartUtils.DataGenerator() + self.snapshotQuery = "SELECT begin_lat, begin_lon, driver, end_lat, end_lon, fare, partitionpath, rider, ts, uuid FROM hudi_trips_snapshot" + return + + def runQuickstart(self): + + def snap(): + return self.spark.sql(self.snapshotQuery) + insertDf = self.insertData() + self.queryData() + assert len(insertDf.exceptAll(snap()).collect()) == 0 + + snapshotBeforeUpdate = snap() + updateDf = self.updateData() + self.queryData() + assert len(snap().intersect(updateDf).collect()) == len(updateDf.collect()) + assert len(snap().exceptAll(updateDf).exceptAll(snapshotBeforeUpdate).collect()) == 0 + + + self.timeTravelQuery() + self.incrementalQuery() + self.pointInTimeQuery() + + self.softDeletes() + self.queryData() + + snapshotBeforeDelete = snap() + deletesDf = self.hardDeletes() + self.queryData() + assert len(snap().select(["uuid", "partitionpath", "ts"]).intersect(deletesDf).collect()) == 0 + assert len(snapshotBeforeDelete.exceptAll(snap()).exceptAll(snapshotBeforeDelete).collect()) == 0 + + snapshotBeforeInsertOverwrite = snap() + insertOverwriteDf = self.insertOverwrite() + self.queryData() + withoutSanFran = snapshotBeforeInsertOverwrite.filter("partitionpath != 'americas/united_states/san_francisco'") + expectedDf = withoutSanFran.union(insertOverwriteDf) + assert len(snap().exceptAll(expectedDf).collect()) == 0 + return + + def insertData(self): + print("Insert Data") + inserts = self.spark._jvm.org.apache.hudi.QuickstartUtils.convertToStringList(self.dataGen.generateInserts(10)) + df = self.spark.read.json(self.spark.sparkContext.parallelize(inserts, 2)) + df.write.format("hudi").options(**self.hudi_options).mode("overwrite").save(self.basePath) + return df + + def updateData(self): + print("Update Data") + updates = self.spark._jvm.org.apache.hudi.QuickstartUtils.convertToStringList(self.dataGen.generateUniqueUpdates(5)) + df = self.spark.read.json(spark.sparkContext.parallelize(updates, 2)) + df.write.format("hudi").options(**self.hudi_options).mode("append").save(self.basePath) + return df + + def queryData(self): + print("Query Data") + tripsSnapshotDF = self.spark.read.format("hudi").load(self.basePath) + tripsSnapshotDF.createOrReplaceTempView("hudi_trips_snapshot") + self.spark.sql("SELECT fare, begin_lon, begin_lat, ts FROM hudi_trips_snapshot WHERE fare > 20.0").show() + self.spark.sql("SELECT _hoodie_commit_time, _hoodie_record_key, _hoodie_partition_path, rider, driver, fare FROM hudi_trips_snapshot").show() + return + + def timeTravelQuery(self): + query = "SELECT begin_lat, begin_lon, driver, end_lat, end_lon, fare, partitionpath, rider, ts, uuid FROM time_travel_query" + print("Time Travel Query") + self.spark.read.format("hudi").option("as.of.instant", "20210728141108").load(self.basePath).createOrReplaceTempView("time_travel_query") + self.spark.sql(query) + self.spark.read.format("hudi").option("as.of.instant", "2021-07-28 14:11:08.000").load(self.basePath).createOrReplaceTempView("time_travel_query") + self.spark.sql(query) + self.spark.read.format("hudi").option("as.of.instant", "2021-07-28").load(self.basePath).createOrReplaceTempView("time_travel_query") + self.spark.sql(query) + return + + def incrementalQuery(self): + print("Incremental Query") + self.spark.read.format("hudi").load(self.basePath).createOrReplaceTempView("hudi_trips_snapshot") + self.commits = list(map(lambda row: row[0], self.spark.sql("SELECT DISTINCT(_hoodie_commit_time) AS commitTime FROM hudi_trips_snapshot ORDER BY commitTime").limit(50).collect())) + beginTime = self.commits[len(self.commits) - 2] + incremental_read_options = { + 'hoodie.datasource.query.type': 'incremental', + 'hoodie.datasource.read.begin.instanttime': beginTime, + } + tripsIncrementalDF = self.spark.read.format("hudi").options(**incremental_read_options).load(self.basePath) + tripsIncrementalDF.createOrReplaceTempView("hudi_trips_incremental") + self.spark.sql("SELECT `_hoodie_commit_time`, fare, begin_lon, begin_lat, ts FROM hudi_trips_incremental WHERE fare > 20.0").show() + + def pointInTimeQuery(self): + print("Point-in-time Query") + beginTime = "000" + endTime = self.commits[len(self.commits) - 2] + point_in_time_read_options = { + 'hoodie.datasource.query.type': 'incremental', + 'hoodie.datasource.read.end.instanttime': endTime, + 'hoodie.datasource.read.begin.instanttime': beginTime + } + + tripsPointInTimeDF = self.spark.read.format("hudi").options(**point_in_time_read_options).load(self.basePath) + tripsPointInTimeDF.createOrReplaceTempView("hudi_trips_point_in_time") + self.spark.sql("SELECT `_hoodie_commit_time`, fare, begin_lon, begin_lat, ts FROM hudi_trips_point_in_time WHERE fare > 20.0").show() + + def softDeletes(self): + print("Soft Deletes") + spark.read.format("hudi").load(self.basePath).createOrReplaceTempView("hudi_trips_snapshot") + + # fetch total records count + trip_count = spark.sql("SELECT uuid, partitionpath FROM hudi_trips_snapshot").count() + non_null_rider_count = spark.sql("SELECT uuid, partitionpath FROM hudi_trips_snapshot WHERE rider IS NOT null").count() + print(f"trip count: {trip_count}, non null rider count: {non_null_rider_count}") + # fetch two records for soft deletes + soft_delete_ds = spark.sql("SELECT * FROM hudi_trips_snapshot").limit(2) + # prepare the soft deletes by ensuring the appropriate fields are nullified + meta_columns = ["_hoodie_commit_time", "_hoodie_commit_seqno", "_hoodie_record_key", + "_hoodie_partition_path", "_hoodie_file_name"] + excluded_columns = meta_columns + ["ts", "uuid", "partitionpath"] + nullify_columns = list(filter(lambda field: field[0] not in excluded_columns, \ + list(map(lambda field: (field.name, field.dataType), soft_delete_ds.schema.fields)))) + + hudi_soft_delete_options = { + 'hoodie.table.name': self.tableName, + 'hoodie.datasource.write.recordkey.field': 'uuid', + 'hoodie.datasource.write.partitionpath.field': 'partitionpath', + 'hoodie.datasource.write.table.name': self.tableName, + 'hoodie.datasource.write.operation': 'upsert', + 'hoodie.datasource.write.precombine.field': 'ts', + 'hoodie.upsert.shuffle.parallelism': 2, + 'hoodie.insert.shuffle.parallelism': 2 + } + + soft_delete_df = reduce(lambda df,col: df.withColumn(col[0], lit(None).cast(col[1])), \ + nullify_columns, reduce(lambda df,col: df.drop(col[0]), meta_columns, soft_delete_ds)) + + # simply upsert the table after setting these fields to null + soft_delete_df.write.format("hudi").options(**hudi_soft_delete_options).mode("append").save(self.basePath) + + # reload data + self.spark.read.format("hudi").load(self.basePath).createOrReplaceTempView("hudi_trips_snapshot") + + # This should return the same total count as before + trip_count = self.spark.sql("SELECT uuid, partitionpath FROM hudi_trips_snapshot").count() + # This should return (total - 2) count as two records are updated with nulls + non_null_rider_count = self.spark.sql("SELECT uuid, partitionpath FROM hudi_trips_snapshot WHERE rider IS NOT null").count() + print(f"trip count: {trip_count}, non null rider count: {non_null_rider_count}") + + def hardDeletes(self): + print("Hard Deletes") + # fetch total records count + total_count = self.spark.sql("SELECT uuid, partitionpath FROM hudi_trips_snapshot").count() + print(f"total count: {total_count}") + # fetch two records to be deleted + ds = self.spark.sql("SELECT uuid, partitionpath FROM hudi_trips_snapshot").limit(2) + + # issue deletes + hudi_hard_delete_options = { + 'hoodie.table.name': self.tableName, + 'hoodie.datasource.write.recordkey.field': 'uuid', + 'hoodie.datasource.write.partitionpath.field': 'partitionpath', + 'hoodie.datasource.write.table.name': self.tableName, + 'hoodie.datasource.write.operation': 'delete', + 'hoodie.datasource.write.precombine.field': 'ts', + 'hoodie.upsert.shuffle.parallelism': 2, + 'hoodie.insert.shuffle.parallelism': 2 + } + + deletes = list(map(lambda row: (row[0], row[1]), ds.collect())) + hard_delete_df = self.spark.sparkContext.parallelize(deletes).toDF(['uuid', 'partitionpath']).withColumn('ts', lit(0.0)) + hard_delete_df.write.format("hudi").options(**hudi_hard_delete_options).mode("append").save(self.basePath) + + # run the same read query as above. + roAfterDeleteViewDF = self.spark.read.format("hudi").load(self.basePath) + roAfterDeleteViewDF.createOrReplaceTempView("hudi_trips_snapshot") + # fetch should return (total - 2) records + total_count = self.spark.sql("SELECT uuid, partitionpath FROM hudi_trips_snapshot").count() + print(f"total count: {total_count}") + return hard_delete_df + + def insertOverwrite(self): + print("Insert Overwrite") + self.spark.read.format("hudi").load(self.basePath).select(["uuid","partitionpath"]).sort(["partitionpath", "uuid"]).show(n=100,truncate=False) + inserts = self.spark._jvm.org.apache.hudi.QuickstartUtils.convertToStringList(self.dataGen.generateInserts(10)) + df = self.spark.read.json(self.spark.sparkContext.parallelize(inserts, 2)).filter("partitionpath = 'americas/united_states/san_francisco'") + hudi_insert_overwrite_options = { + 'hoodie.table.name': self.tableName, + 'hoodie.datasource.write.recordkey.field': 'uuid', + 'hoodie.datasource.write.partitionpath.field': 'partitionpath', + 'hoodie.datasource.write.table.name': self.tableName, + 'hoodie.datasource.write.operation': 'insert_overwrite', + 'hoodie.datasource.write.precombine.field': 'ts', + 'hoodie.upsert.shuffle.parallelism': 2, + 'hoodie.insert.shuffle.parallelism': 2 + } + df.write.format("hudi").options(**hudi_insert_overwrite_options).mode("append").save(self.basePath) + self.spark.read.format("hudi").load(self.basePath).select(["uuid","partitionpath"]).sort(["partitionpath", "uuid"]).show(n=100,truncate=False) + return df + +if __name__ == "__main__": + random.seed(46474747) + parser = argparse.ArgumentParser(description="Examples of various operations to perform on Hudi with PySpark",formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-t", "--table", action="store", required=True, help="the name of the table to create") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("-p", "--package", action="store", help="the name of the hudi-spark-bundle package\n eg. \"org.apache.hudi:hudi-spark3.3-bundle_2.12:0.12.0\"") + group.add_argument("-j", "--jar", action="store", help="the full path to hudi-spark-bundle .jar file\n eg. \"[HUDI_BASE_PATH]/packaging/hudi-spark-bundle/target/hudi-spark-bundle[VERSION].jar\"") + args = vars(parser.parse_args()) + package = args["package"] + jar = args["jar"] + if package != None: + os.environ["PYSPARK_SUBMIT_ARGS"] = f"--packages {package} pyspark-shell" + elif "jar" != None: + os.environ["PYSPARK_SUBMIT_ARGS"] = f"--jars {jar} pyspark-shell" + + with tempfile.TemporaryDirectory() as tmpdirname: + spark = sql.SparkSession \ + .builder \ + .appName("Hudi Spark basic example") \ + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \ + .config("spark.kryoserializer.buffer.max", "512m") \ + .config("spark.sql.extensions", "org.apache.spark.sql.hudi.HoodieSparkSessionExtension") \ + .getOrCreate() + ps = ExamplePySpark(spark,args["table"],tmpdirname) + ps.runQuickstart() + + + + diff --git a/hudi-examples/hudi-examples-spark/src/test/python/README.md b/hudi-examples/hudi-examples-spark/src/test/python/README.md new file mode 100644 index 000000000000..71382fd979d6 --- /dev/null +++ b/hudi-examples/hudi-examples-spark/src/test/python/README.md @@ -0,0 +1,42 @@ + +# Requirements +Python is required to run this. Pyspark 2.4.7 does not work with the latest versions of python (python 3.8+) so if you want to use a later version (in the example below 3.3) you can build Hudi by using the command: +```bash +cd $HUDI_DIR +mvn clean install -DskipTests -Dspark3.3 -Dscala2.12 +``` +Various python packages may also need to be installed so you should get pip and then use **pip install \** to get them +# How to Run +1. [Download pyspark](https://spark.apache.org/downloads) +2. Extract it where you want it to be installed and note that location +3. Run(or add to .bashrc) the following and make sure that you put in the correct path for SPARK_HOME +```bash +export SPARK_HOME=/path/to/spark/home +export PATH=$PATH:$SPARK_HOME/bin:$SPARK_HOME/sbin +export PYSPARK_SUBMIT_ARGS="--master local[*]" +export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH +export PYTHONPATH=$SPARK_HOME/python/lib/*.zip:$PYTHONPATH +``` +4. Identify the Hudi Spark Bundle .jar or package that you wish to use: +A package will be in the format **org.apache.hudi:hudi-spark3.3-bundle_2.12:0.12.0** +A jar will be in the format **\[HUDI_BASE_PATH\]/packaging/hudi-spark-bundle/target/hudi-spark-bundle\[VERSION\].jar** +5. Go to the hudi directory and run the quickstart examples using the commands below, using the -t flag for the table name and the -p flag or -j flag for your package or jar respectively. +```bash +cd $HUDI_DIR +python3 hudi-examples/hudi-examples-spark/src/test/python/HoodiePySparkQuickstart.py [-h] -t TABLE (-p PACKAGE | -j JAR) +``` \ No newline at end of file diff --git a/hudi-spark-datasource/hudi-spark/src/main/java/org/apache/hudi/QuickstartUtils.java b/hudi-spark-datasource/hudi-spark/src/main/java/org/apache/hudi/QuickstartUtils.java index 56ad5a8b66c8..453cbb4e748a 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/java/org/apache/hudi/QuickstartUtils.java +++ b/hudi-spark-datasource/hudi-spark/src/main/java/org/apache/hudi/QuickstartUtils.java @@ -33,7 +33,9 @@ import org.apache.spark.sql.Row; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -100,7 +102,7 @@ public int getNumExistingKeys() { } public static GenericRecord generateGenericRecord(String rowKey, String riderName, String driverName, - long timestamp) { + long timestamp) { GenericRecord rec = new GenericData.Record(avroSchema); rec.put("uuid", rowKey); rec.put("ts", timestamp); @@ -135,7 +137,7 @@ public static OverwriteWithLatestAvroPayload generateRandomValue(HoodieKey key, */ private static long generateRangeRandomTimestamp(int daysTillNow) { long maxIntervalMillis = daysTillNow * 24 * 60 * 60 * 1000L; - return System.currentTimeMillis() - (long)(Math.random() * maxIntervalMillis); + return System.currentTimeMillis() - (long) (Math.random() * maxIntervalMillis); } /** @@ -190,6 +192,30 @@ public List generateUpdates(Integer n) { }).collect(Collectors.toList()); } + /** + * Generates new updates, one for each of the keys above + * list + * + * @param n Number of updates (must be no more than number of existing keys) + * @return list of hoodie record updates + */ + public List generateUniqueUpdates(Integer n) { + if (numExistingKeys < n) { + throw new HoodieException("Data must have been written before performing the update operation"); + } + List keys = IntStream.range(0, numExistingKeys).boxed() + .collect(Collectors.toCollection(ArrayList::new)); + Collections.shuffle(keys); + String randomString = generateRandomString(); + return IntStream.range(0, n).boxed().map(x -> { + try { + return generateUpdateRecord(existingKeys.get(keys.get(x)), randomString); + } catch (IOException e) { + throw new HoodieIOException(e.getMessage(), e); + } + }).collect(Collectors.toList()); + } + /** * Generates delete records for the passed in rows. * @@ -200,9 +226,9 @@ public List generateDeletes(List rows) { // if row.length() == 2, then the record contains "uuid" and "partitionpath" fields, otherwise, // another field "ts" is available return rows.stream().map(row -> row.length() == 2 - ? convertToString(row.getAs("uuid"), row.getAs("partitionpath"), null) : - convertToString(row.getAs("uuid"), row.getAs("partitionpath"), row.getAs("ts")) - ).filter(os -> os.isPresent()).map(os -> os.get()) + ? convertToString(row.getAs("uuid"), row.getAs("partitionpath"), null) : + convertToString(row.getAs("uuid"), row.getAs("partitionpath"), row.getAs("ts")) + ).filter(os -> os.isPresent()).map(os -> os.get()) .collect(Collectors.toList()); }