diff --git a/pom.xml b/pom.xml
index bdd41e9f..4f4b5860 100644
--- a/pom.xml
+++ b/pom.xml
@@ -106,6 +106,7 @@
kafkaspark
+ spark4hbase-connectors-assembly
diff --git a/spark4/README.md b/spark4/README.md
new file mode 100644
index 00000000..e4d596f1
--- /dev/null
+++ b/spark4/README.md
@@ -0,0 +1,40 @@
+
+
+# Apache HBase™ Spark Connector
+
+## Spark, Scala and Configurable Options
+
+To generate an artifact for a different [Spark version](https://mvnrepository.com/artifact/org.apache.spark/spark-core) and/or [Scala version](https://www.scala-lang.org/download/all.html),
+[Hadoop version](https://mvnrepository.com/artifact/org.apache.hadoop/hadoop-core), or [HBase version](https://mvnrepository.com/artifact/org.apache.hbase/hbase), pass command-line options as follows (changing version numbers appropriately):
+
+```
+$ mvn -Dcheckstyle.skip=true --add-opens java.base/java.util=ALL-UNNAMED --add-opens java.base/java.lang=ALL-UNNAMED --add-opens java.base/sun.net.util=ALL-UNNAMED --add-opens java.base/java.lang=ALL-UNNAMED -Dspark.version=4.0.0-preview1 -Dscala.version=2.13.14 -Dhadoop-three.version=3.2.3 -Dscala.binary.version=2.13 -Dhbase.version=2.4.8 clean install
+```
+
+
+## Configuration and Installation
+**Client-side** (Spark) configuration:
+- The HBase configuration file `hbase-site.xml` should be made available to Spark, it can be copied to `$SPARK_CONF_DIR` (default is $SPARK_HOME/conf`)
+
+**Server-side** (HBase region servers) configuration:
+- The following jars need to be in the CLASSPATH of the HBase region servers:
+ - scala-library, hbase-spark, and hbase-spark-protocol-shaded.
+- The server-side configuration is needed for column filter pushdown
+ - if you cannot perform the server-side configuration, consider using `.option("hbase.spark.pushdown.columnfilter", false)`
+- The Scala library version must match the Scala version (2.13) used for compiling the connector.
diff --git a/spark4/hbase-spark4-it/pom.xml b/spark4/hbase-spark4-it/pom.xml
new file mode 100644
index 00000000..db850594
--- /dev/null
+++ b/spark4/hbase-spark4-it/pom.xml
@@ -0,0 +1,360 @@
+
+
+
+ 4.0.0
+
+ org.apache.hbase.connectors
+ spark4
+ ${revision}
+ ../pom.xml
+
+ org.apache.hbase.connectors.spark
+ hbase-spark4-it
+ Apache HBase - Spark4 Integration Tests
+ Integration and System tests for HBase
+
+
+ **/Test*.java
+ **/IntegrationTest*.java
+
+ 4g
+
+
+
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+ org.apache.hbase
+ hbase-shaded-testing-util
+
+
+ org.apache.hbase
+ hbase-it
+ test-jar
+
+
+ org.apache.hbase.connectors.spark
+ hbase-spark4
+ ${revision}
+
+
+ org.apache.hbase
+ ${compat.module}
+
+
+ org.apache.hbase.thirdparty
+ hbase-shaded-miscellaneous
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop-two.version}
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+ log4j
+ log4j
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop-two.version}
+ test-jar
+ test
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+ log4j
+ log4j
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+
+
+
+ com.fasterxml.jackson.module
+ jackson-module-scala_${scala.binary.version}
+ ${jackson.version}
+ test
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+
+ org.scala-lang
+ scala-library
+
+
+
+ org.scala-lang
+ scalap
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+ org.apache.hadoop
+ hadoop-client-api
+
+
+ org.apache.hadoop
+ hadoop-client-runtime
+
+
+
+
+ org.scala-lang
+ scala-library
+ ${scala.version}
+ provided
+
+
+ org.scala-lang.modules
+ scala-xml_2.11
+ 1.0.4
+ provided
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-streaming_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-streaming_${scala.binary.version}
+ ${spark.version}
+ tests
+ test-jar
+ test
+
+
+
+ junit
+ junit
+ test
+
+
+ org.mockito
+ mockito-all
+ test
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-source-plugin
+
+
+
+ maven-assembly-plugin
+
+ true
+
+
+
+ org.apache.maven.plugins
+ maven-failsafe-plugin
+ ${surefire.version}
+
+
+ ${integrationtest.include}
+
+
+ ${unittest.include}
+ **/*$*
+
+ ${test.output.tofile}
+ false
+ false
+
+
+
+ org.apache.maven.surefire
+ surefire-junit4
+ ${surefire.version}
+
+
+
+
+ integration-test
+
+ integration-test
+
+ integration-test
+
+
+ verify
+
+ verify
+
+ verify
+
+
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-failsafe-plugin
+
+ false
+ 1
+ true
+
+ 1800
+ -enableassertions -Xmx${failsafe.Xmx}
+ -Djava.security.egd=file:/dev/./urandom -XX:+CMSClassUnloadingEnabled
+ -verbose:gc -XX:+PrintCommandLineFlags -XX:+PrintFlagsFinal
+
+
+
+ org.apache.maven.plugins
+ maven-enforcer-plugin
+
+
+
+ banned-hbase-spark
+
+ enforce
+
+
+ true
+
+
+
+ banned-scala
+
+ enforce
+
+
+ true
+
+
+
+
+
+ maven-dependency-plugin
+
+
+ create-mrapp-generated-classpath
+
+ build-classpath
+
+ generate-test-resources
+
+
+ ${project.build.directory}/test-classes/spark-generated-classpath
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-checkstyle-plugin
+
+
+ net.revelc.code
+ warbucks-maven-plugin
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-report-plugin
+ 3.0.0-M4
+
+
+ spark-integration-tests
+
+ report-only
+
+
+ failsafe-report
+
+ ${project.build.directory}/failsafe-reports
+
+
+
+
+
+
+
+
+
+
+
+ skipIntegrationTests
+
+
+ skipIntegrationTests
+
+
+
+ true
+
+
+
+
+
diff --git a/spark4/hbase-spark4-it/src/test/java/org/apache/hadoop/hbase/spark/IntegrationTestSparkBulkLoad.java b/spark4/hbase-spark4-it/src/test/java/org/apache/hadoop/hbase/spark/IntegrationTestSparkBulkLoad.java
new file mode 100644
index 00000000..07669157
--- /dev/null
+++ b/spark4/hbase-spark4-it/src/test/java/org/apache/hadoop/hbase/spark/IntegrationTestSparkBulkLoad.java
@@ -0,0 +1,654 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hbase.Cell;
+import org.apache.hadoop.hbase.CellUtil;
+import org.apache.hadoop.hbase.HBaseConfiguration;
+import org.apache.hadoop.hbase.HBaseTestingUtility;
+import org.apache.hadoop.hbase.HConstants;
+import org.apache.hadoop.hbase.HTableDescriptor;
+import org.apache.hadoop.hbase.IntegrationTestBase;
+import org.apache.hadoop.hbase.IntegrationTestingUtility;
+import org.apache.hadoop.hbase.TableName;
+import org.apache.hadoop.hbase.client.Admin;
+import org.apache.hadoop.hbase.client.Connection;
+import org.apache.hadoop.hbase.client.ConnectionFactory;
+import org.apache.hadoop.hbase.client.Consistency;
+import org.apache.hadoop.hbase.client.RegionLocator;
+import org.apache.hadoop.hbase.client.Result;
+import org.apache.hadoop.hbase.client.Scan;
+import org.apache.hadoop.hbase.client.Table;
+import org.apache.hadoop.hbase.io.ImmutableBytesWritable;
+import org.apache.hadoop.hbase.mapreduce.IntegrationTestBulkLoad;
+import org.apache.hadoop.hbase.tool.LoadIncrementalHFiles;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.apache.hadoop.hbase.util.EnvironmentEdgeManager;
+import org.apache.hadoop.hbase.util.Pair;
+import org.apache.hadoop.hbase.util.RegionSplitter;
+import org.apache.hadoop.util.StringUtils;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.spark.Partitioner;
+import org.apache.spark.SerializableWritable;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.apache.spark.api.java.function.VoidFunction;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.Tuple2;
+
+import org.apache.hbase.thirdparty.com.google.common.collect.Sets;
+import org.apache.hbase.thirdparty.org.apache.commons.cli.CommandLine;
+
+/**
+ * Test Bulk Load and Spark on a distributed cluster. It starts an Spark job that creates linked
+ * chains. This test mimic {@link IntegrationTestBulkLoad} in mapreduce. Usage on cluster: First add
+ * hbase related jars and hbase-spark.jar into spark classpath. spark-submit --class
+ * org.apache.hadoop.hbase.spark.IntegrationTestSparkBulkLoad
+ * HBASE_HOME/lib/hbase-spark-it-XXX-tests.jar -m slowDeterministic
+ * -Dhbase.spark.bulkload.chainlength=300
+ */
+public class IntegrationTestSparkBulkLoad extends IntegrationTestBase {
+
+ private static final Logger LOG = LoggerFactory.getLogger(IntegrationTestSparkBulkLoad.class);
+
+ // The number of partitions for random generated data
+ private static String BULKLOAD_PARTITIONS_NUM = "hbase.spark.bulkload.partitionsnum";
+ private static int DEFAULT_BULKLOAD_PARTITIONS_NUM = 3;
+
+ private static String BULKLOAD_CHAIN_LENGTH = "hbase.spark.bulkload.chainlength";
+ private static int DEFAULT_BULKLOAD_CHAIN_LENGTH = 200000;
+
+ private static String BULKLOAD_IMPORT_ROUNDS = "hbase.spark.bulkload.importround";
+ private static int DEFAULT_BULKLOAD_IMPORT_ROUNDS = 1;
+
+ private static String CURRENT_ROUND_NUM = "hbase.spark.bulkload.current.roundnum";
+
+ private static String NUM_REPLICA_COUNT_KEY = "hbase.spark.bulkload.replica.countkey";
+ private static int DEFAULT_NUM_REPLICA_COUNT = 1;
+
+ private static String BULKLOAD_TABLE_NAME = "hbase.spark.bulkload.tableName";
+ private static String DEFAULT_BULKLOAD_TABLE_NAME = "IntegrationTestSparkBulkLoad";
+
+ private static String BULKLOAD_OUTPUT_PATH = "hbase.spark.bulkload.output.path";
+
+ private static final String OPT_LOAD = "load";
+ private static final String OPT_CHECK = "check";
+
+ private boolean load = false;
+ private boolean check = false;
+
+ private static final byte[] CHAIN_FAM = Bytes.toBytes("L");
+ private static final byte[] SORT_FAM = Bytes.toBytes("S");
+ private static final byte[] DATA_FAM = Bytes.toBytes("D");
+
+ /**
+ * Running spark job to load data into hbase table
+ */
+ public void runLoad() throws Exception {
+ setupTable();
+ int numImportRounds = getConf().getInt(BULKLOAD_IMPORT_ROUNDS, DEFAULT_BULKLOAD_IMPORT_ROUNDS);
+ LOG.info("Running load with numIterations:" + numImportRounds);
+ for (int i = 0; i < numImportRounds; i++) {
+ runLinkedListSparkJob(i);
+ }
+ }
+
+ /**
+ * Running spark job to create LinkedList for testing
+ * @param iteration iteration th of this job
+ * @throws Exception if an HBase operation or getting the test directory fails
+ */
+ public void runLinkedListSparkJob(int iteration) throws Exception {
+ String jobName = IntegrationTestSparkBulkLoad.class.getSimpleName() + " _load "
+ + EnvironmentEdgeManager.currentTime();
+
+ LOG.info("Running iteration " + iteration + "in Spark Job");
+
+ Path output = null;
+ if (conf.get(BULKLOAD_OUTPUT_PATH) == null) {
+ output = util.getDataTestDirOnTestFS(getTablename() + "-" + iteration);
+ } else {
+ output = new Path(conf.get(BULKLOAD_OUTPUT_PATH));
+ }
+
+ SparkConf sparkConf = new SparkConf().setAppName(jobName).setMaster("local");
+ Configuration hbaseConf = new Configuration(getConf());
+ hbaseConf.setInt(CURRENT_ROUND_NUM, iteration);
+ int partitionNum = hbaseConf.getInt(BULKLOAD_PARTITIONS_NUM, DEFAULT_BULKLOAD_PARTITIONS_NUM);
+
+ JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+ JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, hbaseConf);
+
+ LOG.info("Partition RDD into " + partitionNum + " parts");
+ List temp = new ArrayList<>();
+ JavaRDD> rdd = jsc.parallelize(temp, partitionNum).mapPartitionsWithIndex(
+ new LinkedListCreationMapper(new SerializableWritable<>(hbaseConf)), false);
+
+ hbaseContext.bulkLoad(rdd, getTablename(), new ListToKeyValueFunc(), output.toUri().getPath(),
+ new HashMap<>(), false, HConstants.DEFAULT_MAX_FILE_SIZE);
+
+ try (Connection conn = ConnectionFactory.createConnection(conf); Admin admin = conn.getAdmin();
+ Table table = conn.getTable(getTablename());
+ RegionLocator regionLocator = conn.getRegionLocator(getTablename())) {
+ // Create a new loader.
+ LoadIncrementalHFiles loader = new LoadIncrementalHFiles(conf);
+
+ // Load the HFiles into table.
+ loader.doBulkLoad(output, admin, table, regionLocator);
+ }
+
+ // Delete the files.
+ util.getTestFileSystem().delete(output, true);
+ jsc.close();
+ }
+
+ // See mapreduce.IntegrationTestBulkLoad#LinkedListCreationMapper
+ // Used to generate test data
+ public static class LinkedListCreationMapper
+ implements Function2, Iterator>> {
+
+ SerializableWritable swConfig = null;
+ private Random rand = new Random();
+
+ public LinkedListCreationMapper(SerializableWritable conf) {
+ this.swConfig = conf;
+ }
+
+ @Override
+ public Iterator> call(Integer v1, Iterator v2) throws Exception {
+ Configuration config = (Configuration) swConfig.value();
+ int partitionId = v1.intValue();
+ LOG.info("Starting create List in Partition " + partitionId);
+
+ int partitionNum = config.getInt(BULKLOAD_PARTITIONS_NUM, DEFAULT_BULKLOAD_PARTITIONS_NUM);
+ int chainLength = config.getInt(BULKLOAD_CHAIN_LENGTH, DEFAULT_BULKLOAD_CHAIN_LENGTH);
+ int iterationsNum = config.getInt(BULKLOAD_IMPORT_ROUNDS, DEFAULT_BULKLOAD_IMPORT_ROUNDS);
+ int iterationsCur = config.getInt(CURRENT_ROUND_NUM, 0);
+ List> res = new LinkedList<>();
+
+ long tempId = partitionId + iterationsCur * partitionNum;
+ long totalPartitionNum = partitionNum * iterationsNum;
+ long chainId = Math.abs(rand.nextLong());
+ chainId = chainId - (chainId % totalPartitionNum) + tempId;
+
+ byte[] chainIdArray = Bytes.toBytes(chainId);
+ long currentRow = 0;
+ long nextRow = getNextRow(0, chainLength);
+ for (long i = 0; i < chainLength; i++) {
+ byte[] rk = Bytes.toBytes(currentRow);
+ // Insert record into a list
+ List tmp1 = Arrays.asList(rk, CHAIN_FAM, chainIdArray, Bytes.toBytes(nextRow));
+ List tmp2 = Arrays.asList(rk, SORT_FAM, chainIdArray, Bytes.toBytes(i));
+ List tmp3 = Arrays.asList(rk, DATA_FAM, chainIdArray, Bytes.toBytes("random" + i));
+ res.add(tmp1);
+ res.add(tmp2);
+ res.add(tmp3);
+
+ currentRow = nextRow;
+ nextRow = getNextRow(i + 1, chainLength);
+ }
+ return res.iterator();
+ }
+
+ /** Returns a unique row id within this chain for this index */
+ private long getNextRow(long index, long chainLength) {
+ long nextRow = Math.abs(new Random().nextLong());
+ // use significant bits from the random number, but pad with index to ensure it is unique
+ // this also ensures that we do not reuse row = 0
+ // row collisions from multiple mappers are fine, since we guarantee unique chainIds
+ nextRow = nextRow - (nextRow % chainLength) + index;
+ return nextRow;
+ }
+ }
+
+ public static class ListToKeyValueFunc
+ implements Function, Pair> {
+ @Override
+ public Pair call(List v1) throws Exception {
+ if (v1 == null || v1.size() != 4) {
+ return null;
+ }
+ KeyFamilyQualifier kfq = new KeyFamilyQualifier(v1.get(0), v1.get(1), v1.get(2));
+
+ return new Pair<>(kfq, v1.get(3));
+ }
+ }
+
+ /**
+ * After adding data to the table start a mr job to check the bulk load.
+ */
+ public void runCheck() throws Exception {
+ LOG.info("Running check");
+ String jobName = IntegrationTestSparkBulkLoad.class.getSimpleName() + "_check"
+ + EnvironmentEdgeManager.currentTime();
+
+ SparkConf sparkConf = new SparkConf().setAppName(jobName).setMaster("local");
+ Configuration hbaseConf = new Configuration(getConf());
+ JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+ JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, hbaseConf);
+
+ Scan scan = new Scan();
+ scan.addFamily(CHAIN_FAM);
+ scan.addFamily(SORT_FAM);
+ scan.setMaxVersions(1);
+ scan.setCacheBlocks(false);
+ scan.setBatch(1000);
+ int replicaCount = conf.getInt(NUM_REPLICA_COUNT_KEY, DEFAULT_NUM_REPLICA_COUNT);
+ if (replicaCount != DEFAULT_NUM_REPLICA_COUNT) {
+ scan.setConsistency(Consistency.TIMELINE);
+ }
+
+ // 1. Using TableInputFormat to get data from HBase table
+ // 2. Mimic LinkedListCheckingMapper in mapreduce.IntegrationTestBulkLoad
+ // 3. Sort LinkKey by its order ID
+ // 4. Group LinkKey if they have same chainId, and repartition RDD by NaturalKeyPartitioner
+ // 5. Check LinkList in each Partition using LinkedListCheckingFlatMapFunc
+ hbaseContext.hbaseRDD(getTablename(), scan).flatMapToPair(new LinkedListCheckingFlatMapFunc())
+ .sortByKey()
+ .combineByKey(new createCombinerFunc(), new mergeValueFunc(), new mergeCombinersFunc(),
+ new NaturalKeyPartitioner(new SerializableWritable<>(hbaseConf)))
+ .foreach(new LinkedListCheckingForeachFunc(new SerializableWritable<>(hbaseConf)));
+ jsc.close();
+ }
+
+ private void runCheckWithRetry() throws Exception {
+ try {
+ runCheck();
+ } catch (Throwable t) {
+ LOG.warn("Received " + StringUtils.stringifyException(t));
+ LOG.warn("Running the check MR Job again to see whether an ephemeral problem or not");
+ runCheck();
+ throw t; // we should still fail the test even if second retry succeeds
+ }
+ // everything green
+ }
+
+ /**
+ * PairFlatMapFunction used to transfer {@code } to
+ * {@code Tuple}.
+ */
+ public static class LinkedListCheckingFlatMapFunc implements
+ PairFlatMapFunction, SparkLinkKey, SparkLinkChain> {
+
+ @Override
+ public Iterator>
+ call(Tuple2 v) throws Exception {
+ Result value = v._2();
+ long longRk = Bytes.toLong(value.getRow());
+ List> list = new LinkedList<>();
+
+ for (Map.Entry entry : value.getFamilyMap(CHAIN_FAM).entrySet()) {
+ long chainId = Bytes.toLong(entry.getKey());
+ long next = Bytes.toLong(entry.getValue());
+ Cell c = value.getColumnCells(SORT_FAM, entry.getKey()).get(0);
+ long order = Bytes.toLong(CellUtil.cloneValue(c));
+ Tuple2 tuple2 =
+ new Tuple2<>(new SparkLinkKey(chainId, order), new SparkLinkChain(longRk, next));
+ list.add(tuple2);
+ }
+ return list.iterator();
+ }
+ }
+
+ public static class createCombinerFunc implements Function> {
+ @Override
+ public List call(SparkLinkChain v1) throws Exception {
+ List list = new LinkedList<>();
+ list.add(v1);
+ return list;
+ }
+ }
+
+ public static class mergeValueFunc
+ implements Function2, SparkLinkChain, List> {
+ @Override
+ public List call(List v1, SparkLinkChain v2) throws Exception {
+ if (v1 == null) {
+ v1 = new LinkedList<>();
+ }
+
+ v1.add(v2);
+ return v1;
+ }
+ }
+
+ public static class mergeCombinersFunc
+ implements Function2, List, List> {
+ @Override
+ public List call(List v1, List v2)
+ throws Exception {
+ v1.addAll(v2);
+ return v1;
+ }
+ }
+
+ /**
+ * Class to figure out what partition to send a link in the chain to. This is based upon the
+ * linkKey's ChainId.
+ */
+ public static class NaturalKeyPartitioner extends Partitioner {
+
+ private int numPartions = 0;
+
+ public NaturalKeyPartitioner(SerializableWritable swConf) {
+ Configuration hbaseConf = (Configuration) swConf.value();
+ numPartions = hbaseConf.getInt(BULKLOAD_PARTITIONS_NUM, DEFAULT_BULKLOAD_PARTITIONS_NUM);
+
+ }
+
+ @Override
+ public int numPartitions() {
+ return numPartions;
+ }
+
+ @Override
+ public int getPartition(Object key) {
+ if (!(key instanceof SparkLinkKey)) {
+ return -1;
+ }
+
+ int hash = ((SparkLinkKey) key).getChainId().hashCode();
+ return Math.abs(hash % numPartions);
+
+ }
+ }
+
+ /**
+ * Sort all LinkChain for one LinkKey, and test {@code List}.
+ */
+ public static class LinkedListCheckingForeachFunc
+ implements VoidFunction>> {
+
+ private SerializableWritable swConf = null;
+
+ public LinkedListCheckingForeachFunc(SerializableWritable conf) {
+ swConf = conf;
+ }
+
+ @Override
+ public void call(Tuple2> v1) throws Exception {
+ long next = -1L;
+ long prev = -1L;
+ long count = 0L;
+
+ SparkLinkKey key = v1._1();
+ List values = v1._2();
+
+ for (SparkLinkChain lc : values) {
+
+ if (next == -1) {
+ if (lc.getRk() != 0L) {
+ String msg = "Chains should all start at rk 0, but read rk " + lc.getRk() + ". Chain:"
+ + key.getChainId() + ", order:" + key.getOrder();
+ throw new RuntimeException(msg);
+ }
+ next = lc.getNext();
+ } else {
+ if (next != lc.getRk()) {
+ String msg = "Missing a link in the chain. Prev rk " + prev + " was, expecting " + next
+ + " but got " + lc.getRk() + ". Chain:" + key.getChainId() + ", order:"
+ + key.getOrder();
+ throw new RuntimeException(msg);
+ }
+ prev = lc.getRk();
+ next = lc.getNext();
+ }
+ count++;
+ }
+ Configuration hbaseConf = (Configuration) swConf.value();
+ int expectedChainLen = hbaseConf.getInt(BULKLOAD_CHAIN_LENGTH, DEFAULT_BULKLOAD_CHAIN_LENGTH);
+ if (count != expectedChainLen) {
+ String msg = "Chain wasn't the correct length. Expected " + expectedChainLen + " got "
+ + count + ". Chain:" + key.getChainId() + ", order:" + key.getOrder();
+ throw new RuntimeException(msg);
+ }
+ }
+ }
+
+ /**
+ * Writable class used as the key to group links in the linked list. Used as the key emited from a
+ * pass over the table.
+ */
+ public static class SparkLinkKey implements java.io.Serializable, Comparable {
+
+ private Long chainId;
+ private Long order;
+
+ public Long getOrder() {
+ return order;
+ }
+
+ public Long getChainId() {
+ return chainId;
+ }
+
+ public SparkLinkKey(long chainId, long order) {
+ this.chainId = chainId;
+ this.order = order;
+ }
+
+ @Override
+ public int hashCode() {
+ return this.getChainId().hashCode();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof SparkLinkKey)) {
+ return false;
+ }
+
+ SparkLinkKey otherKey = (SparkLinkKey) other;
+ return this.getChainId().equals(otherKey.getChainId());
+ }
+
+ @Override
+ public int compareTo(SparkLinkKey other) {
+ int res = getChainId().compareTo(other.getChainId());
+
+ if (res == 0) {
+ res = getOrder().compareTo(other.getOrder());
+ }
+
+ return res;
+ }
+ }
+
+ /**
+ * Writable used as the value emitted from a pass over the hbase table.
+ */
+ public static class SparkLinkChain implements java.io.Serializable, Comparable {
+
+ public Long getNext() {
+ return next;
+ }
+
+ public Long getRk() {
+ return rk;
+ }
+
+ public SparkLinkChain(Long rk, Long next) {
+ this.rk = rk;
+ this.next = next;
+ }
+
+ private Long rk;
+ private Long next;
+
+ @Override
+ public int compareTo(SparkLinkChain linkChain) {
+ int res = getRk().compareTo(linkChain.getRk());
+ if (res == 0) {
+ res = getNext().compareTo(linkChain.getNext());
+ }
+ return res;
+ }
+
+ @Override
+ public int hashCode() {
+ return getRk().hashCode() ^ getNext().hashCode();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof SparkLinkChain)) {
+ return false;
+ }
+
+ SparkLinkChain otherKey = (SparkLinkChain) other;
+ return this.getRk().equals(otherKey.getRk()) && this.getNext().equals(otherKey.getNext());
+ }
+ }
+
+ /**
+ * Allow the scan to go to replica, this would not affect the runCheck() Since data are BulkLoaded
+ * from HFile into table
+ * @throws IOException if an HBase operation fails
+ * @throws InterruptedException if modifying the table fails
+ */
+ private void installSlowingCoproc() throws IOException, InterruptedException {
+ int replicaCount = conf.getInt(NUM_REPLICA_COUNT_KEY, DEFAULT_NUM_REPLICA_COUNT);
+
+ if (replicaCount == DEFAULT_NUM_REPLICA_COUNT) {
+ return;
+ }
+
+ TableName t = getTablename();
+ Admin admin = util.getAdmin();
+ HTableDescriptor desc = admin.getTableDescriptor(t);
+ desc.addCoprocessor(IntegrationTestBulkLoad.SlowMeCoproScanOperations.class.getName());
+ HBaseTestingUtility.modifyTableSync(admin, desc);
+ }
+
+ @Test
+ public void testBulkLoad() throws Exception {
+ runLoad();
+ installSlowingCoproc();
+ runCheckWithRetry();
+ }
+
+ private byte[][] getSplits(int numRegions) {
+ RegionSplitter.UniformSplit split = new RegionSplitter.UniformSplit();
+ split.setFirstRow(Bytes.toBytes(0L));
+ split.setLastRow(Bytes.toBytes(Long.MAX_VALUE));
+ return split.split(numRegions);
+ }
+
+ private void setupTable() throws IOException, InterruptedException {
+ if (util.getAdmin().tableExists(getTablename())) {
+ util.deleteTable(getTablename());
+ }
+
+ util.createTable(getTablename(), new byte[][] { CHAIN_FAM, SORT_FAM, DATA_FAM }, getSplits(16));
+
+ int replicaCount = conf.getInt(NUM_REPLICA_COUNT_KEY, DEFAULT_NUM_REPLICA_COUNT);
+
+ if (replicaCount == DEFAULT_NUM_REPLICA_COUNT) {
+ return;
+ }
+
+ TableName t = getTablename();
+ HBaseTestingUtility.setReplicas(util.getAdmin(), t, replicaCount);
+ }
+
+ @Override
+ public void setUpCluster() throws Exception {
+ util = getTestingUtil(getConf());
+ util.initializeCluster(1);
+ int replicaCount = getConf().getInt(NUM_REPLICA_COUNT_KEY, DEFAULT_NUM_REPLICA_COUNT);
+ if (LOG.isDebugEnabled() && replicaCount != DEFAULT_NUM_REPLICA_COUNT) {
+ LOG.debug("Region Replicas enabled: " + replicaCount);
+ }
+
+ // Scale this up on a real cluster
+ if (util.isDistributedCluster()) {
+ util.getConfiguration().setIfUnset(BULKLOAD_PARTITIONS_NUM,
+ String.valueOf(DEFAULT_BULKLOAD_PARTITIONS_NUM));
+ util.getConfiguration().setIfUnset(BULKLOAD_IMPORT_ROUNDS, "1");
+ } else {
+ util.startMiniMapReduceCluster();
+ }
+ }
+
+ @Override
+ protected void addOptions() {
+ super.addOptions();
+ super.addOptNoArg(OPT_CHECK, "Run check only");
+ super.addOptNoArg(OPT_LOAD, "Run load only");
+ }
+
+ @Override
+ protected void processOptions(CommandLine cmd) {
+ super.processOptions(cmd);
+ check = cmd.hasOption(OPT_CHECK);
+ load = cmd.hasOption(OPT_LOAD);
+ }
+
+ @Override
+ public int runTestFromCommandLine() throws Exception {
+ if (load) {
+ runLoad();
+ } else if (check) {
+ installSlowingCoproc();
+ runCheckWithRetry();
+ } else {
+ testBulkLoad();
+ }
+ return 0;
+ }
+
+ @Override
+ public TableName getTablename() {
+ return getTableName(getConf());
+ }
+
+ public static TableName getTableName(Configuration conf) {
+ return TableName.valueOf(conf.get(BULKLOAD_TABLE_NAME, DEFAULT_BULKLOAD_TABLE_NAME));
+ }
+
+ @Override
+ protected Set getColumnFamilies() {
+ return Sets.newHashSet(Bytes.toString(CHAIN_FAM), Bytes.toString(DATA_FAM),
+ Bytes.toString(SORT_FAM));
+ }
+
+ public static void main(String[] args) throws Exception {
+ Configuration conf = HBaseConfiguration.create();
+ IntegrationTestingUtility.setUseDistributedCluster(conf);
+ int status = ToolRunner.run(conf, new IntegrationTestSparkBulkLoad(), args);
+ System.exit(status);
+ }
+}
diff --git a/spark4/hbase-spark4-it/src/test/resources/hbase-site.xml b/spark4/hbase-spark4-it/src/test/resources/hbase-site.xml
new file mode 100644
index 00000000..99d2ab8d
--- /dev/null
+++ b/spark4/hbase-spark4-it/src/test/resources/hbase-site.xml
@@ -0,0 +1,32 @@
+
+
+
+
+
+ hbase.defaults.for.version.skip
+ true
+
+
+ hbase.hconnection.threads.keepalivetime
+ 3
+
+
diff --git a/spark4/hbase-spark4-protocol-shaded/pom.xml b/spark4/hbase-spark4-protocol-shaded/pom.xml
new file mode 100644
index 00000000..de113ecd
--- /dev/null
+++ b/spark4/hbase-spark4-protocol-shaded/pom.xml
@@ -0,0 +1,89 @@
+
+
+
+ 4.0.0
+
+
+ org.apache.hbase.connectors
+ spark4
+ ${revision}
+ ../pom.xml
+
+
+ org.apache.hbase.connectors.spark
+ hbase-spark4-protocol-shaded
+ Apache HBase - Spark4 Protocol (Shaded)
+
+
+
+
+ org.apache.hbase.connectors.spark
+ hbase-spark4-protocol
+ true
+
+
+ org.apache.hbase.thirdparty
+ hbase-shaded-protobuf
+ ${hbase-thirdparty.version}
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+
+
+
+ shade
+
+ package
+
+ true
+ true
+
+ false
+
+
+ com.google.protobuf
+ org.apache.hbase.thirdparty.com.google.protobuf
+
+
+
+
+ com.google.protobuf:protobuf-java
+ org.apache.hbase.thirdparty:*
+
+
+
+
+
+
+
+
+
+
diff --git a/spark4/hbase-spark4-protocol/pom.xml b/spark4/hbase-spark4-protocol/pom.xml
new file mode 100644
index 00000000..77a8d0d6
--- /dev/null
+++ b/spark4/hbase-spark4-protocol/pom.xml
@@ -0,0 +1,76 @@
+
+
+
+ 4.0.0
+
+
+ org.apache.hbase.connectors
+ spark4
+ ${revision}
+ ../
+
+
+ org.apache.hbase.connectors.spark
+ hbase-spark4-protocol
+ Apache HBase - Spark4 Protocol
+
+
+
+ com.google.protobuf
+ protobuf-java
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+ org.xolstice.maven.plugins
+ protobuf-maven-plugin
+
+
+ compile-protoc
+
+ compile
+
+ generate-sources
+
+
+
+
+ org.apache.maven.plugins
+ maven-source-plugin
+
+
+ attach-sources
+
+ jar
+
+
+
+
+
+
+
+
diff --git a/spark4/hbase-spark4-protocol/src/main/protobuf/SparkFilter.proto b/spark4/hbase-spark4-protocol/src/main/protobuf/SparkFilter.proto
new file mode 100644
index 00000000..e16c5517
--- /dev/null
+++ b/spark4/hbase-spark4-protocol/src/main/protobuf/SparkFilter.proto
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// This file contains protocol buffers that are used for Spark filters
+// over in the hbase-spark module
+package hbase.pb;
+
+option java_package = "org.apache.hadoop.hbase.spark.protobuf.generated";
+option java_outer_classname = "SparkFilterProtos";
+option java_generic_services = true;
+option java_generate_equals_and_hash = true;
+option optimize_for = SPEED;
+
+message SQLPredicatePushDownCellToColumnMapping {
+ required bytes column_family = 1;
+ required bytes qualifier = 2;
+ required string column_name = 3;
+}
+
+message SQLPredicatePushDownFilter {
+ required string dynamic_logic_expression = 1;
+ repeated bytes value_from_query_array = 2;
+ repeated SQLPredicatePushDownCellToColumnMapping cell_to_column_mapping = 3;
+ optional string encoderClassName = 4;
+}
diff --git a/spark4/hbase-spark4/README.md b/spark4/hbase-spark4/README.md
new file mode 100644
index 00000000..315fa640
--- /dev/null
+++ b/spark4/hbase-spark4/README.md
@@ -0,0 +1,24 @@
+
+
+##ON PROTOBUFS
+This maven module has core protobuf definition files ('.protos') used by hbase
+Spark that ship with hbase core including tests.
+
+Generation of java files from protobuf .proto files included here is done as
+part of the build.
diff --git a/spark4/hbase-spark4/pom.xml b/spark4/hbase-spark4/pom.xml
new file mode 100644
index 00000000..7edafe56
--- /dev/null
+++ b/spark4/hbase-spark4/pom.xml
@@ -0,0 +1,621 @@
+
+
+
+ 4.0.0
+
+
+ org.apache.hbase.connectors
+ spark4
+ ${revision}
+ ../pom.xml
+
+
+ org.apache.hbase.connectors.spark
+ hbase-spark4
+ Apache HBase - Spark4 Connector
+
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.apache.hbase.thirdparty
+ hbase-shaded-miscellaneous
+
+
+
+ javax.servlet
+ javax.servlet-api
+ 3.1.0
+ test
+
+
+
+ org.scala-lang
+ scala-library
+ ${scala.version}
+ provided
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+
+ org.scala-lang
+ scala-library
+
+
+
+ org.scala-lang
+ scalap
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+
+ org.xerial.snappy
+ snappy-java
+
+
+ xerces
+ xercesImpl
+
+
+ org.apache.hadoop
+ hadoop-client-api
+
+
+ org.apache.hadoop
+ hadoop-client-runtime
+
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-streaming_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-streaming_${scala.binary.version}
+ ${spark.version}
+ tests
+ test-jar
+ test
+
+
+ junit
+ junit
+ test
+
+
+ org.mockito
+ mockito-all
+ test
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ 3.2.19
+ test
+
+
+ org.scala-lang
+ scala-library
+
+
+
+
+ org.apache.hbase
+ hbase-shaded-client
+
+
+ org.apache.hbase.connectors.spark
+ hbase-spark4-protocol-shaded
+
+
+ org.apache.yetus
+ audience-annotations
+
+
+ org.apache.hbase
+ hbase-shaded-testing-util
+
+
+ org.apache.hbase
+ hbase-annotations
+ test-jar
+ test
+
+
+ org.apache.hbase
+ hbase-shaded-mapreduce
+ ${hbase.version}
+
+
+ org.apache.avro
+ avro
+
+
+ org.scala-lang
+ scala-reflect
+ ${scala.version}
+
+
+ org.apache.spark
+ spark-unsafe_${scala.binary.version}
+ ${spark.version}
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${spark.version}
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-source
+
+ add-source
+
+ validate
+
+
+ src/main/scala
+
+
+
+
+ add-test-source
+
+ add-test-source
+
+ validate
+
+
+ src/test/scala
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-enforcer-plugin
+
+
+
+ banned-scala
+
+ enforce
+
+
+ true
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-checkstyle-plugin
+
+
+ net.revelc.code
+ warbucks-maven-plugin
+
+
+
+ true
+
+
+
+ (?!.*(.generated.|.tmpl.|\$|org.apache.hadoop.hbase.spark.hbase.package)).*
+ false
+ true
+ false
+ false
+ false
+ org[.]apache[.]yetus[.]audience[.]InterfaceAudience.*
+
+
+
+
+
+
+
+
+
+
+ skipSparkTests
+
+
+ skipSparkTests
+
+
+
+ true
+ true
+ true
+
+
+
+
+ hadoop-2.0
+
+
+ hadoop.profile
+ 2.0
+
+
+
+
+ org.apache.hadoop
+ hadoop-client
+ ${hadoop-two.version}
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop-two.version}
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop-two.version}
+ test-jar
+ test
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+
+
+ org.apache.hadoop
+ hadoop-hdfs
+ ${hadoop-two.version}
+ test-jar
+ test
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+ xerces
+ xercesImpl
+
+
+
+
+ org.apache.hadoop
+ hadoop-minikdc
+ ${hadoop-two.version}
+ test
+
+
+ org.apache.directory.jdbm
+ apacheds-jdbm1
+
+
+
+
+
+
+
+ hadoop-3.0
+
+
+ !hadoop.profile
+
+
+
+ ${hadoop-three.version}
+
+
+
+ org.apache.hadoop
+ hadoop-client
+ ${hadoop-three.version}
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop-three.version}
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop-three.version}
+ test-jar
+ test
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+ log4j
+ log4j
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+
+
+ org.apache.hadoop
+ hadoop-hdfs
+ ${hadoop-three.version}
+ test-jar
+ test
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+
+
+ org.apache.hadoop
+ hadoop-minikdc
+ ${hadoop-three.version}
+ test
+
+
+
+
+
+
+ build-scala-sources
+
+
+ scala.skip
+ !true
+
+
+
+
+
+
+ org.codehaus.gmaven
+ gmaven-plugin
+ 1.5
+
+
+
+ execute
+
+ validate
+
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 4.9.2
+
+ ${project.build.sourceEncoding}
+ ${scala.version}
+
+ -feature
+
+ ${target.jvm}
+
+ ${compileSource}
+ ${compileSource}
+
+
+
+ scala-compile-first
+
+ add-source
+ compile
+
+ process-resources
+
+
+ scala-test-compile
+
+ testCompile
+
+ process-test-resources
+
+
+
+
+ org.scalatest
+ scalatest-maven-plugin
+ 2.0.2
+
+ ${project.build.directory}/surefire-reports
+ .
+ WDF TestSuite.txt
+ false
+
+
+
+ test
+
+ test
+
+ test
+
+ -Xmx1536m -XX:ReservedCodeCacheSize=512m
+ false
+
+
+
+
+
+
+
+
+
+ coverage
+
+ false
+
+
+ src/main/
+ ${project.parent.parent.basedir}
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+ src/test/scala
+ src/main/scala
+
+
+
+ default-sbt-compile
+
+ compile
+ testCompile
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+ true
+ true
+
+
+
+ org.scoverage
+ scoverage-maven-plugin
+
+ 1.4.11
+ true
+ true
+ false
+ ${project.build.sourceEncoding}
+ ${scoverageReportDir}
+ ${scoverageReportDir}
+
+
+
+ instrument
+
+ pre-compile
+ post-compile
+
+
+
+ package
+
+ package
+ report
+
+
+
+ scoverage-report
+
+
+ report-only
+
+ prepare-package
+
+
+
+
+
+
+
+
+ org.scoverage
+ scoverage-maven-plugin
+
+
+
+ report-only
+
+
+
+
+
+
+
+
+
+
diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java
new file mode 100644
index 00000000..f965e602
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import org.apache.hadoop.hbase.Cell;
+import org.apache.hadoop.hbase.exceptions.DeserializationException;
+import org.apache.hadoop.hbase.filter.Filter.ReturnCode;
+import org.apache.hadoop.hbase.filter.FilterBase;
+import org.apache.hadoop.hbase.spark.datasources.BytesEncoder;
+import org.apache.hadoop.hbase.spark.datasources.Field;
+import org.apache.hadoop.hbase.spark.datasources.JavaBytesEncoder;
+import org.apache.hadoop.hbase.spark.protobuf.generated.SparkFilterProtos;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.apache.yetus.audience.InterfaceAudience;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.collection.mutable.ListBuffer ;
+
+import org.apache.hbase.thirdparty.com.google.protobuf.ByteString;
+import org.apache.hbase.thirdparty.com.google.protobuf.InvalidProtocolBufferException;
+
+/**
+ * This filter will push down all qualifier logic given to us by SparkSQL so that we have make the
+ * filters at the region server level and avoid sending the data back to the client to be filtered.
+ */
+@InterfaceAudience.Private
+public class SparkSQLPushDownFilter extends FilterBase {
+ protected static final Logger log = LoggerFactory.getLogger(SparkSQLPushDownFilter.class);
+
+ // The following values are populated with protobuffer
+ DynamicLogicExpression dynamicLogicExpression;
+ byte[][] valueFromQueryArray;
+ HashMap> currentCellToColumnIndexMap;
+
+ // The following values are transient
+ HashMap columnToCurrentRowValueMap = null;
+
+ static final byte[] rowKeyFamily = new byte[0];
+ static final byte[] rowKeyQualifier = Bytes.toBytes("key");
+
+ String encoderClassName;
+
+ public SparkSQLPushDownFilter(DynamicLogicExpression dynamicLogicExpression,
+ byte[][] valueFromQueryArray,
+ HashMap> currentCellToColumnIndexMap,
+ String encoderClassName) {
+ this.dynamicLogicExpression = dynamicLogicExpression;
+ this.valueFromQueryArray = valueFromQueryArray;
+ this.currentCellToColumnIndexMap = currentCellToColumnIndexMap;
+ this.encoderClassName = encoderClassName;
+ }
+
+ public SparkSQLPushDownFilter(DynamicLogicExpression dynamicLogicExpression,
+ byte[][] valueFromQueryArray, ListBuffer fields, String encoderClassName) {
+ this.dynamicLogicExpression = dynamicLogicExpression;
+ this.valueFromQueryArray = valueFromQueryArray;
+ this.encoderClassName = encoderClassName;
+
+ // generate family qualifier to index mapping
+ this.currentCellToColumnIndexMap = new HashMap<>();
+
+ for (int i = 0; i < fields.size(); i++) {
+ Field field = fields.apply(i);
+
+ byte[] cfBytes = field.cfBytes();
+ ByteArrayComparable familyByteComparable =
+ new ByteArrayComparable(cfBytes, 0, cfBytes.length);
+
+ HashMap qualifierIndexMap =
+ currentCellToColumnIndexMap.get(familyByteComparable);
+
+ if (qualifierIndexMap == null) {
+ qualifierIndexMap = new HashMap<>();
+ currentCellToColumnIndexMap.put(familyByteComparable, qualifierIndexMap);
+ }
+ byte[] qBytes = field.colBytes();
+ ByteArrayComparable qualifierByteComparable =
+ new ByteArrayComparable(qBytes, 0, qBytes.length);
+
+ qualifierIndexMap.put(qualifierByteComparable, field.colName());
+ }
+ }
+
+ @Override
+ public ReturnCode filterCell(final Cell c) throws IOException {
+
+ // If the map RowValueMap is empty then we need to populate
+ // the row key
+ if (columnToCurrentRowValueMap == null) {
+ columnToCurrentRowValueMap = new HashMap<>();
+ HashMap qualifierColumnMap = currentCellToColumnIndexMap
+ .get(new ByteArrayComparable(rowKeyFamily, 0, rowKeyFamily.length));
+
+ if (qualifierColumnMap != null) {
+ String rowKeyColumnName = qualifierColumnMap
+ .get(new ByteArrayComparable(rowKeyQualifier, 0, rowKeyQualifier.length));
+ // Make sure that the rowKey is part of the where clause
+ if (rowKeyColumnName != null) {
+ columnToCurrentRowValueMap.put(rowKeyColumnName,
+ new ByteArrayComparable(c.getRowArray(), c.getRowOffset(), c.getRowLength()));
+ }
+ }
+ }
+
+ // Always populate the column value into the RowValueMap
+ ByteArrayComparable currentFamilyByteComparable =
+ new ByteArrayComparable(c.getFamilyArray(), c.getFamilyOffset(), c.getFamilyLength());
+
+ HashMap qualifierColumnMap =
+ currentCellToColumnIndexMap.get(currentFamilyByteComparable);
+
+ if (qualifierColumnMap != null) {
+
+ String columnName = qualifierColumnMap.get(new ByteArrayComparable(c.getQualifierArray(),
+ c.getQualifierOffset(), c.getQualifierLength()));
+
+ if (columnName != null) {
+ columnToCurrentRowValueMap.put(columnName,
+ new ByteArrayComparable(c.getValueArray(), c.getValueOffset(), c.getValueLength()));
+ }
+ }
+
+ return ReturnCode.INCLUDE;
+ }
+
+ @Override
+ public boolean filterRow() throws IOException {
+
+ try {
+ boolean result =
+ dynamicLogicExpression.execute(columnToCurrentRowValueMap, valueFromQueryArray);
+ columnToCurrentRowValueMap = null;
+ return !result;
+ } catch (Throwable e) {
+ log.error("Error running dynamic logic on row", e);
+ }
+ return false;
+ }
+
+ /**
+ * @param pbBytes A pb serialized instance
+ * @return An instance of SparkSQLPushDownFilter
+ * @throws DeserializationException if the filter cannot be parsed from the given bytes
+ */
+ @SuppressWarnings("unused")
+ public static SparkSQLPushDownFilter parseFrom(final byte[] pbBytes)
+ throws DeserializationException {
+
+ SparkFilterProtos.SQLPredicatePushDownFilter proto;
+ try {
+ proto = SparkFilterProtos.SQLPredicatePushDownFilter.parseFrom(pbBytes);
+ } catch (InvalidProtocolBufferException e) {
+ throw new DeserializationException(e);
+ }
+
+ String encoder = proto.getEncoderClassName();
+ BytesEncoder enc = JavaBytesEncoder.create(encoder);
+
+ // Load DynamicLogicExpression
+ DynamicLogicExpression dynamicLogicExpression =
+ DynamicLogicExpressionBuilder.build(proto.getDynamicLogicExpression(), enc);
+
+ // Load valuesFromQuery
+ final List valueFromQueryArrayList = proto.getValueFromQueryArrayList();
+ byte[][] valueFromQueryArray = new byte[valueFromQueryArrayList.size()][];
+ for (int i = 0; i < valueFromQueryArrayList.size(); i++) {
+ valueFromQueryArray[i] = valueFromQueryArrayList.get(i).toByteArray();
+ }
+
+ // Load mapping from HBase family/qualifier to Spark SQL columnName
+ HashMap> currentCellToColumnIndexMap =
+ new HashMap<>();
+
+ for (SparkFilterProtos.SQLPredicatePushDownCellToColumnMapping sqlPredicatePushDownCellToColumnMapping : proto
+ .getCellToColumnMappingList()) {
+
+ byte[] familyArray = sqlPredicatePushDownCellToColumnMapping.getColumnFamily().toByteArray();
+ ByteArrayComparable familyByteComparable =
+ new ByteArrayComparable(familyArray, 0, familyArray.length);
+ HashMap qualifierMap =
+ currentCellToColumnIndexMap.get(familyByteComparable);
+
+ if (qualifierMap == null) {
+ qualifierMap = new HashMap<>();
+ currentCellToColumnIndexMap.put(familyByteComparable, qualifierMap);
+ }
+ byte[] qualifierArray = sqlPredicatePushDownCellToColumnMapping.getQualifier().toByteArray();
+
+ ByteArrayComparable qualifierByteComparable =
+ new ByteArrayComparable(qualifierArray, 0, qualifierArray.length);
+
+ qualifierMap.put(qualifierByteComparable,
+ sqlPredicatePushDownCellToColumnMapping.getColumnName());
+ }
+
+ return new SparkSQLPushDownFilter(dynamicLogicExpression, valueFromQueryArray,
+ currentCellToColumnIndexMap, encoder);
+ }
+
+ /** Returns The filter serialized using pb */
+ public byte[] toByteArray() {
+
+ SparkFilterProtos.SQLPredicatePushDownFilter.Builder builder =
+ SparkFilterProtos.SQLPredicatePushDownFilter.newBuilder();
+
+ SparkFilterProtos.SQLPredicatePushDownCellToColumnMapping.Builder columnMappingBuilder =
+ SparkFilterProtos.SQLPredicatePushDownCellToColumnMapping.newBuilder();
+
+ builder.setDynamicLogicExpression(dynamicLogicExpression.toExpressionString());
+ for (byte[] valueFromQuery : valueFromQueryArray) {
+ builder.addValueFromQueryArray(ByteString.copyFrom(valueFromQuery));
+ }
+
+ for (Map.Entry> familyEntry : currentCellToColumnIndexMap.entrySet()) {
+ for (Map.Entry qualifierEntry : familyEntry.getValue()
+ .entrySet()) {
+ columnMappingBuilder.setColumnFamily(ByteString.copyFrom(familyEntry.getKey().bytes()));
+ columnMappingBuilder.setQualifier(ByteString.copyFrom(qualifierEntry.getKey().bytes()));
+ columnMappingBuilder.setColumnName(qualifierEntry.getValue());
+ builder.addCellToColumnMapping(columnMappingBuilder.build());
+ }
+ }
+ builder.setEncoderClassName(encoderClassName);
+
+ return builder.build().toByteArray();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!(obj instanceof SparkSQLPushDownFilter)) {
+ return false;
+ }
+ if (this == obj) {
+ return true;
+ }
+ SparkSQLPushDownFilter f = (SparkSQLPushDownFilter) obj;
+ if (this.valueFromQueryArray.length != f.valueFromQueryArray.length) {
+ return false;
+ }
+ int i = 0;
+ for (byte[] val : this.valueFromQueryArray) {
+ if (!Bytes.equals(val, f.valueFromQueryArray[i])) {
+ return false;
+ }
+ i++;
+ }
+ return this.dynamicLogicExpression.equals(f.dynamicLogicExpression)
+ && this.currentCellToColumnIndexMap.equals(f.currentCellToColumnIndexMap)
+ && this.encoderClassName.equals(f.encoderClassName);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(this.dynamicLogicExpression, Arrays.hashCode(this.valueFromQueryArray),
+ this.currentCellToColumnIndexMap, this.encoderClassName);
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkDeleteExample.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkDeleteExample.java
new file mode 100644
index 00000000..8be0acc9
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkDeleteExample.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.HBaseConfiguration;
+import org.apache.hadoop.hbase.TableName;
+import org.apache.hadoop.hbase.client.Delete;
+import org.apache.hadoop.hbase.spark.JavaHBaseContext;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.yetus.audience.InterfaceAudience;
+
+/**
+ * This is a simple example of deleting records in HBase with the bulkDelete function.
+ */
+@InterfaceAudience.Private
+final public class JavaHBaseBulkDeleteExample {
+
+ private JavaHBaseBulkDeleteExample() {
+ }
+
+ public static void main(String[] args) {
+ if (args.length < 1) {
+ System.out.println("JavaHBaseBulkDeleteExample {tableName}");
+ return;
+ }
+
+ String tableName = args[0];
+
+ SparkConf sparkConf = new SparkConf().setAppName("JavaHBaseBulkDeleteExample " + tableName);
+ JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+
+ try {
+ List list = new ArrayList<>(5);
+ list.add(Bytes.toBytes("1"));
+ list.add(Bytes.toBytes("2"));
+ list.add(Bytes.toBytes("3"));
+ list.add(Bytes.toBytes("4"));
+ list.add(Bytes.toBytes("5"));
+
+ JavaRDD rdd = jsc.parallelize(list);
+
+ Configuration conf = HBaseConfiguration.create();
+
+ JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
+
+ hbaseContext.bulkDelete(rdd, TableName.valueOf(tableName), new DeleteFunction(), 4);
+ } finally {
+ jsc.stop();
+ }
+
+ }
+
+ public static class DeleteFunction implements Function {
+ private static final long serialVersionUID = 1L;
+
+ public Delete call(byte[] v) throws Exception {
+ return new Delete(v);
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkGetExample.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkGetExample.java
new file mode 100644
index 00000000..8ff21ea4
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkGetExample.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.Cell;
+import org.apache.hadoop.hbase.HBaseConfiguration;
+import org.apache.hadoop.hbase.TableName;
+import org.apache.hadoop.hbase.client.Get;
+import org.apache.hadoop.hbase.client.Result;
+import org.apache.hadoop.hbase.spark.JavaHBaseContext;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.yetus.audience.InterfaceAudience;
+
+/**
+ * This is a simple example of getting records in HBase with the bulkGet function.
+ */
+@InterfaceAudience.Private
+final public class JavaHBaseBulkGetExample {
+
+ private JavaHBaseBulkGetExample() {
+ }
+
+ public static void main(String[] args) {
+ if (args.length < 1) {
+ System.out.println("JavaHBaseBulkGetExample {tableName}");
+ return;
+ }
+
+ String tableName = args[0];
+
+ SparkConf sparkConf = new SparkConf().setAppName("JavaHBaseBulkGetExample " + tableName);
+ JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+
+ try {
+ List list = new ArrayList<>(5);
+ list.add(Bytes.toBytes("1"));
+ list.add(Bytes.toBytes("2"));
+ list.add(Bytes.toBytes("3"));
+ list.add(Bytes.toBytes("4"));
+ list.add(Bytes.toBytes("5"));
+
+ JavaRDD rdd = jsc.parallelize(list);
+
+ Configuration conf = HBaseConfiguration.create();
+
+ JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
+
+ hbaseContext.bulkGet(TableName.valueOf(tableName), 2, rdd, new GetFunction(),
+ new ResultFunction());
+ } finally {
+ jsc.stop();
+ }
+ }
+
+ public static class GetFunction implements Function {
+
+ private static final long serialVersionUID = 1L;
+
+ public Get call(byte[] v) throws Exception {
+ return new Get(v);
+ }
+ }
+
+ public static class ResultFunction implements Function {
+
+ private static final long serialVersionUID = 1L;
+
+ public String call(Result result) throws Exception {
+ Iterator it = result.listCells().iterator();
+ StringBuilder b = new StringBuilder();
+
+ b.append(Bytes.toString(result.getRow())).append(":");
+
+ while (it.hasNext()) {
+ Cell cell = it.next();
+ String q = Bytes.toString(cell.getQualifierArray());
+ if (q.equals("counter")) {
+ b.append("(").append(Bytes.toString(cell.getQualifierArray())).append(",")
+ .append(Bytes.toLong(cell.getValueArray())).append(")");
+ } else {
+ b.append("(").append(Bytes.toString(cell.getQualifierArray())).append(",")
+ .append(Bytes.toString(cell.getValueArray())).append(")");
+ }
+ }
+ return b.toString();
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkLoadExample.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkLoadExample.java
new file mode 100644
index 00000000..44f1b339
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkLoadExample.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.HBaseConfiguration;
+import org.apache.hadoop.hbase.HConstants;
+import org.apache.hadoop.hbase.TableName;
+import org.apache.hadoop.hbase.spark.FamilyHFileWriteOptions;
+import org.apache.hadoop.hbase.spark.JavaHBaseContext;
+import org.apache.hadoop.hbase.spark.KeyFamilyQualifier;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.apache.hadoop.hbase.util.Pair;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.yetus.audience.InterfaceAudience;
+
+/**
+ * Run this example using command below: SPARK_HOME/bin/spark-submit --master local[2] --class
+ * org.apache.hadoop.hbase.spark.example.hbasecontext.JavaHBaseBulkLoadExample
+ * path/to/hbase-spark.jar {path/to/output/HFiles} This example will output put hfiles in
+ * {path/to/output/HFiles}, and user can run 'hbase
+ * org.apache.hadoop.hbase.tool.LoadIncrementalHFiles' to load the HFiles into table to verify this
+ * example.
+ */
+@InterfaceAudience.Private
+final public class JavaHBaseBulkLoadExample {
+ private JavaHBaseBulkLoadExample() {
+ }
+
+ public static void main(String[] args) {
+ if (args.length < 1) {
+ System.out.println("JavaHBaseBulkLoadExample " + "{outputPath}");
+ return;
+ }
+
+ String tableName = "bulkload-table-test";
+ String columnFamily1 = "f1";
+ String columnFamily2 = "f2";
+
+ SparkConf sparkConf = new SparkConf().setAppName("JavaHBaseBulkLoadExample " + tableName);
+ JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+
+ try {
+ List list = new ArrayList();
+ // row1
+ list.add("1," + columnFamily1 + ",b,1");
+ // row3
+ list.add("3," + columnFamily1 + ",a,2");
+ list.add("3," + columnFamily1 + ",b,1");
+ list.add("3," + columnFamily2 + ",a,1");
+ /* row2 */
+ list.add("2," + columnFamily2 + ",a,3");
+ list.add("2," + columnFamily2 + ",b,3");
+
+ JavaRDD rdd = jsc.parallelize(list);
+
+ Configuration conf = HBaseConfiguration.create();
+ JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
+
+ hbaseContext.bulkLoad(rdd, TableName.valueOf(tableName), new BulkLoadFunction(), args[0],
+ new HashMap(), false, HConstants.DEFAULT_MAX_FILE_SIZE);
+ } finally {
+ jsc.stop();
+ }
+ }
+
+ public static class BulkLoadFunction
+ implements Function> {
+ @Override
+ public Pair call(String v1) throws Exception {
+ if (v1 == null) {
+ return null;
+ }
+
+ String[] strs = v1.split(",");
+ if (strs.length != 4) {
+ return null;
+ }
+
+ KeyFamilyQualifier kfq = new KeyFamilyQualifier(Bytes.toBytes(strs[0]),
+ Bytes.toBytes(strs[1]), Bytes.toBytes(strs[2]));
+ return new Pair(kfq, Bytes.toBytes(strs[3]));
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkPutExample.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkPutExample.java
new file mode 100644
index 00000000..5f685a14
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkPutExample.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.HBaseConfiguration;
+import org.apache.hadoop.hbase.TableName;
+import org.apache.hadoop.hbase.client.Put;
+import org.apache.hadoop.hbase.spark.JavaHBaseContext;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.yetus.audience.InterfaceAudience;
+
+/**
+ * This is a simple example of putting records in HBase with the bulkPut function.
+ */
+@InterfaceAudience.Private
+final public class JavaHBaseBulkPutExample {
+
+ private JavaHBaseBulkPutExample() {
+ }
+
+ public static void main(String[] args) {
+ if (args.length < 2) {
+ System.out.println("JavaHBaseBulkPutExample " + "{tableName} {columnFamily}");
+ return;
+ }
+
+ String tableName = args[0];
+ String columnFamily = args[1];
+
+ SparkConf sparkConf = new SparkConf().setAppName("JavaHBaseBulkPutExample " + tableName);
+ JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+
+ try {
+ List list = new ArrayList<>(5);
+ list.add("1," + columnFamily + ",a,1");
+ list.add("2," + columnFamily + ",a,2");
+ list.add("3," + columnFamily + ",a,3");
+ list.add("4," + columnFamily + ",a,4");
+ list.add("5," + columnFamily + ",a,5");
+
+ JavaRDD rdd = jsc.parallelize(list);
+
+ Configuration conf = HBaseConfiguration.create();
+
+ JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
+
+ hbaseContext.bulkPut(rdd, TableName.valueOf(tableName), new PutFunction());
+ } finally {
+ jsc.stop();
+ }
+ }
+
+ public static class PutFunction implements Function {
+
+ private static final long serialVersionUID = 1L;
+
+ public Put call(String v) throws Exception {
+ String[] cells = v.split(",");
+ Put put = new Put(Bytes.toBytes(cells[0]));
+
+ put.addColumn(Bytes.toBytes(cells[1]), Bytes.toBytes(cells[2]), Bytes.toBytes(cells[3]));
+ return put;
+ }
+
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseDistributedScan.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseDistributedScan.java
new file mode 100644
index 00000000..76c7f6d8
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseDistributedScan.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext;
+
+import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.HBaseConfiguration;
+import org.apache.hadoop.hbase.TableName;
+import org.apache.hadoop.hbase.client.Result;
+import org.apache.hadoop.hbase.client.Scan;
+import org.apache.hadoop.hbase.io.ImmutableBytesWritable;
+import org.apache.hadoop.hbase.spark.JavaHBaseContext;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.yetus.audience.InterfaceAudience;
+import scala.Tuple2;
+
+/**
+ * This is a simple example of scanning records from HBase with the hbaseRDD function.
+ */
+@InterfaceAudience.Private
+final public class JavaHBaseDistributedScan {
+
+ private JavaHBaseDistributedScan() {
+ }
+
+ public static void main(String[] args) {
+ if (args.length < 1) {
+ System.out.println("JavaHBaseDistributedScan {tableName}");
+ return;
+ }
+
+ String tableName = args[0];
+
+ SparkConf sparkConf = new SparkConf().setAppName("JavaHBaseDistributedScan " + tableName);
+ JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+
+ try {
+ Configuration conf = HBaseConfiguration.create();
+
+ JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
+
+ Scan scan = new Scan();
+ scan.setCaching(100);
+
+ JavaRDD> javaRdd =
+ hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan);
+
+ List results = javaRdd.map(new ScanConvertFunction()).collect();
+
+ System.out.println("Result Size: " + results.size());
+ } finally {
+ jsc.stop();
+ }
+ }
+
+ private static class ScanConvertFunction
+ implements Function, String> {
+ @Override
+ public String call(Tuple2 v1) throws Exception {
+ return Bytes.toString(v1._1().copyBytes());
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseMapGetPutExample.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseMapGetPutExample.java
new file mode 100644
index 00000000..c516ab35
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseMapGetPutExample.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.HBaseConfiguration;
+import org.apache.hadoop.hbase.TableName;
+import org.apache.hadoop.hbase.client.BufferedMutator;
+import org.apache.hadoop.hbase.client.Connection;
+import org.apache.hadoop.hbase.client.Get;
+import org.apache.hadoop.hbase.client.Put;
+import org.apache.hadoop.hbase.client.Result;
+import org.apache.hadoop.hbase.client.Table;
+import org.apache.hadoop.hbase.spark.JavaHBaseContext;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.VoidFunction;
+import org.apache.yetus.audience.InterfaceAudience;
+import scala.Tuple2;
+
+/**
+ * This is a simple example of using the foreachPartition method with a HBase connection
+ */
+@InterfaceAudience.Private
+final public class JavaHBaseMapGetPutExample {
+
+ private JavaHBaseMapGetPutExample() {
+ }
+
+ public static void main(String[] args) {
+ if (args.length < 1) {
+ System.out.println("JavaHBaseBulkGetExample {tableName}");
+ return;
+ }
+
+ final String tableName = args[0];
+
+ SparkConf sparkConf = new SparkConf().setAppName("JavaHBaseBulkGetExample " + tableName);
+ JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+
+ try {
+ List list = new ArrayList<>(5);
+ list.add(Bytes.toBytes("1"));
+ list.add(Bytes.toBytes("2"));
+ list.add(Bytes.toBytes("3"));
+ list.add(Bytes.toBytes("4"));
+ list.add(Bytes.toBytes("5"));
+
+ JavaRDD rdd = jsc.parallelize(list);
+ Configuration conf = HBaseConfiguration.create();
+
+ JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
+
+ hbaseContext.foreachPartition(rdd, new VoidFunction, Connection>>() {
+ public void call(Tuple2, Connection> t) throws Exception {
+ Table table = t._2().getTable(TableName.valueOf(tableName));
+ BufferedMutator mutator = t._2().getBufferedMutator(TableName.valueOf(tableName));
+
+ while (t._1().hasNext()) {
+ byte[] b = t._1().next();
+ Result r = table.get(new Get(b));
+ if (r.getExists()) {
+ mutator.mutate(new Put(b));
+ }
+ }
+
+ mutator.flush();
+ mutator.close();
+ table.close();
+ }
+ });
+ } finally {
+ jsc.stop();
+ }
+ }
+
+ public static class GetFunction implements Function {
+ private static final long serialVersionUID = 1L;
+
+ public Get call(byte[] v) throws Exception {
+ return new Get(v);
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseStreamingBulkPutExample.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseStreamingBulkPutExample.java
new file mode 100644
index 00000000..dc07a562
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseStreamingBulkPutExample.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.HBaseConfiguration;
+import org.apache.hadoop.hbase.TableName;
+import org.apache.hadoop.hbase.client.Put;
+import org.apache.hadoop.hbase.spark.JavaHBaseContext;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.apache.yetus.audience.InterfaceAudience;
+
+/**
+ * This is a simple example of BulkPut with Spark Streaming
+ */
+@InterfaceAudience.Private
+final public class JavaHBaseStreamingBulkPutExample {
+
+ private JavaHBaseStreamingBulkPutExample() {
+ }
+
+ public static void main(String[] args) {
+ if (args.length < 4) {
+ System.out.println("JavaHBaseBulkPutExample " + "{host} {port} {tableName}");
+ return;
+ }
+
+ String host = args[0];
+ String port = args[1];
+ String tableName = args[2];
+
+ SparkConf sparkConf = new SparkConf()
+ .setAppName("JavaHBaseStreamingBulkPutExample " + tableName + ":" + port + ":" + tableName);
+
+ JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+
+ try {
+ JavaStreamingContext jssc = new JavaStreamingContext(jsc, new Duration(1000));
+
+ JavaReceiverInputDStream javaDstream =
+ jssc.socketTextStream(host, Integer.parseInt(port));
+
+ Configuration conf = HBaseConfiguration.create();
+
+ JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
+
+ hbaseContext.streamBulkPut(javaDstream, TableName.valueOf(tableName), new PutFunction());
+ } finally {
+ jsc.stop();
+ }
+ }
+
+ public static class PutFunction implements Function {
+
+ private static final long serialVersionUID = 1L;
+
+ public Put call(String v) throws Exception {
+ String[] part = v.split(",");
+ Put put = new Put(Bytes.toBytes(part[0]));
+
+ put.addColumn(Bytes.toBytes(part[1]), Bytes.toBytes(part[2]), Bytes.toBytes(part[3]));
+ return put;
+ }
+
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/BulkLoadPartitioner.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/BulkLoadPartitioner.scala
new file mode 100644
index 00000000..17316638
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/BulkLoadPartitioner.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.util
+import java.util.Comparator
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.Partitioner
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * A Partitioner implementation that will separate records to different
+ * HBase Regions based on region splits
+ *
+ * @param startKeys The start keys for the given table
+ */
+@InterfaceAudience.Public
+class BulkLoadPartitioner(startKeys: Array[Array[Byte]]) extends Partitioner {
+ // when table not exist, startKeys = Byte[0][]
+ override def numPartitions: Int = if (startKeys.length == 0) 1 else startKeys.length
+
+ override def getPartition(key: Any): Int = {
+
+ val comparator: Comparator[Array[Byte]] = new Comparator[Array[Byte]] {
+ override def compare(o1: Array[Byte], o2: Array[Byte]): Int = {
+ Bytes.compareTo(o1, o2)
+ }
+ }
+
+ val rowKey: Array[Byte] =
+ key match {
+ case qualifier: KeyFamilyQualifier =>
+ qualifier.rowKey
+ case wrapper: ByteArrayWrapper =>
+ wrapper.value
+ case _ =>
+ key.asInstanceOf[Array[Byte]]
+ }
+ var partition = util.Arrays.binarySearch(startKeys, rowKey, comparator)
+ if (partition < 0)
+ partition = partition * -1 + -2
+ if (partition < 0)
+ partition = 0
+ partition
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ByteArrayComparable.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ByteArrayComparable.scala
new file mode 100644
index 00000000..78cd3abc
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ByteArrayComparable.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.yetus.audience.InterfaceAudience
+
+@InterfaceAudience.Public
+class ByteArrayComparable(val bytes: Array[Byte], val offset: Int = 0, var length: Int = -1)
+ extends Comparable[ByteArrayComparable] {
+
+ if (length == -1) {
+ length = bytes.length
+ }
+
+ override def compareTo(o: ByteArrayComparable): Int = {
+ Bytes.compareTo(bytes, offset, length, o.bytes, o.offset, o.length)
+ }
+
+ override def hashCode(): Int = {
+ Bytes.hashCode(bytes, offset, length)
+ }
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case b: ByteArrayComparable =>
+ Bytes.equals(bytes, offset, length, b.bytes, b.offset, b.length)
+ case _ =>
+ false
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ByteArrayWrapper.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ByteArrayWrapper.scala
new file mode 100644
index 00000000..a774838e
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ByteArrayWrapper.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.io.Serializable
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a wrapper over a byte array so it can work as
+ * a key in a hashMap
+ *
+ * @param value The Byte Array value
+ */
+@InterfaceAudience.Public
+class ByteArrayWrapper(var value: Array[Byte])
+ extends Comparable[ByteArrayWrapper]
+ with Serializable {
+ override def compareTo(valueOther: ByteArrayWrapper): Int = {
+ Bytes.compareTo(value, valueOther.value)
+ }
+ override def equals(o2: Any): Boolean = {
+ o2 match {
+ case wrapper: ByteArrayWrapper =>
+ Bytes.equals(value, wrapper.value)
+ case _ =>
+ false
+ }
+ }
+ override def hashCode(): Int = {
+ Bytes.hashCode(value)
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ColumnFamilyQualifierMapKeyWrapper.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ColumnFamilyQualifierMapKeyWrapper.scala
new file mode 100644
index 00000000..84883992
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ColumnFamilyQualifierMapKeyWrapper.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * A wrapper class that will allow both columnFamily and qualifier to
+ * be the key of a hashMap. Also allow for finding the value in a hashmap
+ * with out cloning the HBase value from the HBase Cell object
+ * @param columnFamily ColumnFamily byte array
+ * @param columnFamilyOffSet Offset of columnFamily value in the array
+ * @param columnFamilyLength Length of the columnFamily value in the columnFamily array
+ * @param qualifier Qualifier byte array
+ * @param qualifierOffSet Offset of qualifier value in the array
+ * @param qualifierLength Length of the qualifier value with in the array
+ */
+@InterfaceAudience.Public
+class ColumnFamilyQualifierMapKeyWrapper(
+ val columnFamily: Array[Byte],
+ val columnFamilyOffSet: Int,
+ val columnFamilyLength: Int,
+ val qualifier: Array[Byte],
+ val qualifierOffSet: Int,
+ val qualifierLength: Int)
+ extends Serializable {
+
+ override def equals(other: Any): Boolean = {
+ val otherWrapper = other.asInstanceOf[ColumnFamilyQualifierMapKeyWrapper]
+
+ Bytes.compareTo(
+ columnFamily,
+ columnFamilyOffSet,
+ columnFamilyLength,
+ otherWrapper.columnFamily,
+ otherWrapper.columnFamilyOffSet,
+ otherWrapper.columnFamilyLength) == 0 && Bytes.compareTo(
+ qualifier,
+ qualifierOffSet,
+ qualifierLength,
+ otherWrapper.qualifier,
+ otherWrapper.qualifierOffSet,
+ otherWrapper.qualifierLength) == 0
+ }
+
+ override def hashCode(): Int = {
+ Bytes.hashCode(columnFamily, columnFamilyOffSet, columnFamilyLength) +
+ Bytes.hashCode(qualifier, qualifierOffSet, qualifierLength)
+ }
+
+ def cloneColumnFamily(): Array[Byte] = {
+ val resultArray = new Array[Byte](columnFamilyLength)
+ System.arraycopy(columnFamily, columnFamilyOffSet, resultArray, 0, columnFamilyLength)
+ resultArray
+ }
+
+ def cloneQualifier(): Array[Byte] = {
+ val resultArray = new Array[Byte](qualifierLength)
+ System.arraycopy(qualifier, qualifierOffSet, resultArray, 0, qualifierLength)
+ resultArray
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
new file mode 100644
index 00000000..06dea823
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
@@ -0,0 +1,1232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.util
+import java.util.concurrent.ConcurrentLinkedQueue
+import org.apache.hadoop.hbase.CellUtil
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.HColumnDescriptor
+import org.apache.hadoop.hbase.HTableDescriptor
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client._
+import org.apache.hadoop.hbase.io.ImmutableBytesWritable
+import org.apache.hadoop.hbase.mapred.TableOutputFormat
+import org.apache.hadoop.hbase.spark.datasources._
+import org.apache.hadoop.hbase.types._
+import org.apache.hadoop.hbase.util.{Bytes, PositionedByteRange, SimplePositionedMutableByteRange}
+import org.apache.hadoop.mapred.JobConf
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+import org.apache.yetus.audience.InterfaceAudience
+import scala.collection.mutable
+
+/**
+ * DefaultSource for integration with Spark's dataframe datasources.
+ * This class will produce a relationProvider based on input given to it from spark
+ *
+ * This class needs to stay in the current package 'org.apache.hadoop.hbase.spark'
+ * for Spark to match the hbase data source name.
+ *
+ * In all this DefaultSource support the following datasource functionality
+ * - Scan range pruning through filter push down logic based on rowKeys
+ * - Filter push down logic on HBase Cells
+ * - Qualifier filtering based on columns used in the SparkSQL statement
+ * - Type conversions of basic SQL types. All conversions will be
+ * Through the HBase Bytes object commands.
+ */
+@InterfaceAudience.Private
+class DefaultSource extends RelationProvider with CreatableRelationProvider with Logging {
+
+ /**
+ * Is given input from SparkSQL to construct a BaseRelation
+ *
+ * @param sqlContext SparkSQL context
+ * @param parameters Parameters given to us from SparkSQL
+ * @return A BaseRelation Object
+ */
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String]): BaseRelation = {
+ new HBaseRelation(parameters, None)(sqlContext)
+ }
+
+ override def createRelation(
+ sqlContext: SQLContext,
+ mode: SaveMode,
+ parameters: Map[String, String],
+ data: DataFrame): BaseRelation = {
+ val relation = HBaseRelation(parameters, Some(data.schema))(sqlContext)
+ relation.createTable()
+ relation.insert(data, false)
+ relation
+ }
+}
+
+/**
+ * Implementation of Spark BaseRelation that will build up our scan logic
+ * , do the scan pruning, filter push down, and value conversions
+ *
+ * @param sqlContext SparkSQL context
+ */
+@InterfaceAudience.Private
+case class HBaseRelation(
+ @transient parameters: Map[String, String],
+ userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext)
+ extends BaseRelation
+ with PrunedFilteredScan
+ with InsertableRelation
+ with Logging {
+ val timestamp = parameters.get(HBaseSparkConf.TIMESTAMP).map(_.toLong)
+ val minTimestamp = parameters.get(HBaseSparkConf.TIMERANGE_START).map(_.toLong)
+ val maxTimestamp = parameters.get(HBaseSparkConf.TIMERANGE_END).map(_.toLong)
+ val maxVersions = parameters.get(HBaseSparkConf.MAX_VERSIONS).map(_.toInt)
+ val encoderClsName =
+ parameters.get(HBaseSparkConf.QUERY_ENCODER).getOrElse(HBaseSparkConf.DEFAULT_QUERY_ENCODER)
+
+ @transient val encoder = JavaBytesEncoder.create(encoderClsName)
+
+ val catalog = HBaseTableCatalog(parameters)
+ def tableName = s"${catalog.namespace}:${catalog.name}"
+ val configResources = parameters.get(HBaseSparkConf.HBASE_CONFIG_LOCATION)
+ val useHBaseContext = parameters
+ .get(HBaseSparkConf.USE_HBASECONTEXT)
+ .map(_.toBoolean)
+ .getOrElse(HBaseSparkConf.DEFAULT_USE_HBASECONTEXT)
+ val usePushDownColumnFilter = parameters
+ .get(HBaseSparkConf.PUSHDOWN_COLUMN_FILTER)
+ .map(_.toBoolean)
+ .getOrElse(HBaseSparkConf.DEFAULT_PUSHDOWN_COLUMN_FILTER)
+
+ // The user supplied per table parameter will overwrite global ones in SparkConf
+ val blockCacheEnable = parameters
+ .get(HBaseSparkConf.QUERY_CACHEBLOCKS)
+ .map(_.toBoolean)
+ .getOrElse(sqlContext.sparkContext.getConf
+ .getBoolean(HBaseSparkConf.QUERY_CACHEBLOCKS, HBaseSparkConf.DEFAULT_QUERY_CACHEBLOCKS))
+ val cacheSize = parameters
+ .get(HBaseSparkConf.QUERY_CACHEDROWS)
+ .map(_.toInt)
+ .getOrElse(sqlContext.sparkContext.getConf.getInt(HBaseSparkConf.QUERY_CACHEDROWS, -1))
+ val batchNum = parameters
+ .get(HBaseSparkConf.QUERY_BATCHSIZE)
+ .map(_.toInt)
+ .getOrElse(sqlContext.sparkContext.getConf.getInt(HBaseSparkConf.QUERY_BATCHSIZE, -1))
+
+ val bulkGetSize = parameters
+ .get(HBaseSparkConf.BULKGET_SIZE)
+ .map(_.toInt)
+ .getOrElse(sqlContext.sparkContext.getConf
+ .getInt(HBaseSparkConf.BULKGET_SIZE, HBaseSparkConf.DEFAULT_BULKGET_SIZE))
+
+ // create or get latest HBaseContext
+ val hbaseContext: HBaseContext = if (useHBaseContext) {
+ LatestHBaseContextCache.latest
+ } else {
+ val hadoopConfig = sqlContext.sparkContext.hadoopConfiguration
+ val config = HBaseConfiguration.create(hadoopConfig)
+ configResources.map(
+ resource =>
+ resource
+ .split(",")
+ .foreach(r => config.addResource(r)))
+ new HBaseContext(sqlContext.sparkContext, config)
+ }
+
+ val wrappedConf = new SerializableConfiguration(hbaseContext.config)
+ def hbaseConf = wrappedConf.value
+
+ /**
+ * Generates a Spark SQL schema objeparametersct so Spark SQL knows what is being
+ * provided by this BaseRelation
+ *
+ * @return schema generated from the SCHEMA_COLUMNS_MAPPING_KEY value
+ */
+ override val schema: StructType = userSpecifiedSchema.getOrElse(catalog.toDataType)
+
+ def createTable() {
+ val numReg = parameters
+ .get(HBaseTableCatalog.newTable)
+ .map(x => x.toInt)
+ .getOrElse(0)
+ val startKey = Bytes.toBytes(
+ parameters
+ .get(HBaseTableCatalog.regionStart)
+ .getOrElse(HBaseTableCatalog.defaultRegionStart))
+ val endKey = Bytes.toBytes(
+ parameters
+ .get(HBaseTableCatalog.regionEnd)
+ .getOrElse(HBaseTableCatalog.defaultRegionEnd))
+ if (numReg > 3) {
+ val tName = TableName.valueOf(tableName)
+ val cfs = catalog.getColumnFamilies
+
+ val connection = HBaseConnectionCache.getConnection(hbaseConf)
+ // Initialize hBase table if necessary
+ val admin = connection.getAdmin
+ try {
+ if (!admin.tableExists(tName)) {
+ val tableDesc = new HTableDescriptor(tName)
+ cfs.foreach {
+ x =>
+ val cf = new HColumnDescriptor(x.getBytes())
+ logDebug(s"add family $x to ${tableName}")
+ tableDesc.addFamily(cf)
+ }
+ val splitKeys = Bytes.split(startKey, endKey, numReg);
+ admin.createTable(tableDesc, splitKeys)
+
+ }
+ } finally {
+ admin.close()
+ connection.close()
+ }
+ } else {
+ logInfo(s"""${HBaseTableCatalog.newTable}
+ |is not defined or no larger than 3, skip the create table""".stripMargin)
+ }
+ }
+
+ /**
+ * @param data
+ * @param overwrite
+ */
+ override def insert(data: DataFrame, overwrite: Boolean): Unit = {
+ val jobConfig: JobConf = new JobConf(hbaseConf, this.getClass)
+ jobConfig.setOutputFormat(classOf[TableOutputFormat])
+ jobConfig.set(TableOutputFormat.OUTPUT_TABLE, tableName)
+ var count = 0
+ val rkFields = catalog.getRowKey
+ val rkIdxedFields = rkFields.map {
+ case x =>
+ (schema.fieldIndex(x.colName), x)
+ }
+ val colsIdxedFields = schema.fieldNames
+ .partition(x => rkFields.map(_.colName).contains(x))
+ ._2
+ .map(x => (schema.fieldIndex(x), catalog.getField(x)))
+ val rdd = data.rdd
+ def convertToPut(row: Row) = {
+ // construct bytes for row key
+ val rowBytes = rkIdxedFields.map {
+ case (x, y) =>
+ Utils.toBytes(row(x), y)
+ }
+ val rLen = rowBytes.foldLeft(0) {
+ case (x, y) =>
+ x + y.length
+ }
+ val rBytes = new Array[Byte](rLen)
+ var offset = 0
+ rowBytes.foreach {
+ x =>
+ System.arraycopy(x, 0, rBytes, offset, x.length)
+ offset += x.length
+ }
+ val put = timestamp.fold(new Put(rBytes))(new Put(rBytes, _))
+
+ colsIdxedFields.foreach {
+ case (x, y) =>
+ val r = row(x)
+ if (r != null) {
+ val b = Utils.toBytes(r, y)
+ put.addColumn(Bytes.toBytes(y.cf), Bytes.toBytes(y.col), b)
+ }
+ }
+ count += 1
+ (new ImmutableBytesWritable, put)
+ }
+ rdd.map(convertToPut(_)).saveAsHadoopDataset(jobConfig)
+ }
+
+ def getIndexedProjections(requiredColumns: Array[String]): Seq[(Field, Int)] = {
+ requiredColumns.map(catalog.sMap.getField(_)).zipWithIndex
+ }
+
+ /**
+ * Takes a HBase Row object and parses all of the fields from it.
+ * This is independent of which fields were requested from the key
+ * Because we have all the data it's less complex to parse everything.
+ *
+ * @param row the retrieved row from hbase.
+ * @param keyFields all of the fields in the row key, ORDERED by their order in the row key.
+ */
+ def parseRowKey(row: Array[Byte], keyFields: Seq[Field]): Map[Field, Any] = {
+ keyFields
+ .foldLeft((0, Seq[(Field, Any)]()))(
+ (state, field) => {
+ val idx = state._1
+ val parsed = state._2
+ if (field.length != -1) {
+ val value = Utils.hbaseFieldToScalaType(field, row, idx, field.length)
+ // Return the new index and appended value
+ (idx + field.length, parsed ++ Seq((field, value)))
+ } else {
+ field.dt match {
+ case StringType =>
+ val pos = row.indexOf(HBaseTableCatalog.delimiter, idx)
+ if (pos == -1 || pos > row.length) {
+ // this is at the last dimension
+ val value = Utils.hbaseFieldToScalaType(field, row, idx, row.length)
+ (row.length + 1, parsed ++ Seq((field, value)))
+ } else {
+ val value = Utils.hbaseFieldToScalaType(field, row, idx, pos - idx)
+ (pos, parsed ++ Seq((field, value)))
+ }
+ // We don't know the length, assume it extends to the end of the rowkey.
+ case _ =>
+ (
+ row.length + 1,
+ parsed ++ Seq((field, Utils.hbaseFieldToScalaType(field, row, idx, row.length))))
+ }
+ }
+ })
+ ._2
+ .toMap
+ }
+
+ def buildRow(fields: Seq[Field], result: Result): Row = {
+ val r = result.getRow
+ val keySeq = parseRowKey(r, catalog.getRowKey)
+ val valueSeq = fields
+ .filter(!_.isRowKey)
+ .map {
+ x =>
+ val kv = result.getColumnLatestCell(Bytes.toBytes(x.cf), Bytes.toBytes(x.col))
+ if (kv == null || kv.getValueLength == 0) {
+ (x, null)
+ } else {
+ val v = CellUtil.cloneValue(kv)
+ (
+ x,
+ x.dt match {
+ // Here, to avoid arraycopy, return v directly instead of calling hbaseFieldToScalaType
+ case BinaryType => v
+ case _ => Utils.hbaseFieldToScalaType(x, v, 0, v.length)
+ })
+ }
+ }
+ .toMap
+ val unionedRow = keySeq ++ valueSeq
+ // Return the row ordered by the requested order
+ Row.fromSeq(fields.map(unionedRow.get(_).getOrElse(null)))
+ }
+
+ /**
+ * Here we are building the functionality to populate the resulting RDD[Row]
+ * Here is where we will do the following:
+ * - Filter push down
+ * - Scan or GetList pruning
+ * - Executing our scan(s) or/and GetList to generate result
+ *
+ * @param requiredColumns The columns that are being requested by the requesting query
+ * @param filters The filters that are being applied by the requesting query
+ * @return RDD will all the results from HBase needed for SparkSQL to
+ * execute the query on
+ */
+ override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
+
+ val pushDownTuple = buildPushDownPredicatesResource(filters)
+ val pushDownRowKeyFilter = pushDownTuple._1
+ var pushDownDynamicLogicExpression = pushDownTuple._2
+ val valueArray = pushDownTuple._3
+
+ if (!usePushDownColumnFilter) {
+ pushDownDynamicLogicExpression = null
+ }
+
+ logDebug("pushDownRowKeyFilter: " + pushDownRowKeyFilter.ranges)
+ if (pushDownDynamicLogicExpression != null) {
+ logDebug(
+ "pushDownDynamicLogicExpression: " +
+ pushDownDynamicLogicExpression.toExpressionString)
+ }
+ logDebug("valueArray: " + valueArray.length)
+
+ val requiredQualifierDefinitionList =
+ new mutable.ListBuffer[Field]
+
+ requiredColumns.foreach(
+ c => {
+ val field = catalog.getField(c)
+ requiredQualifierDefinitionList += field
+ })
+
+ // retain the information for unit testing checks
+ DefaultSourceStaticUtils.populateLatestExecutionRules(
+ pushDownRowKeyFilter,
+ pushDownDynamicLogicExpression)
+
+ val getList = new util.ArrayList[Get]()
+ val rddList = new util.ArrayList[RDD[Row]]()
+
+ // add points to getList
+ pushDownRowKeyFilter.points.foreach(
+ p => {
+ val get = new Get(p)
+ requiredQualifierDefinitionList.foreach(
+ d => {
+ if (d.isRowKey)
+ get.addColumn(d.cfBytes, d.colBytes)
+ })
+ getList.add(get)
+ })
+
+ val pushDownFilterJava =
+ if (usePushDownColumnFilter && pushDownDynamicLogicExpression != null) {
+ Some(
+ new SparkSQLPushDownFilter(
+ pushDownDynamicLogicExpression,
+ valueArray,
+ requiredQualifierDefinitionList,
+ encoderClsName))
+ } else {
+ None
+ }
+ val hRdd = new HBaseTableScanRDD(
+ this,
+ hbaseContext,
+ pushDownFilterJava,
+ requiredQualifierDefinitionList.seq.toSeq)
+ pushDownRowKeyFilter.points.foreach(hRdd.addPoint(_))
+ pushDownRowKeyFilter.ranges.foreach(hRdd.addRange(_))
+
+ var resultRDD: RDD[Row] = {
+ val tmp = hRdd.map {
+ r =>
+ val indexedFields = getIndexedProjections(requiredColumns).map(_._1)
+ buildRow(indexedFields, r)
+
+ }
+ if (tmp.partitions.size > 0) {
+ tmp
+ } else {
+ null
+ }
+ }
+
+ if (resultRDD == null) {
+ val scan = new Scan()
+ scan.setCacheBlocks(blockCacheEnable)
+ scan.setBatch(batchNum)
+ scan.setCaching(cacheSize)
+ requiredQualifierDefinitionList.foreach(d => scan.addColumn(d.cfBytes, d.colBytes))
+
+ val rdd = hbaseContext
+ .hbaseRDD(TableName.valueOf(tableName), scan)
+ .map(
+ r => {
+ val indexedFields = getIndexedProjections(requiredColumns).map(_._1)
+ buildRow(indexedFields, r._2)
+ })
+ resultRDD = rdd
+ }
+ resultRDD
+ }
+
+ def buildPushDownPredicatesResource(
+ filters: Array[Filter]): (RowKeyFilter, DynamicLogicExpression, Array[Array[Byte]]) = {
+ var superRowKeyFilter: RowKeyFilter = null
+ val queryValueList = new mutable.ListBuffer [Array[Byte]]
+ var superDynamicLogicExpression: DynamicLogicExpression = null
+
+ filters.foreach(
+ f => {
+ val rowKeyFilter = new RowKeyFilter()
+ val logicExpression = transverseFilterTree(rowKeyFilter, queryValueList, f)
+ if (superDynamicLogicExpression == null) {
+ superDynamicLogicExpression = logicExpression
+ superRowKeyFilter = rowKeyFilter
+ } else {
+ superDynamicLogicExpression =
+ new AndLogicExpression(superDynamicLogicExpression, logicExpression)
+ superRowKeyFilter.mergeIntersect(rowKeyFilter)
+ }
+
+ })
+
+ val queryValueArray = queryValueList.toArray
+
+ if (superRowKeyFilter == null) {
+ superRowKeyFilter = new RowKeyFilter
+ }
+
+ (superRowKeyFilter, superDynamicLogicExpression, queryValueArray)
+ }
+
+ /**
+ * For some codec, the order may be inconsistent between java primitive
+ * type and its byte array. We may have to split the predicates on some
+ * of the java primitive type into multiple predicates. The encoder will take
+ * care of it and returning the concrete ranges.
+ *
+ * For example in naive codec, some of the java primitive types have to be split into multiple
+ * predicates, and union these predicates together to make the predicates be performed correctly.
+ * For example, if we have "COLUMN < 2", we will transform it into
+ * "0 <= COLUMN < 2 OR Integer.MIN_VALUE <= COLUMN <= -1"
+ */
+
+ def transverseFilterTree(
+ parentRowKeyFilter: RowKeyFilter,
+ valueArray: mutable.ListBuffer [Array[Byte]],
+ filter: Filter): DynamicLogicExpression = {
+ filter match {
+ case EqualTo(attr, value) =>
+ val field = catalog.getField(attr)
+ if (field != null) {
+ if (field.isRowKey) {
+ parentRowKeyFilter.mergeIntersect(new RowKeyFilter(Utils.toBytes(value, field), null))
+ }
+ val byteValue = Utils.toBytes(value, field)
+ valueArray += byteValue
+ }
+ new EqualLogicExpression(attr, valueArray.length - 1, false)
+
+ /**
+ * encoder may split the predicates into multiple byte array boundaries.
+ * Each boundaries is mapped into the RowKeyFilter and then is unioned by the reduce
+ * operation. If the data type is not supported, b will be None, and there is
+ * no operation happens on the parentRowKeyFilter.
+ *
+ * Note that because LessThan is not inclusive, thus the first bound should be exclusive,
+ * which is controlled by inc.
+ *
+ * The other predicates, i.e., GreaterThan/LessThanOrEqual/GreaterThanOrEqual follows
+ * the similar logic.
+ */
+ case LessThan(attr, value) =>
+ val field = catalog.getField(attr)
+ if (field != null) {
+ if (field.isRowKey) {
+ val b = encoder.ranges(value)
+ var inc = false
+ b.map(_.less.map {
+ x =>
+ val r = new RowKeyFilter(null, new ScanRange(x.upper, inc, x.low, true))
+ inc = true
+ r
+ }).map { x => x.reduce { (i, j) => i.mergeUnion(j) } }
+ .map(parentRowKeyFilter.mergeIntersect(_))
+ }
+ val byteValue = encoder.encode(field.dt, value)
+ valueArray += byteValue
+ }
+ new LessThanLogicExpression(attr, valueArray.length - 1)
+ case GreaterThan(attr, value) =>
+ val field = catalog.getField(attr)
+ if (field != null) {
+ if (field.isRowKey) {
+ val b = encoder.ranges(value)
+ var inc = false
+ b.map(_.greater.map {
+ x =>
+ val r = new RowKeyFilter(null, new ScanRange(x.upper, true, x.low, inc))
+ inc = true
+ r
+ }).map { x => x.reduce { (i, j) => i.mergeUnion(j) } }
+ .map(parentRowKeyFilter.mergeIntersect(_))
+ }
+ val byteValue = encoder.encode(field.dt, value)
+ valueArray += byteValue
+ }
+ new GreaterThanLogicExpression(attr, valueArray.length - 1)
+ case LessThanOrEqual(attr, value) =>
+ val field = catalog.getField(attr)
+ if (field != null) {
+ if (field.isRowKey) {
+ val b = encoder.ranges(value)
+ b.map(
+ _.less.map(x => new RowKeyFilter(null, new ScanRange(x.upper, true, x.low, true))))
+ .map { x => x.reduce { (i, j) => i.mergeUnion(j) } }
+ .map(parentRowKeyFilter.mergeIntersect(_))
+ }
+ val byteValue = encoder.encode(field.dt, value)
+ valueArray += byteValue
+ }
+ new LessThanOrEqualLogicExpression(attr, valueArray.length - 1)
+ case GreaterThanOrEqual(attr, value) =>
+ val field = catalog.getField(attr)
+ if (field != null) {
+ if (field.isRowKey) {
+ val b = encoder.ranges(value)
+ b.map(
+ _.greater.map(x => new RowKeyFilter(null, new ScanRange(x.upper, true, x.low, true))))
+ .map { x => x.reduce { (i, j) => i.mergeUnion(j) } }
+ .map(parentRowKeyFilter.mergeIntersect(_))
+ }
+ val byteValue = encoder.encode(field.dt, value)
+ valueArray += byteValue
+ }
+ new GreaterThanOrEqualLogicExpression(attr, valueArray.length - 1)
+ case StringStartsWith(attr, value) =>
+ val field = catalog.getField(attr)
+ if (field != null) {
+ if (field.isRowKey) {
+ val p = Utils.toBytes(value, field)
+ val endRange = Utils.incrementByteArray(p)
+ parentRowKeyFilter.mergeIntersect(
+ new RowKeyFilter(null, new ScanRange(endRange, false, p, true)))
+ }
+ val byteValue = Utils.toBytes(value, field)
+ valueArray += byteValue
+ }
+ new StartsWithLogicExpression(attr, valueArray.length - 1)
+ case Or(left, right) =>
+ val leftExpression = transverseFilterTree(parentRowKeyFilter, valueArray, left)
+ val rightSideRowKeyFilter = new RowKeyFilter
+ val rightExpression = transverseFilterTree(rightSideRowKeyFilter, valueArray, right)
+
+ parentRowKeyFilter.mergeUnion(rightSideRowKeyFilter)
+
+ new OrLogicExpression(leftExpression, rightExpression)
+ case And(left, right) =>
+ val leftExpression = transverseFilterTree(parentRowKeyFilter, valueArray, left)
+ val rightSideRowKeyFilter = new RowKeyFilter
+ val rightExpression = transverseFilterTree(rightSideRowKeyFilter, valueArray, right)
+ parentRowKeyFilter.mergeIntersect(rightSideRowKeyFilter)
+
+ new AndLogicExpression(leftExpression, rightExpression)
+ case IsNull(attr) =>
+ new IsNullLogicExpression(attr, false)
+ case IsNotNull(attr) =>
+ new IsNullLogicExpression(attr, true)
+ case _ =>
+ new PassThroughLogicExpression
+ }
+ }
+}
+
+/**
+ * Construct to contain a single scan ranges information. Also
+ * provide functions to merge with other scan ranges through AND
+ * or OR operators
+ *
+ * @param upperBound Upper bound of scan
+ * @param isUpperBoundEqualTo Include upper bound value in the results
+ * @param lowerBound Lower bound of scan
+ * @param isLowerBoundEqualTo Include lower bound value in the results
+ */
+@InterfaceAudience.Private
+class ScanRange(
+ var upperBound: Array[Byte],
+ var isUpperBoundEqualTo: Boolean,
+ var lowerBound: Array[Byte],
+ var isLowerBoundEqualTo: Boolean)
+ extends Serializable {
+
+ /**
+ * Function to merge another scan object through a AND operation
+ *
+ * @param other Other scan object
+ */
+ def mergeIntersect(other: ScanRange): Unit = {
+ val upperBoundCompare = compareRange(upperBound, other.upperBound)
+ val lowerBoundCompare = compareRange(lowerBound, other.lowerBound)
+
+ upperBound = if (upperBoundCompare < 0) upperBound else other.upperBound
+ lowerBound = if (lowerBoundCompare > 0) lowerBound else other.lowerBound
+
+ isLowerBoundEqualTo =
+ if (lowerBoundCompare == 0)
+ isLowerBoundEqualTo && other.isLowerBoundEqualTo
+ else if (lowerBoundCompare < 0) other.isLowerBoundEqualTo
+ else isLowerBoundEqualTo
+
+ isUpperBoundEqualTo =
+ if (upperBoundCompare == 0)
+ isUpperBoundEqualTo && other.isUpperBoundEqualTo
+ else if (upperBoundCompare < 0) isUpperBoundEqualTo
+ else other.isUpperBoundEqualTo
+ }
+
+ /**
+ * Function to merge another scan object through a OR operation
+ *
+ * @param other Other scan object
+ */
+ def mergeUnion(other: ScanRange): Unit = {
+
+ val upperBoundCompare = compareRange(upperBound, other.upperBound)
+ val lowerBoundCompare = compareRange(lowerBound, other.lowerBound)
+
+ upperBound = if (upperBoundCompare > 0) upperBound else other.upperBound
+ lowerBound = if (lowerBoundCompare < 0) lowerBound else other.lowerBound
+
+ isLowerBoundEqualTo =
+ if (lowerBoundCompare == 0)
+ isLowerBoundEqualTo || other.isLowerBoundEqualTo
+ else if (lowerBoundCompare < 0) isLowerBoundEqualTo
+ else other.isLowerBoundEqualTo
+
+ isUpperBoundEqualTo =
+ if (upperBoundCompare == 0)
+ isUpperBoundEqualTo || other.isUpperBoundEqualTo
+ else if (upperBoundCompare < 0) other.isUpperBoundEqualTo
+ else isUpperBoundEqualTo
+ }
+
+ /**
+ * Common function to see if this scan over laps with another
+ *
+ * Reference Visual
+ *
+ * A B
+ * |---------------------------|
+ * LL--------------LU
+ * RL--------------RU
+ *
+ * A = lowest value is byte[0]
+ * B = highest value is null
+ * LL = Left Lower Bound
+ * LU = Left Upper Bound
+ * RL = Right Lower Bound
+ * RU = Right Upper Bound
+ *
+ * @param other Other scan object
+ * @return True is overlap false is not overlap
+ */
+ def getOverLapScanRange(other: ScanRange): ScanRange = {
+ var leftRange: ScanRange = null
+ var rightRange: ScanRange = null
+
+ // First identify the Left range
+ // Also lower bound can't be null
+ if (compareRange(lowerBound, other.lowerBound) < 0 ||
+ compareRange(upperBound, other.upperBound) < 0) {
+ leftRange = this
+ rightRange = other
+ } else {
+ leftRange = other
+ rightRange = this
+ }
+
+ if (hasOverlap(leftRange, rightRange)) {
+ val result = new ScanRange(upperBound, isUpperBoundEqualTo, lowerBound, isLowerBoundEqualTo)
+ result.mergeIntersect(other)
+ result
+ } else {
+ null
+ }
+ }
+
+ /**
+ * The leftRange.upperBound has to be larger than the rightRange's lowerBound.
+ * Otherwise, there is no overlap.
+ *
+ * @param left: The range with the smaller lowBound
+ * @param right: The range with the larger lowBound
+ * @return Whether two ranges have overlap.
+ */
+
+ def hasOverlap(left: ScanRange, right: ScanRange): Boolean = {
+ compareRange(left.upperBound, right.lowerBound) >= 0
+ }
+
+ /**
+ * Special compare logic because we can have null values
+ * for left or right bound
+ *
+ * @param left Left byte array
+ * @param right Right byte array
+ * @return 0 for equals 1 is left is greater and -1 is right is greater
+ */
+ def compareRange(left: Array[Byte], right: Array[Byte]): Int = {
+ if (left == null && right == null) 0
+ else if (left == null && right != null) 1
+ else if (left != null && right == null) -1
+ else Bytes.compareTo(left, right)
+ }
+
+ /**
+ * @return
+ */
+ def containsPoint(point: Array[Byte]): Boolean = {
+ val lowerCompare = compareRange(point, lowerBound)
+ val upperCompare = compareRange(point, upperBound)
+
+ ((isLowerBoundEqualTo && lowerCompare >= 0) ||
+ (!isLowerBoundEqualTo && lowerCompare > 0)) &&
+ ((isUpperBoundEqualTo && upperCompare <= 0) ||
+ (!isUpperBoundEqualTo && upperCompare < 0))
+
+ }
+ override def toString: String = {
+ "ScanRange:(upperBound:" + Bytes.toString(upperBound) +
+ ",isUpperBoundEqualTo:" + isUpperBoundEqualTo + ",lowerBound:" +
+ Bytes.toString(lowerBound) + ",isLowerBoundEqualTo:" + isLowerBoundEqualTo + ")"
+ }
+}
+
+/**
+ * Contains information related to a filters for a given column.
+ * This can contain many ranges or points.
+ *
+ * @param currentPoint the initial point when the filter is created
+ * @param currentRange the initial scanRange when the filter is created
+ */
+@InterfaceAudience.Private
+class ColumnFilter(
+ currentPoint: Array[Byte] = null,
+ currentRange: ScanRange = null,
+ var points: mutable.ListBuffer [Array[Byte]] = new mutable.ListBuffer [Array[Byte]](),
+ var ranges: mutable.ListBuffer [ScanRange] = new mutable.ListBuffer [ScanRange]())
+ extends Serializable {
+ // Collection of ranges
+ if (currentRange != null) ranges.+=(currentRange)
+
+ // Collection of points
+ if (currentPoint != null) points.+=(currentPoint)
+
+ /**
+ * This will validate a give value through the filter's points and/or ranges
+ * the result will be if the value passed the filter
+ *
+ * @param value Value to be validated
+ * @param valueOffSet The offset of the value
+ * @param valueLength The length of the value
+ * @return True is the value passes the filter false if not
+ */
+ def validate(value: Array[Byte], valueOffSet: Int, valueLength: Int): Boolean = {
+ var result = false
+
+ points.foreach(
+ p => {
+ if (Bytes.equals(p, 0, p.length, value, valueOffSet, valueLength)) {
+ result = true
+ }
+ })
+
+ ranges.foreach(
+ r => {
+ val upperBoundPass = r.upperBound == null ||
+ (r.isUpperBoundEqualTo &&
+ Bytes.compareTo(
+ r.upperBound,
+ 0,
+ r.upperBound.length,
+ value,
+ valueOffSet,
+ valueLength) >= 0) ||
+ (!r.isUpperBoundEqualTo &&
+ Bytes
+ .compareTo(r.upperBound, 0, r.upperBound.length, value, valueOffSet, valueLength) > 0)
+
+ val lowerBoundPass = r.lowerBound == null || r.lowerBound.length == 0
+ (r.isLowerBoundEqualTo &&
+ Bytes
+ .compareTo(
+ r.lowerBound,
+ 0,
+ r.lowerBound.length,
+ value,
+ valueOffSet,
+ valueLength) <= 0) ||
+ (!r.isLowerBoundEqualTo &&
+ Bytes
+ .compareTo(r.lowerBound, 0, r.lowerBound.length, value, valueOffSet, valueLength) < 0)
+
+ result = result || (upperBoundPass && lowerBoundPass)
+ })
+ result
+ }
+
+ /**
+ * This will allow us to merge filter logic that is joined to the existing filter
+ * through a OR operator
+ *
+ * @param other Filter to merge
+ */
+ def mergeUnion(other: ColumnFilter): Unit = {
+ other.points.foreach(p => points += p)
+
+ other.ranges.foreach(
+ otherR => {
+ var doesOverLap = false
+ ranges.foreach {
+ r =>
+ if (r.getOverLapScanRange(otherR) != null) {
+ r.mergeUnion(otherR)
+ doesOverLap = true
+ }
+ }
+ if (!doesOverLap) ranges.+=(otherR)
+ })
+ }
+
+ /**
+ * This will allow us to merge filter logic that is joined to the existing filter
+ * through a AND operator
+ *
+ * @param other Filter to merge
+ */
+ def mergeIntersect(other: ColumnFilter): Unit = {
+ val survivingPoints = new mutable.ListBuffer [Array[Byte]]()
+ points.foreach(
+ p => {
+ other.points.foreach(
+ otherP => {
+ if (Bytes.equals(p, otherP)) {
+ survivingPoints.+=(p)
+ }
+ })
+ })
+ points = survivingPoints
+
+ val survivingRanges = new mutable.ListBuffer [ScanRange]()
+
+ other.ranges.foreach(
+ otherR => {
+ ranges.foreach(
+ r => {
+ if (r.getOverLapScanRange(otherR) != null) {
+ r.mergeIntersect(otherR)
+ survivingRanges += r
+ }
+ })
+ })
+ ranges = survivingRanges
+ }
+
+ override def toString: String = {
+ val strBuilder = new StringBuilder
+ strBuilder.append("(points:(")
+ var isFirst = true
+ points.foreach(
+ p => {
+ if (isFirst) isFirst = false
+ else strBuilder.append(",")
+ strBuilder.append(Bytes.toString(p))
+ })
+ strBuilder.append("),ranges:")
+ isFirst = true
+ ranges.foreach(
+ r => {
+ if (isFirst) isFirst = false
+ else strBuilder.append(",")
+ strBuilder.append(r)
+ })
+ strBuilder.append("))")
+ strBuilder.toString()
+ }
+}
+
+/**
+ * A collection of ColumnFilters indexed by column names.
+ *
+ * Also contains merge commends that will consolidate the filters
+ * per column name
+ */
+@InterfaceAudience.Private
+class ColumnFilterCollection {
+ val columnFilterMap = new mutable.HashMap[String, ColumnFilter]
+
+ def clear(): Unit = {
+ columnFilterMap.clear()
+ }
+
+ /**
+ * This will allow us to merge filter logic that is joined to the existing filter
+ * through a OR operator. This will merge a single columns filter
+ *
+ * @param column The column to be merged
+ * @param other The other ColumnFilter object to merge
+ */
+ def mergeUnion(column: String, other: ColumnFilter): Unit = {
+ val existingFilter = columnFilterMap.get(column)
+ if (existingFilter.isEmpty) {
+ columnFilterMap.+=((column, other))
+ } else {
+ existingFilter.get.mergeUnion(other)
+ }
+ }
+
+ /**
+ * This will allow us to merge all filters in the existing collection
+ * to the filters in the other collection. All merges are done as a result
+ * of a OR operator
+ *
+ * @param other The other Column Filter Collection to be merged
+ */
+ def mergeUnion(other: ColumnFilterCollection): Unit = {
+ other.columnFilterMap.foreach(
+ e => {
+ mergeUnion(e._1, e._2)
+ })
+ }
+
+ /**
+ * This will allow us to merge all filters in the existing collection
+ * to the filters in the other collection. All merges are done as a result
+ * of a AND operator
+ *
+ * @param other The column filter from the other collection
+ */
+ def mergeIntersect(other: ColumnFilterCollection): Unit = {
+ other.columnFilterMap.foreach(
+ e => {
+ val existingColumnFilter = columnFilterMap.get(e._1)
+ if (existingColumnFilter.isEmpty) {
+ columnFilterMap += e
+ } else {
+ existingColumnFilter.get.mergeIntersect(e._2)
+ }
+ })
+ }
+
+ override def toString: String = {
+ val strBuilder = new StringBuilder
+ columnFilterMap.foreach(e => strBuilder.append(e))
+ strBuilder.toString()
+ }
+}
+
+/**
+ * Status object to store static functions but also to hold last executed
+ * information that can be used for unit testing.
+ */
+@InterfaceAudience.Private
+object DefaultSourceStaticUtils {
+
+ val byteRange = new ThreadLocal[PositionedByteRange] {
+ override def initialValue(): PositionedByteRange = {
+ val range = new SimplePositionedMutableByteRange()
+ range.setOffset(0)
+ range.setPosition(0)
+ }
+ }
+
+ def getFreshByteRange(bytes: Array[Byte]): PositionedByteRange = {
+ getFreshByteRange(bytes, 0, bytes.length)
+ }
+
+ def getFreshByteRange(bytes: Array[Byte], offset: Int = 0, length: Int): PositionedByteRange = {
+ byteRange.get().set(bytes).setLength(length).setOffset(offset)
+ }
+
+ // This will contain the last 5 filters and required fields used in buildScan
+ // These values can be used in unit testing to make sure we are converting
+ // The Spark SQL input correctly
+ val lastFiveExecutionRules =
+ new ConcurrentLinkedQueue[ExecutionRuleForUnitTesting]()
+
+ /**
+ * This method is to populate the lastFiveExecutionRules for unit test perposes
+ * This method is not thread safe.
+ *
+ * @param rowKeyFilter The rowKey Filter logic used in the last query
+ * @param dynamicLogicExpression The dynamicLogicExpression used in the last query
+ */
+ def populateLatestExecutionRules(
+ rowKeyFilter: RowKeyFilter,
+ dynamicLogicExpression: DynamicLogicExpression): Unit = {
+ lastFiveExecutionRules.add(
+ new ExecutionRuleForUnitTesting(rowKeyFilter, dynamicLogicExpression))
+ while (lastFiveExecutionRules.size() > 5) {
+ lastFiveExecutionRules.poll()
+ }
+ }
+}
+
+/**
+ * Contains information related to a filters for a given column.
+ * This can contain many ranges or points.
+ *
+ * @param currentPoint the initial point when the filter is created
+ * @param currentRange the initial scanRange when the filter is created
+ */
+@InterfaceAudience.Private
+class RowKeyFilter(
+ currentPoint: Array[Byte] = null,
+ currentRange: ScanRange = new ScanRange(null, true, new Array[Byte](0), true),
+ var points: mutable.ListBuffer [Array[Byte]] = new mutable.ListBuffer [Array[Byte]](),
+ var ranges: mutable.ListBuffer [ScanRange] = new mutable.ListBuffer [ScanRange]())
+ extends Serializable {
+ // Collection of ranges
+ if (currentRange != null) ranges.+=(currentRange)
+
+ // Collection of points
+ if (currentPoint != null) points.+=(currentPoint)
+
+ /**
+ * This will validate a give value through the filter's points and/or ranges
+ * the result will be if the value passed the filter
+ *
+ * @param value Value to be validated
+ * @param valueOffSet The offset of the value
+ * @param valueLength The length of the value
+ * @return True is the value passes the filter false if not
+ */
+ def validate(value: Array[Byte], valueOffSet: Int, valueLength: Int): Boolean = {
+ var result = false
+
+ points.foreach(
+ p => {
+ if (Bytes.equals(p, 0, p.length, value, valueOffSet, valueLength)) {
+ result = true
+ }
+ })
+
+ ranges.foreach(
+ r => {
+ val upperBoundPass = r.upperBound == null ||
+ (r.isUpperBoundEqualTo &&
+ Bytes.compareTo(
+ r.upperBound,
+ 0,
+ r.upperBound.length,
+ value,
+ valueOffSet,
+ valueLength) >= 0) ||
+ (!r.isUpperBoundEqualTo &&
+ Bytes
+ .compareTo(r.upperBound, 0, r.upperBound.length, value, valueOffSet, valueLength) > 0)
+
+ val lowerBoundPass = r.lowerBound == null || r.lowerBound.length == 0
+ (r.isLowerBoundEqualTo &&
+ Bytes
+ .compareTo(
+ r.lowerBound,
+ 0,
+ r.lowerBound.length,
+ value,
+ valueOffSet,
+ valueLength) <= 0) ||
+ (!r.isLowerBoundEqualTo &&
+ Bytes
+ .compareTo(r.lowerBound, 0, r.lowerBound.length, value, valueOffSet, valueLength) < 0)
+
+ result = result || (upperBoundPass && lowerBoundPass)
+ })
+ result
+ }
+
+ /**
+ * This will allow us to merge filter logic that is joined to the existing filter
+ * through a OR operator
+ *
+ * @param other Filter to merge
+ */
+ def mergeUnion(other: RowKeyFilter): RowKeyFilter = {
+ other.points.foreach(p => points += p)
+
+ other.ranges.foreach(
+ otherR => {
+ var doesOverLap = false
+ ranges.foreach {
+ r =>
+ if (r.getOverLapScanRange(otherR) != null) {
+ r.mergeUnion(otherR)
+ doesOverLap = true
+ }
+ }
+ if (!doesOverLap) ranges.+=(otherR)
+ })
+ this
+ }
+
+ /**
+ * This will allow us to merge filter logic that is joined to the existing filter
+ * through a AND operator
+ *
+ * @param other Filter to merge
+ */
+ def mergeIntersect(other: RowKeyFilter): RowKeyFilter = {
+ val survivingPoints = new mutable.ListBuffer [Array[Byte]]()
+ val didntSurviveFirstPassPoints = new mutable.ListBuffer [Array[Byte]]()
+ if (points == null || points.length == 0) {
+ other.points.foreach(
+ otherP => {
+ didntSurviveFirstPassPoints += otherP
+ })
+ } else {
+ points.foreach(
+ p => {
+ if (other.points.length == 0) {
+ didntSurviveFirstPassPoints += p
+ } else {
+ other.points.foreach(
+ otherP => {
+ if (Bytes.equals(p, otherP)) {
+ survivingPoints += p
+ } else {
+ didntSurviveFirstPassPoints += p
+ }
+ })
+ }
+ })
+ }
+
+ val survivingRanges = new mutable.ListBuffer [ScanRange]()
+
+ if (ranges.length == 0) {
+ didntSurviveFirstPassPoints.foreach(
+ p => {
+ survivingPoints += p
+ })
+ } else {
+ ranges.foreach(
+ r => {
+ other.ranges.foreach(
+ otherR => {
+ val overLapScanRange = r.getOverLapScanRange(otherR)
+ if (overLapScanRange != null) {
+ survivingRanges += overLapScanRange
+ }
+ })
+ didntSurviveFirstPassPoints.foreach(
+ p => {
+ if (r.containsPoint(p)) {
+ survivingPoints += p
+ }
+ })
+ })
+ }
+ points = survivingPoints
+ ranges = survivingRanges
+ this
+ }
+
+ override def toString: String = {
+ val strBuilder = new StringBuilder
+ strBuilder.append("(points:(")
+ var isFirst = true
+ points.foreach(
+ p => {
+ if (isFirst) isFirst = false
+ else strBuilder.append(",")
+ strBuilder.append(Bytes.toString(p))
+ })
+ strBuilder.append("),ranges:")
+ isFirst = true
+ ranges.foreach(
+ r => {
+ if (isFirst) isFirst = false
+ else strBuilder.append(",")
+ strBuilder.append(r)
+ })
+ strBuilder.append("))")
+ strBuilder.toString()
+ }
+}
+
+@InterfaceAudience.Private
+class ExecutionRuleForUnitTesting(
+ val rowKeyFilter: RowKeyFilter,
+ val dynamicLogicExpression: DynamicLogicExpression)
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala
new file mode 100644
index 00000000..398093a5
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala
@@ -0,0 +1,324 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.util
+import org.apache.hadoop.hbase.spark.datasources.{BytesEncoder, JavaBytesEncoder}
+import org.apache.hadoop.hbase.spark.datasources.JavaBytesEncoder.JavaBytesEncoder
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * Dynamic logic for SQL push down logic there is an instance for most
+ * common operations and a pass through for other operations not covered here
+ *
+ * Logic can be nested with And or Or operators.
+ *
+ * A logic tree can be written out as a string and reconstructed from that string
+ */
+@InterfaceAudience.Private
+trait DynamicLogicExpression {
+ def execute(
+ columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable],
+ valueFromQueryValueArray: Array[Array[Byte]]): Boolean
+ def toExpressionString: String = {
+ val strBuilder = new StringBuilder
+ appendToExpression(strBuilder)
+ strBuilder.toString()
+ }
+ def filterOps: JavaBytesEncoder = JavaBytesEncoder.Unknown
+
+ def appendToExpression(strBuilder: StringBuilder)
+
+ var encoder: BytesEncoder = _
+
+ def setEncoder(enc: BytesEncoder): DynamicLogicExpression = {
+ encoder = enc
+ this
+ }
+}
+
+@InterfaceAudience.Private
+trait CompareTrait {
+ self: DynamicLogicExpression =>
+ def columnName: String
+ def valueFromQueryIndex: Int
+ def execute(
+ columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable],
+ valueFromQueryValueArray: Array[Array[Byte]]): Boolean = {
+ val currentRowValue = columnToCurrentRowValueMap.get(columnName)
+ val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex)
+ currentRowValue != null &&
+ encoder.filter(
+ currentRowValue.bytes,
+ currentRowValue.offset,
+ currentRowValue.length,
+ valueFromQuery,
+ 0,
+ valueFromQuery.length,
+ filterOps)
+ }
+}
+
+@InterfaceAudience.Private
+class AndLogicExpression(
+ val leftExpression: DynamicLogicExpression,
+ val rightExpression: DynamicLogicExpression)
+ extends DynamicLogicExpression {
+ override def execute(
+ columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable],
+ valueFromQueryValueArray: Array[Array[Byte]]): Boolean = {
+ leftExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray) &&
+ rightExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)
+ }
+
+ override def appendToExpression(strBuilder: StringBuilder): Unit = {
+ strBuilder.append("( ")
+ strBuilder.append(leftExpression.toExpressionString)
+ strBuilder.append(" AND ")
+ strBuilder.append(rightExpression.toExpressionString)
+ strBuilder.append(" )")
+ }
+}
+
+@InterfaceAudience.Private
+class OrLogicExpression(
+ val leftExpression: DynamicLogicExpression,
+ val rightExpression: DynamicLogicExpression)
+ extends DynamicLogicExpression {
+ override def execute(
+ columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable],
+ valueFromQueryValueArray: Array[Array[Byte]]): Boolean = {
+ leftExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray) ||
+ rightExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)
+ }
+ override def appendToExpression(strBuilder: StringBuilder): Unit = {
+ strBuilder.append("( ")
+ strBuilder.append(leftExpression.toExpressionString)
+ strBuilder.append(" OR ")
+ strBuilder.append(rightExpression.toExpressionString)
+ strBuilder.append(" )")
+ }
+}
+
+@InterfaceAudience.Private
+class EqualLogicExpression(val columnName: String, val valueFromQueryIndex: Int, val isNot: Boolean)
+ extends DynamicLogicExpression {
+ override def execute(
+ columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable],
+ valueFromQueryValueArray: Array[Array[Byte]]): Boolean = {
+ val currentRowValue = columnToCurrentRowValueMap.get(columnName)
+ val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex)
+
+ currentRowValue != null &&
+ Bytes.equals(
+ valueFromQuery,
+ 0,
+ valueFromQuery.length,
+ currentRowValue.bytes,
+ currentRowValue.offset,
+ currentRowValue.length) != isNot
+ }
+ override def appendToExpression(strBuilder: StringBuilder): Unit = {
+ val command = if (isNot) "!=" else "=="
+ strBuilder.append(columnName + " " + command + " " + valueFromQueryIndex)
+ }
+}
+
+@InterfaceAudience.Private
+class StartsWithLogicExpression(val columnName: String, val valueFromQueryIndex: Int)
+ extends DynamicLogicExpression {
+ override def execute(
+ columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable],
+ valueFromQueryValueArray: Array[Array[Byte]]): Boolean = {
+ val currentRowValue = columnToCurrentRowValueMap.get(columnName)
+ val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex)
+
+ currentRowValue != null && valueFromQuery != null && currentRowValue.length >= valueFromQuery.length &&
+ Bytes.equals(
+ valueFromQuery,
+ 0,
+ valueFromQuery.length,
+ currentRowValue.bytes,
+ currentRowValue.offset,
+ valueFromQuery.length)
+ }
+ override def appendToExpression(strBuilder: StringBuilder): Unit = {
+ strBuilder.append(columnName + " startsWith " + valueFromQueryIndex)
+ }
+}
+
+@InterfaceAudience.Private
+class IsNullLogicExpression(val columnName: String, val isNot: Boolean)
+ extends DynamicLogicExpression {
+ override def execute(
+ columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable],
+ valueFromQueryValueArray: Array[Array[Byte]]): Boolean = {
+ val currentRowValue = columnToCurrentRowValueMap.get(columnName)
+
+ (currentRowValue == null) != isNot
+ }
+ override def appendToExpression(strBuilder: StringBuilder): Unit = {
+ val command = if (isNot) "isNotNull" else "isNull"
+ strBuilder.append(columnName + " " + command)
+ }
+}
+
+@InterfaceAudience.Private
+class GreaterThanLogicExpression(
+ override val columnName: String,
+ override val valueFromQueryIndex: Int)
+ extends DynamicLogicExpression
+ with CompareTrait {
+ override val filterOps = JavaBytesEncoder.Greater
+ override def appendToExpression(strBuilder: StringBuilder): Unit = {
+ strBuilder.append(columnName + " > " + valueFromQueryIndex)
+ }
+}
+
+@InterfaceAudience.Private
+class GreaterThanOrEqualLogicExpression(
+ override val columnName: String,
+ override val valueFromQueryIndex: Int)
+ extends DynamicLogicExpression
+ with CompareTrait {
+ override val filterOps = JavaBytesEncoder.GreaterEqual
+ override def appendToExpression(strBuilder: StringBuilder): Unit = {
+ strBuilder.append(columnName + " >= " + valueFromQueryIndex)
+ }
+}
+
+@InterfaceAudience.Private
+class LessThanLogicExpression(
+ override val columnName: String,
+ override val valueFromQueryIndex: Int)
+ extends DynamicLogicExpression
+ with CompareTrait {
+ override val filterOps = JavaBytesEncoder.Less
+ override def appendToExpression(strBuilder: StringBuilder): Unit = {
+ strBuilder.append(columnName + " < " + valueFromQueryIndex)
+ }
+}
+
+@InterfaceAudience.Private
+class LessThanOrEqualLogicExpression(val columnName: String, val valueFromQueryIndex: Int)
+ extends DynamicLogicExpression
+ with CompareTrait {
+ override val filterOps = JavaBytesEncoder.LessEqual
+ override def appendToExpression(strBuilder: StringBuilder): Unit = {
+ strBuilder.append(columnName + " <= " + valueFromQueryIndex)
+ }
+}
+
+@InterfaceAudience.Private
+class PassThroughLogicExpression() extends DynamicLogicExpression {
+ override def execute(
+ columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable],
+ valueFromQueryValueArray: Array[Array[Byte]]): Boolean = true
+
+ override def appendToExpression(strBuilder: StringBuilder): Unit = {
+ // Fix the offset bug by add dummy to avoid crash the region server.
+ // because in the DynamicLogicExpressionBuilder.build function, the command is always retrieved from offset + 1 as below
+ // val command = expressionArray(offSet + 1)
+ // we have to padding it so that `Pass` is on the right offset.
+ strBuilder.append("dummy Pass -1")
+ }
+}
+
+@InterfaceAudience.Private
+object DynamicLogicExpressionBuilder {
+ def build(expressionString: String, encoder: BytesEncoder): DynamicLogicExpression = {
+
+ val expressionAndOffset = build(expressionString.split(' '), 0, encoder)
+ expressionAndOffset._1
+ }
+
+ private def build(
+ expressionArray: Array[String],
+ offSet: Int,
+ encoder: BytesEncoder): (DynamicLogicExpression, Int) = {
+ val expr = {
+ if (expressionArray(offSet).equals("(")) {
+ val left = build(expressionArray, offSet + 1, encoder)
+ val right = build(expressionArray, left._2 + 1, encoder)
+ if (expressionArray(left._2).equals("AND")) {
+ (new AndLogicExpression(left._1, right._1), right._2 + 1)
+ } else if (expressionArray(left._2).equals("OR")) {
+ (new OrLogicExpression(left._1, right._1), right._2 + 1)
+ } else {
+ throw new Throwable("Unknown gate:" + expressionArray(left._2))
+ }
+ } else {
+ val command = expressionArray(offSet + 1)
+ if (command.equals("<")) {
+ (
+ new LessThanLogicExpression(expressionArray(offSet), expressionArray(offSet + 2).toInt),
+ offSet + 3)
+ } else if (command.equals("<=")) {
+ (
+ new LessThanOrEqualLogicExpression(
+ expressionArray(offSet),
+ expressionArray(offSet + 2).toInt),
+ offSet + 3)
+ } else if (command.equals(">")) {
+ (
+ new GreaterThanLogicExpression(
+ expressionArray(offSet),
+ expressionArray(offSet + 2).toInt),
+ offSet + 3)
+ } else if (command.equals(">=")) {
+ (
+ new GreaterThanOrEqualLogicExpression(
+ expressionArray(offSet),
+ expressionArray(offSet + 2).toInt),
+ offSet + 3)
+ } else if (command.equals("==")) {
+ (
+ new EqualLogicExpression(
+ expressionArray(offSet),
+ expressionArray(offSet + 2).toInt,
+ false),
+ offSet + 3)
+ } else if (command.equals("!=")) {
+ (
+ new EqualLogicExpression(
+ expressionArray(offSet),
+ expressionArray(offSet + 2).toInt,
+ true),
+ offSet + 3)
+ } else if (command.equals("startsWith")) {
+ (
+ new StartsWithLogicExpression(
+ expressionArray(offSet),
+ expressionArray(offSet + 2).toInt),
+ offSet + 3)
+ } else if (command.equals("isNull")) {
+ (new IsNullLogicExpression(expressionArray(offSet), false), offSet + 2)
+ } else if (command.equals("isNotNull")) {
+ (new IsNullLogicExpression(expressionArray(offSet), true), offSet + 2)
+ } else if (command.equals("Pass")) {
+ (new PassThroughLogicExpression, offSet + 3)
+ } else {
+ throw new Throwable("Unknown logic command:" + command)
+ }
+ }
+ }
+ expr._1.setEncoder(encoder)
+ expr
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/FamiliesQualifiersValues.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/FamiliesQualifiersValues.scala
new file mode 100644
index 00000000..93410a42
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/FamiliesQualifiersValues.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.util
+import org.apache.yetus.audience.InterfaceAudience;
+
+/**
+ * This object is a clean way to store and sort all cells that will be bulk
+ * loaded into a single row
+ */
+@InterfaceAudience.Public
+class FamiliesQualifiersValues extends Serializable {
+ // Tree maps are used because we need the results to
+ // be sorted when we read them
+ val familyMap = new util.TreeMap[ByteArrayWrapper, util.TreeMap[ByteArrayWrapper, Array[Byte]]]()
+
+ // normally in a row there are more columns then
+ // column families this wrapper is reused for column
+ // family look ups
+ val reusableWrapper = new ByteArrayWrapper(null)
+
+ /**
+ * Adds a new cell to an existing row
+ * @param family HBase column family
+ * @param qualifier HBase column qualifier
+ * @param value HBase cell value
+ */
+ def +=(family: Array[Byte], qualifier: Array[Byte], value: Array[Byte]): Unit = {
+
+ reusableWrapper.value = family
+
+ var qualifierValues = familyMap.get(reusableWrapper)
+
+ if (qualifierValues == null) {
+ qualifierValues = new util.TreeMap[ByteArrayWrapper, Array[Byte]]()
+ familyMap.put(new ByteArrayWrapper(family), qualifierValues)
+ }
+
+ qualifierValues.put(new ByteArrayWrapper(qualifier), value)
+ }
+
+ /**
+ * A wrapper for "+=" method above, can be used by Java
+ * @param family HBase column family
+ * @param qualifier HBase column qualifier
+ * @param value HBase cell value
+ */
+ def add(family: Array[Byte], qualifier: Array[Byte], value: Array[Byte]): Unit = {
+ this += (family, qualifier, value)
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/FamilyHFileWriteOptions.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/FamilyHFileWriteOptions.scala
new file mode 100644
index 00000000..05117c51
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/FamilyHFileWriteOptions.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.io.Serializable
+import org.apache.yetus.audience.InterfaceAudience;
+
+/**
+ * This object will hold optional data for how a given column family's
+ * writer will work
+ *
+ * @param compression String to define the Compression to be used in the HFile
+ * @param bloomType String to define the bloom type to be used in the HFile
+ * @param blockSize The block size to be used in the HFile
+ * @param dataBlockEncoding String to define the data block encoding to be used
+ * in the HFile
+ */
+@InterfaceAudience.Public
+class FamilyHFileWriteOptions(
+ val compression: String,
+ val bloomType: String,
+ val blockSize: Int,
+ val dataBlockEncoding: String)
+ extends Serializable
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCache.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCache.scala
new file mode 100644
index 00000000..812975f3
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCache.scala
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.io.IOException
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hbase.HConstants
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.Admin
+import org.apache.hadoop.hbase.client.Connection
+import org.apache.hadoop.hbase.client.ConnectionFactory
+import org.apache.hadoop.hbase.client.RegionLocator
+import org.apache.hadoop.hbase.client.Table
+import org.apache.hadoop.hbase.ipc.RpcControllerFactory
+import org.apache.hadoop.hbase.security.User
+import org.apache.hadoop.hbase.security.UserProvider
+import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf
+import org.apache.yetus.audience.InterfaceAudience
+import scala.collection.mutable
+
+@InterfaceAudience.Private
+private[spark] object HBaseConnectionCache extends Logging {
+
+ // A hashmap of Spark-HBase connections. Key is HBaseConnectionKey.
+ val connectionMap = new mutable.HashMap[HBaseConnectionKey, SmartConnection]()
+
+ val cacheStat = HBaseConnectionCacheStat(0, 0, 0)
+
+ // in milliseconds
+ private final val DEFAULT_TIME_OUT: Long = HBaseSparkConf.DEFAULT_CONNECTION_CLOSE_DELAY
+ private var timeout = DEFAULT_TIME_OUT
+ private var closed: Boolean = false
+
+ var housekeepingThread = new Thread(new Runnable {
+ override def run() {
+ while (true) {
+ try {
+ Thread.sleep(timeout)
+ } catch {
+ case e: InterruptedException =>
+ // setTimeout() and close() may interrupt the sleep and it's safe
+ // to ignore the exception
+ }
+ if (closed)
+ return
+ performHousekeeping(false)
+ }
+ }
+ })
+ housekeepingThread.setDaemon(true)
+ housekeepingThread.start()
+
+ def getStat: HBaseConnectionCacheStat = {
+ connectionMap.synchronized {
+ cacheStat.numActiveConnections = connectionMap.size
+ cacheStat.copy()
+ }
+ }
+
+ def close(): Unit = {
+ try {
+ connectionMap.synchronized {
+ if (closed)
+ return
+ closed = true
+ housekeepingThread.interrupt()
+ housekeepingThread = null
+ HBaseConnectionCache.performHousekeeping(true)
+ }
+ } catch {
+ case e: Exception => logWarning("Error in finalHouseKeeping", e)
+ }
+ }
+
+ def performHousekeeping(forceClean: Boolean) = {
+ val tsNow: Long = System.currentTimeMillis()
+ connectionMap.synchronized {
+ connectionMap.foreach {
+ x =>
+ {
+ if (x._2.refCount < 0) {
+ logError(s"Bug to be fixed: negative refCount of connection ${x._2}")
+ }
+
+ if (forceClean || ((x._2.refCount <= 0) && (tsNow - x._2.timestamp > timeout))) {
+ try {
+ x._2.connection.close()
+ } catch {
+ case e: IOException => logWarning(s"Fail to close connection ${x._2}", e)
+ }
+ connectionMap.remove(x._1)
+ }
+ }
+ }
+ }
+ }
+
+ // For testing purpose only
+ def getConnection(key: HBaseConnectionKey, conn: => Connection): SmartConnection = {
+ connectionMap.synchronized {
+ if (closed)
+ return null
+ cacheStat.numTotalRequests += 1
+ val sc = connectionMap.getOrElseUpdate(
+ key, {
+ cacheStat.numActualConnectionsCreated += 1
+ new SmartConnection(conn)
+ })
+ sc.refCount += 1
+ sc
+ }
+ }
+
+ def getConnection(conf: Configuration): SmartConnection =
+ getConnection(new HBaseConnectionKey(conf), ConnectionFactory.createConnection(conf))
+
+ // For testing purpose only
+ def setTimeout(to: Long): Unit = {
+ connectionMap.synchronized {
+ if (closed)
+ return
+ timeout = to
+ housekeepingThread.interrupt()
+ }
+ }
+}
+
+@InterfaceAudience.Private
+private[hbase] case class SmartConnection(
+ connection: Connection,
+ var refCount: Int = 0,
+ var timestamp: Long = 0) {
+ def getTable(tableName: TableName): Table = connection.getTable(tableName)
+ def getRegionLocator(tableName: TableName): RegionLocator = connection.getRegionLocator(tableName)
+ def isClosed: Boolean = connection.isClosed
+ def getAdmin: Admin = connection.getAdmin
+ def close() = {
+ HBaseConnectionCache.connectionMap.synchronized {
+ refCount -= 1
+ if (refCount <= 0)
+ timestamp = System.currentTimeMillis()
+ }
+ }
+}
+
+/**
+ * Denotes a unique key to an HBase Connection instance.
+ * Please refer to 'org.apache.hadoop.hbase.client.HConnectionKey'.
+ *
+ * In essence, this class captures the properties in Configuration
+ * that may be used in the process of establishing a connection.
+ */
+@InterfaceAudience.Private
+class HBaseConnectionKey(c: Configuration) extends Logging {
+ val conf: Configuration = c
+ val CONNECTION_PROPERTIES: Array[String] = Array[String](
+ HConstants.ZOOKEEPER_QUORUM,
+ HConstants.ZOOKEEPER_ZNODE_PARENT,
+ HConstants.ZOOKEEPER_CLIENT_PORT,
+ HConstants.HBASE_CLIENT_PAUSE,
+ HConstants.HBASE_CLIENT_RETRIES_NUMBER,
+ HConstants.HBASE_RPC_TIMEOUT_KEY,
+ HConstants.HBASE_META_SCANNER_CACHING,
+ HConstants.HBASE_CLIENT_INSTANCE_ID,
+ HConstants.RPC_CODEC_CONF_KEY,
+ HConstants.USE_META_REPLICAS,
+ RpcControllerFactory.CUSTOM_CONTROLLER_CONF_KEY)
+
+ var username: String = _
+ var m_properties = mutable.HashMap.empty[String, String]
+ if (conf != null) {
+ for (property <- CONNECTION_PROPERTIES) {
+ val value: String = conf.get(property)
+ if (value != null) {
+ m_properties.+=((property, value))
+ }
+ }
+ try {
+ val provider: UserProvider = UserProvider.instantiate(conf)
+ val currentUser: User = provider.getCurrent
+ if (currentUser != null) {
+ username = currentUser.getName
+ }
+ } catch {
+ case e: IOException => {
+ logWarning("Error obtaining current user, skipping username in HBaseConnectionKey", e)
+ }
+ }
+ }
+
+ // make 'properties' immutable
+ val properties = m_properties.toMap
+
+ override def hashCode: Int = {
+ val prime: Int = 31
+ var result: Int = 1
+ if (username != null) {
+ result = username.hashCode
+ }
+ for (property <- CONNECTION_PROPERTIES) {
+ val value: Option[String] = properties.get(property)
+ if (value.isDefined) {
+ result = prime * result + value.hashCode
+ }
+ }
+ result
+ }
+
+ override def equals(obj: Any): Boolean = {
+ if (obj == null) return false
+ if (getClass ne obj.getClass) return false
+ val that: HBaseConnectionKey = obj.asInstanceOf[HBaseConnectionKey]
+ if (this.username != null && !(this.username == that.username)) {
+ return false
+ } else if (this.username == null && that.username != null) {
+ return false
+ }
+ if (this.properties == null) {
+ if (that.properties != null) {
+ return false
+ }
+ } else {
+ if (that.properties == null) {
+ return false
+ }
+ var flag: Boolean = true
+ for (property <- CONNECTION_PROPERTIES) {
+ val thisValue: Option[String] = this.properties.get(property)
+ val thatValue: Option[String] = that.properties.get(property)
+ flag = true
+ if (thisValue eq thatValue) {
+ flag = false // continue, so make flag be false
+ }
+ if (flag && (thisValue == null || !(thisValue == thatValue))) {
+ return false
+ }
+ }
+ }
+ true
+ }
+
+ override def toString: String = {
+ "HBaseConnectionKey{" + "properties=" + properties + ", username='" + username + '\'' + '}'
+ }
+}
+
+/**
+ * To log the state of 'HBaseConnectionCache'
+ *
+ * @param numTotalRequests number of total connection requests to the cache
+ * @param numActualConnectionsCreated number of actual HBase connections the cache ever created
+ * @param numActiveConnections number of current alive HBase connections the cache is holding
+ */
+@InterfaceAudience.Private
+case class HBaseConnectionCacheStat(
+ var numTotalRequests: Long,
+ var numActualConnectionsCreated: Long,
+ var numActiveConnections: Long)
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala
new file mode 100644
index 00000000..e17051db
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala
@@ -0,0 +1,1134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.io._
+import java.net.InetSocketAddress
+import java.util
+import java.util.UUID
+import javax.management.openmbean.KeyAlreadyExistsException
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileAlreadyExistsException, FileSystem, Path}
+import org.apache.hadoop.hbase._
+import org.apache.hadoop.hbase.client._
+import org.apache.hadoop.hbase.fs.HFileSystem
+import org.apache.hadoop.hbase.io.ImmutableBytesWritable
+import org.apache.hadoop.hbase.io.compress.Compression
+import org.apache.hadoop.hbase.io.compress.Compression.Algorithm
+import org.apache.hadoop.hbase.io.encoding.DataBlockEncoding
+import org.apache.hadoop.hbase.io.hfile.{CacheConfig, HFile, HFileContextBuilder, HFileWriterImpl}
+import org.apache.hadoop.hbase.mapreduce.{IdentityTableMapper, TableInputFormat, TableMapReduceUtil}
+import org.apache.hadoop.hbase.regionserver.{BloomType, HStoreFile, StoreFileWriter, StoreUtils}
+import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._
+import org.apache.hadoop.hbase.util.{Bytes, ChecksumType}
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod
+import org.apache.spark.{SerializableWritable, SparkContext}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.yetus.audience.InterfaceAudience
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+/**
+ * HBaseContext is a façade for HBase operations
+ * like bulk put, get, increment, delete, and scan
+ *
+ * HBaseContext will take the responsibilities
+ * of disseminating the configuration information
+ * to the working and managing the life cycle of Connections.
+ */
+@InterfaceAudience.Public
+class HBaseContext(
+ @transient val sc: SparkContext,
+ @transient val config: Configuration,
+ val tmpHdfsConfgFile: String = null)
+ extends Serializable
+ with Logging {
+
+ @transient var tmpHdfsConfiguration: Configuration = config
+ @transient var appliedCredentials = false
+ @transient val job = Job.getInstance(config)
+ TableMapReduceUtil.initCredentials(job)
+ val broadcastedConf = sc.broadcast(new SerializableWritable(config))
+
+ LatestHBaseContextCache.latest = this
+
+ if (tmpHdfsConfgFile != null && config != null) {
+ val fs = FileSystem.newInstance(config)
+ val tmpPath = new Path(tmpHdfsConfgFile)
+ if (!fs.exists(tmpPath)) {
+ val outputStream = fs.create(tmpPath)
+ config.write(outputStream)
+ outputStream.close()
+ } else {
+ logWarning("tmpHdfsConfigDir " + tmpHdfsConfgFile + " exist!!")
+ }
+ }
+
+ /**
+ * A simple enrichment of the traditional Spark RDD foreachPartition.
+ * This function differs from the original in that it offers the
+ * developer access to a already connected Connection object
+ *
+ * Note: Do not close the Connection object. All Connection
+ * management is handled outside this method
+ *
+ * @param rdd Original RDD with data to iterate over
+ * @param f Function to be given a iterator to iterate through
+ * the RDD values and a Connection object to interact
+ * with HBase
+ */
+ def foreachPartition[T](rdd: RDD[T], f: (Iterator[T], Connection) => Unit): Unit = {
+ rdd.foreachPartition(it => hbaseForeachPartition(broadcastedConf, it, f))
+ }
+
+ /**
+ * A simple enrichment of the traditional Spark Streaming dStream foreach
+ * This function differs from the original in that it offers the
+ * developer access to a already connected Connection object
+ *
+ * Note: Do not close the Connection object. All Connection
+ * management is handled outside this method
+ *
+ * @param dstream Original DStream with data to iterate over
+ * @param f Function to be given a iterator to iterate through
+ * the DStream values and a Connection object to
+ * interact with HBase
+ */
+ def foreachPartition[T](dstream: DStream[T], f: (Iterator[T], Connection) => Unit): Unit = {
+ dstream.foreachRDD(
+ (rdd, time) => {
+ foreachPartition(rdd, f)
+ })
+ }
+
+ /**
+ * A simple enrichment of the traditional Spark RDD mapPartition.
+ * This function differs from the original in that it offers the
+ * developer access to a already connected Connection object
+ *
+ * Note: Do not close the Connection object. All Connection
+ * management is handled outside this method
+ *
+ * @param rdd Original RDD with data to iterate over
+ * @param mp Function to be given a iterator to iterate through
+ * the RDD values and a Connection object to interact
+ * with HBase
+ * @return Returns a new RDD generated by the user definition
+ * function just like normal mapPartition
+ */
+ def mapPartitions[T, R: ClassTag](
+ rdd: RDD[T],
+ mp: (Iterator[T], Connection) => Iterator[R]): RDD[R] = {
+
+ rdd.mapPartitions[R](it => hbaseMapPartition[T, R](broadcastedConf, it, mp))
+
+ }
+
+ /**
+ * A simple enrichment of the traditional Spark Streaming DStream
+ * foreachPartition.
+ *
+ * This function differs from the original in that it offers the
+ * developer access to a already connected Connection object
+ *
+ * Note: Do not close the Connection object. All Connection
+ * management is handled outside this method
+ *
+ * Note: Make sure to partition correctly to avoid memory issue when
+ * getting data from HBase
+ *
+ * @param dstream Original DStream with data to iterate over
+ * @param f Function to be given a iterator to iterate through
+ * the DStream values and a Connection object to
+ * interact with HBase
+ * @return Returns a new DStream generated by the user
+ * definition function just like normal mapPartition
+ */
+ def streamForeachPartition[T](dstream: DStream[T], f: (Iterator[T], Connection) => Unit): Unit = {
+
+ dstream.foreachRDD(rdd => this.foreachPartition(rdd, f))
+ }
+
+ /**
+ * A simple enrichment of the traditional Spark Streaming DStream
+ * mapPartition.
+ *
+ * This function differs from the original in that it offers the
+ * developer access to a already connected Connection object
+ *
+ * Note: Do not close the Connection object. All Connection
+ * management is handled outside this method
+ *
+ * Note: Make sure to partition correctly to avoid memory issue when
+ * getting data from HBase
+ *
+ * @param dstream Original DStream with data to iterate over
+ * @param f Function to be given a iterator to iterate through
+ * the DStream values and a Connection object to
+ * interact with HBase
+ * @return Returns a new DStream generated by the user
+ * definition function just like normal mapPartition
+ */
+ def streamMapPartitions[T, U: ClassTag](
+ dstream: DStream[T],
+ f: (Iterator[T], Connection) => Iterator[U]): DStream[U] = {
+ dstream.mapPartitions(it => hbaseMapPartition[T, U](broadcastedConf, it, f))
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.foreachPartition method.
+ *
+ * It allow addition support for a user to take RDD
+ * and generate puts and send them to HBase.
+ * The complexity of managing the Connection is
+ * removed from the developer
+ *
+ * @param rdd Original RDD with data to iterate over
+ * @param tableName The name of the table to put into
+ * @param f Function to convert a value in the RDD to a HBase Put
+ */
+ def bulkPut[T](rdd: RDD[T], tableName: TableName, f: (T) => Put) {
+
+ val tName = tableName.getName
+ rdd.foreachPartition(
+ it =>
+ hbaseForeachPartition[T](
+ broadcastedConf,
+ it,
+ (iterator, connection) => {
+ val m = connection.getBufferedMutator(TableName.valueOf(tName))
+ iterator.foreach(T => m.mutate(f(T)))
+ m.flush()
+ m.close()
+ }))
+ }
+
+ def applyCreds[T]() {
+ if (!appliedCredentials) {
+ appliedCredentials = true
+
+ @transient val ugi = UserGroupInformation.getCurrentUser
+ // specify that this is a proxy user
+ ugi.setAuthenticationMethod(AuthenticationMethod.PROXY)
+ }
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.streamMapPartition method.
+ *
+ * It allow addition support for a user to take a DStream and
+ * generate puts and send them to HBase.
+ *
+ * The complexity of managing the Connection is
+ * removed from the developer
+ *
+ * @param dstream Original DStream with data to iterate over
+ * @param tableName The name of the table to put into
+ * @param f Function to convert a value in
+ * the DStream to a HBase Put
+ */
+ def streamBulkPut[T](dstream: DStream[T], tableName: TableName, f: (T) => Put) = {
+ val tName = tableName.getName
+ dstream.foreachRDD(
+ (rdd, time) => {
+ bulkPut(rdd, TableName.valueOf(tName), f)
+ })
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.foreachPartition method.
+ *
+ * It allow addition support for a user to take a RDD and generate delete
+ * and send them to HBase. The complexity of managing the Connection is
+ * removed from the developer
+ *
+ * @param rdd Original RDD with data to iterate over
+ * @param tableName The name of the table to delete from
+ * @param f Function to convert a value in the RDD to a
+ * HBase Deletes
+ * @param batchSize The number of delete to batch before sending to HBase
+ */
+ def bulkDelete[T](rdd: RDD[T], tableName: TableName, f: (T) => Delete, batchSize: Integer) {
+ bulkMutation(rdd, tableName, f, batchSize)
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.streamBulkMutation method.
+ *
+ * It allow addition support for a user to take a DStream and
+ * generate Delete and send them to HBase.
+ *
+ * The complexity of managing the Connection is
+ * removed from the developer
+ *
+ * @param dstream Original DStream with data to iterate over
+ * @param tableName The name of the table to delete from
+ * @param f function to convert a value in the DStream to a
+ * HBase Delete
+ * @param batchSize The number of deletes to batch before sending to HBase
+ */
+ def streamBulkDelete[T](
+ dstream: DStream[T],
+ tableName: TableName,
+ f: (T) => Delete,
+ batchSize: Integer) = {
+ streamBulkMutation(dstream, tableName, f, batchSize)
+ }
+
+ /**
+ * Under lining function to support all bulk mutations
+ *
+ * May be opened up if requested
+ */
+ private def bulkMutation[T](
+ rdd: RDD[T],
+ tableName: TableName,
+ f: (T) => Mutation,
+ batchSize: Integer) {
+
+ val tName = tableName.getName
+ rdd.foreachPartition(
+ it =>
+ hbaseForeachPartition[T](
+ broadcastedConf,
+ it,
+ (iterator, connection) => {
+ val table = connection.getTable(TableName.valueOf(tName))
+ val mutationList = new java.util.ArrayList[Mutation]
+ iterator.foreach(
+ T => {
+ mutationList.add(f(T))
+ if (mutationList.size >= batchSize) {
+ table.batch(mutationList, null)
+ mutationList.clear()
+ }
+ })
+ if (mutationList.size() > 0) {
+ table.batch(mutationList, null)
+ mutationList.clear()
+ }
+ table.close()
+ }))
+ }
+
+ /**
+ * Under lining function to support all bulk streaming mutations
+ *
+ * May be opened up if requested
+ */
+ private def streamBulkMutation[T](
+ dstream: DStream[T],
+ tableName: TableName,
+ f: (T) => Mutation,
+ batchSize: Integer) = {
+ val tName = tableName.getName
+ dstream.foreachRDD(
+ (rdd, time) => {
+ bulkMutation(rdd, TableName.valueOf(tName), f, batchSize)
+ })
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.mapPartition method.
+ *
+ * It allow addition support for a user to take a RDD and generates a
+ * new RDD based on Gets and the results they bring back from HBase
+ *
+ * @param rdd Original RDD with data to iterate over
+ * @param tableName The name of the table to get from
+ * @param makeGet function to convert a value in the RDD to a
+ * HBase Get
+ * @param convertResult This will convert the HBase Result object to
+ * what ever the user wants to put in the resulting
+ * RDD
+ * return new RDD that is created by the Get to HBase
+ */
+ def bulkGet[T, U: ClassTag](
+ tableName: TableName,
+ batchSize: Integer,
+ rdd: RDD[T],
+ makeGet: (T) => Get,
+ convertResult: (Result) => U): RDD[U] = {
+
+ val getMapPartition = new GetMapPartition(tableName, batchSize, makeGet, convertResult)
+
+ rdd.mapPartitions[U](it => hbaseMapPartition[T, U](broadcastedConf, it, getMapPartition.run))
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.streamMap method.
+ *
+ * It allow addition support for a user to take a DStream and
+ * generates a new DStream based on Gets and the results
+ * they bring back from HBase
+ *
+ * @param tableName The name of the table to get from
+ * @param batchSize The number of Gets to be sent in a single batch
+ * @param dStream Original DStream with data to iterate over
+ * @param makeGet Function to convert a value in the DStream to a
+ * HBase Get
+ * @param convertResult This will convert the HBase Result object to
+ * what ever the user wants to put in the resulting
+ * DStream
+ * @return A new DStream that is created by the Get to HBase
+ */
+ def streamBulkGet[T, U: ClassTag](
+ tableName: TableName,
+ batchSize: Integer,
+ dStream: DStream[T],
+ makeGet: (T) => Get,
+ convertResult: (Result) => U): DStream[U] = {
+
+ val getMapPartition = new GetMapPartition(tableName, batchSize, makeGet, convertResult)
+
+ dStream.mapPartitions[U](
+ it => hbaseMapPartition[T, U](broadcastedConf, it, getMapPartition.run))
+ }
+
+ /**
+ * This function will use the native HBase TableInputFormat with the
+ * given scan object to generate a new RDD
+ *
+ * @param tableName the name of the table to scan
+ * @param scan the HBase scan object to use to read data from HBase
+ * @param f function to convert a Result object from HBase into
+ * what the user wants in the final generated RDD
+ * @return new RDD with results from scan
+ */
+ def hbaseRDD[U: ClassTag](
+ tableName: TableName,
+ scan: Scan,
+ f: ((ImmutableBytesWritable, Result)) => U): RDD[U] = {
+
+ val job: Job = Job.getInstance(getConf(broadcastedConf))
+
+ TableMapReduceUtil.initCredentials(job)
+ TableMapReduceUtil.initTableMapperJob(
+ tableName,
+ scan,
+ classOf[IdentityTableMapper],
+ null,
+ null,
+ job)
+
+ val jconf = new JobConf(job.getConfiguration)
+ val jobCreds = jconf.getCredentials()
+ UserGroupInformation.setConfiguration(sc.hadoopConfiguration)
+ jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials())
+
+ new NewHBaseRDD(
+ sc,
+ classOf[TableInputFormat],
+ classOf[ImmutableBytesWritable],
+ classOf[Result],
+ job.getConfiguration,
+ this).map(f)
+ }
+
+ /**
+ * A overloaded version of HBaseContext hbaseRDD that defines the
+ * type of the resulting RDD
+ *
+ * @param tableName the name of the table to scan
+ * @param scans the HBase scan object to use to read data from HBase
+ * @return New RDD with results from scan
+ */
+ def hbaseRDD(tableName: TableName, scans: Scan): RDD[(ImmutableBytesWritable, Result)] = {
+
+ hbaseRDD[(ImmutableBytesWritable, Result)](
+ tableName,
+ scans,
+ (r: (ImmutableBytesWritable, Result)) => r)
+ }
+
+ /**
+ * underlining wrapper all foreach functions in HBaseContext
+ */
+ private def hbaseForeachPartition[T](
+ configBroadcast: Broadcast[SerializableWritable[Configuration]],
+ it: Iterator[T],
+ f: (Iterator[T], Connection) => Unit) = {
+
+ val config = getConf(configBroadcast)
+
+ applyCreds
+ // specify that this is a proxy user
+ val smartConn = HBaseConnectionCache.getConnection(config)
+ try {
+ f(it, smartConn.connection)
+ } finally {
+ if (smartConn != null) smartConn.close()
+ }
+ }
+
+ private def getConf(
+ configBroadcast: Broadcast[SerializableWritable[Configuration]]): Configuration = {
+
+ if (tmpHdfsConfiguration == null && tmpHdfsConfgFile != null) {
+ val fs = FileSystem.newInstance(sc.hadoopConfiguration)
+ val inputStream = fs.open(new Path(tmpHdfsConfgFile))
+ tmpHdfsConfiguration = new Configuration(false)
+ tmpHdfsConfiguration.readFields(inputStream)
+ inputStream.close()
+ }
+
+ if (tmpHdfsConfiguration == null) {
+ try {
+ tmpHdfsConfiguration = configBroadcast.value.value
+ } catch {
+ case ex: Exception => logError("Unable to getConfig from broadcast", ex)
+ }
+ }
+ tmpHdfsConfiguration
+ }
+
+ /**
+ * underlining wrapper all mapPartition functions in HBaseContext
+ */
+ private def hbaseMapPartition[K, U](
+ configBroadcast: Broadcast[SerializableWritable[Configuration]],
+ it: Iterator[K],
+ mp: (Iterator[K], Connection) => Iterator[U]): Iterator[U] = {
+
+ val config = getConf(configBroadcast)
+ applyCreds
+
+ val smartConn = HBaseConnectionCache.getConnection(config)
+ try {
+ mp(it, smartConn.connection)
+ } finally {
+ if (smartConn != null) smartConn.close()
+ }
+ }
+
+ /**
+ * underlining wrapper all get mapPartition functions in HBaseContext
+ */
+ private class GetMapPartition[T, U: ClassTag](
+ tableName: TableName,
+ batchSize: Integer,
+ makeGet: (T) => Get,
+ convertResult: (Result) => U)
+ extends Serializable {
+
+ val tName = tableName.getName
+
+ def run(iterator: Iterator[T], connection: Connection): Iterator[U] = {
+ val table = connection.getTable(TableName.valueOf(tName))
+
+ val gets = new java.util.ArrayList[Get]()
+ var res = List[U]()
+
+ while (iterator.hasNext) {
+ gets.add(makeGet(iterator.next()))
+
+ if (gets.size() == batchSize) {
+ val results = table.get(gets)
+ res = res ++ results.map(convertResult)
+ gets.clear()
+ }
+ }
+ if (gets.size() > 0) {
+ val results = table.get(gets)
+ res = res ++ results.map(convertResult)
+ gets.clear()
+ }
+ table.close()
+ res.iterator
+ }
+ }
+
+ /**
+ * Produces a ClassTag[T], which is actually just a casted ClassTag[AnyRef].
+ *
+ * This method is used to keep ClassTags out of the external Java API, as
+ * the Java compiler cannot produce them automatically. While this
+ * ClassTag-faking does please the compiler, it can cause problems at runtime
+ * if the Scala API relies on ClassTags for correctness.
+ *
+ * Often, though, a ClassTag[AnyRef] will not lead to incorrect behavior,
+ * just worse performance or security issues.
+ * For instance, an Array of AnyRef can hold any type T, but may lose primitive
+ * specialization.
+ */
+ private[spark] def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]]
+
+ /**
+ * Spark Implementation of HBase Bulk load for wide rows or when
+ * values are not already combined at the time of the map process
+ *
+ * This will take the content from an existing RDD then sort and shuffle
+ * it with respect to region splits. The result of that sort and shuffle
+ * will be written to HFiles.
+ *
+ * After this function is executed the user will have to call
+ * LoadIncrementalHFiles.doBulkLoad(...) to move the files into HBase
+ *
+ * Also note this version of bulk load is different from past versions in
+ * that it includes the qualifier as part of the sort process. The
+ * reason for this is to be able to support rows will very large number
+ * of columns.
+ *
+ * @param rdd The RDD we are bulk loading from
+ * @param tableName The HBase table we are loading into
+ * @param flatMap A flapMap function that will make every
+ * row in the RDD
+ * into N cells for the bulk load
+ * @param stagingDir The location on the FileSystem to bulk load into
+ * @param familyHFileWriteOptionsMap Options that will define how the HFile for a
+ * column family is written
+ * @param compactionExclude Compaction excluded for the HFiles
+ * @param maxSize Max size for the HFiles before they roll
+ * @param nowTimeStamp Version timestamp
+ * @tparam T The Type of values in the original RDD
+ */
+ def bulkLoad[T](
+ rdd: RDD[T],
+ tableName: TableName,
+ flatMap: (T) => Iterator[(KeyFamilyQualifier, Array[Byte])],
+ stagingDir: String,
+ familyHFileWriteOptionsMap: util.Map[Array[Byte], FamilyHFileWriteOptions] =
+ new util.HashMap[Array[Byte], FamilyHFileWriteOptions],
+ compactionExclude: Boolean = false,
+ maxSize: Long = HConstants.DEFAULT_MAX_FILE_SIZE,
+ nowTimeStamp: Long = System.currentTimeMillis()): Unit = {
+ val stagingPath = new Path(stagingDir)
+ val fs = stagingPath.getFileSystem(config)
+ if (fs.exists(stagingPath)) {
+ throw new FileAlreadyExistsException("Path " + stagingDir + " already exists")
+ }
+ val conn = HBaseConnectionCache.getConnection(config)
+ try {
+ val regionLocator = conn.getRegionLocator(tableName)
+ val startKeys = regionLocator.getStartKeys
+ if (startKeys.length == 0) {
+ logInfo("Table " + tableName.toString + " was not found")
+ }
+ val defaultCompressionStr =
+ config.get("hfile.compression", Compression.Algorithm.NONE.getName)
+ val hfileCompression = HFileWriterImpl
+ .compressionByName(defaultCompressionStr)
+ val tableRawName = tableName.getName
+
+ val familyHFileWriteOptionsMapInternal =
+ new util.HashMap[ByteArrayWrapper, FamilyHFileWriteOptions]
+
+ val entrySetIt = familyHFileWriteOptionsMap.entrySet().iterator()
+
+ while (entrySetIt.hasNext) {
+ val entry = entrySetIt.next()
+ familyHFileWriteOptionsMapInternal.put(new ByteArrayWrapper(entry.getKey), entry.getValue)
+ }
+
+ val regionSplitPartitioner =
+ new BulkLoadPartitioner(startKeys)
+
+ // This is where all the magic happens
+ // Here we are going to do the following things
+ // 1. FlapMap every row in the RDD into key column value tuples
+ // 2. Then we are going to repartition sort and shuffle
+ // 3. Finally we are going to write out our HFiles
+ rdd
+ .flatMap(r => flatMap(r))
+ .repartitionAndSortWithinPartitions(regionSplitPartitioner)
+ .hbaseForeachPartition(
+ this,
+ (it, conn) => {
+
+ val conf = broadcastedConf.value.value
+ val fs = new Path(stagingDir).getFileSystem(conf)
+ val writerMap = new mutable.HashMap[ByteArrayWrapper, WriterLength]
+ var previousRow: Array[Byte] = HConstants.EMPTY_BYTE_ARRAY
+ var rollOverRequested = false
+ val localTableName = TableName.valueOf(tableRawName)
+
+ // Here is where we finally iterate through the data in this partition of the
+ // RDD that has been sorted and partitioned
+ it.foreach {
+ case (keyFamilyQualifier, cellValue: Array[Byte]) =>
+ val wl = writeValueToHFile(
+ keyFamilyQualifier.rowKey,
+ keyFamilyQualifier.family,
+ keyFamilyQualifier.qualifier,
+ cellValue,
+ nowTimeStamp,
+ fs,
+ conn,
+ localTableName,
+ conf,
+ familyHFileWriteOptionsMapInternal,
+ hfileCompression,
+ writerMap,
+ stagingDir)
+
+ rollOverRequested = rollOverRequested || wl.written > maxSize
+
+ // This will only roll if we have at least one column family file that is
+ // bigger then maxSize and we have finished a given row key
+ if (rollOverRequested && Bytes
+ .compareTo(previousRow, keyFamilyQualifier.rowKey) != 0) {
+ rollWriters(fs, writerMap, regionSplitPartitioner, previousRow, compactionExclude)
+ rollOverRequested = false
+ }
+
+ previousRow = keyFamilyQualifier.rowKey
+ }
+ // We have finished all the data so lets close up the writers
+ rollWriters(fs, writerMap, regionSplitPartitioner, previousRow, compactionExclude)
+ rollOverRequested = false
+ })
+ } finally {
+ if (null != conn) conn.close()
+ }
+ }
+
+ /**
+ * Spark Implementation of HBase Bulk load for short rows some where less then
+ * a 1000 columns. This bulk load should be faster for tables will thinner
+ * rows then the other spark implementation of bulk load that puts only one
+ * value into a record going into a shuffle
+ *
+ * This will take the content from an existing RDD then sort and shuffle
+ * it with respect to region splits. The result of that sort and shuffle
+ * will be written to HFiles.
+ *
+ * After this function is executed the user will have to call
+ * LoadIncrementalHFiles.doBulkLoad(...) to move the files into HBase
+ *
+ * In this implementation, only the rowKey is given to the shuffle as the key
+ * and all the columns are already linked to the RowKey before the shuffle
+ * stage. The sorting of the qualifier is done in memory out side of the
+ * shuffle stage
+ *
+ * Also make sure that incoming RDDs only have one record for every row key.
+ *
+ * @param rdd The RDD we are bulk loading from
+ * @param tableName The HBase table we are loading into
+ * @param mapFunction A function that will convert the RDD records to
+ * the key value format used for the shuffle to prep
+ * for writing to the bulk loaded HFiles
+ * @param stagingDir The location on the FileSystem to bulk load into
+ * @param familyHFileWriteOptionsMap Options that will define how the HFile for a
+ * column family is written
+ * @param compactionExclude Compaction excluded for the HFiles
+ * @param maxSize Max size for the HFiles before they roll
+ * @tparam T The Type of values in the original RDD
+ */
+ def bulkLoadThinRows[T](
+ rdd: RDD[T],
+ tableName: TableName,
+ mapFunction: (T) => (ByteArrayWrapper, FamiliesQualifiersValues),
+ stagingDir: String,
+ familyHFileWriteOptionsMap: util.Map[Array[Byte], FamilyHFileWriteOptions] =
+ new util.HashMap[Array[Byte], FamilyHFileWriteOptions],
+ compactionExclude: Boolean = false,
+ maxSize: Long = HConstants.DEFAULT_MAX_FILE_SIZE): Unit = {
+ val stagingPath = new Path(stagingDir)
+ val fs = stagingPath.getFileSystem(config)
+ if (fs.exists(stagingPath)) {
+ throw new FileAlreadyExistsException("Path " + stagingDir + " already exists")
+ }
+ val conn = HBaseConnectionCache.getConnection(config)
+ try {
+ val regionLocator = conn.getRegionLocator(tableName)
+ val startKeys = regionLocator.getStartKeys
+ if (startKeys.length == 0) {
+ logInfo("Table " + tableName.toString + " was not found")
+ }
+ val defaultCompressionStr =
+ config.get("hfile.compression", Compression.Algorithm.NONE.getName)
+ val defaultCompression = HFileWriterImpl
+ .compressionByName(defaultCompressionStr)
+ val nowTimeStamp = System.currentTimeMillis()
+ val tableRawName = tableName.getName
+
+ val familyHFileWriteOptionsMapInternal =
+ new util.HashMap[ByteArrayWrapper, FamilyHFileWriteOptions]
+
+ val entrySetIt = familyHFileWriteOptionsMap.entrySet().iterator()
+
+ while (entrySetIt.hasNext) {
+ val entry = entrySetIt.next()
+ familyHFileWriteOptionsMapInternal.put(new ByteArrayWrapper(entry.getKey), entry.getValue)
+ }
+
+ val regionSplitPartitioner =
+ new BulkLoadPartitioner(startKeys)
+
+ // This is where all the magic happens
+ // Here we are going to do the following things
+ // 1. FlapMap every row in the RDD into key column value tuples
+ // 2. Then we are going to repartition sort and shuffle
+ // 3. Finally we are going to write out our HFiles
+ rdd
+ .map(r => mapFunction(r))
+ .repartitionAndSortWithinPartitions(regionSplitPartitioner)
+ .hbaseForeachPartition(
+ this,
+ (it, conn) => {
+
+ val conf = broadcastedConf.value.value
+ val fs = new Path(stagingDir).getFileSystem(conf)
+ val writerMap = new mutable.HashMap[ByteArrayWrapper, WriterLength]
+ var previousRow: Array[Byte] = HConstants.EMPTY_BYTE_ARRAY
+ var rollOverRequested = false
+ val localTableName = TableName.valueOf(tableRawName)
+
+ // Here is where we finally iterate through the data in this partition of the
+ // RDD that has been sorted and partitioned
+ it.foreach {
+ case (rowKey: ByteArrayWrapper, familiesQualifiersValues: FamiliesQualifiersValues) =>
+ if (Bytes.compareTo(previousRow, rowKey.value) == 0) {
+ throw new KeyAlreadyExistsException(
+ "The following key was sent to the " +
+ "HFile load more then one: " + Bytes.toString(previousRow))
+ }
+
+ // The family map is a tree map so the families will be sorted
+ val familyIt = familiesQualifiersValues.familyMap.entrySet().iterator()
+ while (familyIt.hasNext) {
+ val familyEntry = familyIt.next()
+
+ val family = familyEntry.getKey.value
+
+ val qualifierIt = familyEntry.getValue.entrySet().iterator()
+
+ // The qualifier map is a tree map so the families will be sorted
+ while (qualifierIt.hasNext) {
+
+ val qualifierEntry = qualifierIt.next()
+ val qualifier = qualifierEntry.getKey
+ val cellValue = qualifierEntry.getValue
+
+ writeValueToHFile(
+ rowKey.value,
+ family,
+ qualifier.value, // qualifier
+ cellValue, // value
+ nowTimeStamp,
+ fs,
+ conn,
+ localTableName,
+ conf,
+ familyHFileWriteOptionsMapInternal,
+ defaultCompression,
+ writerMap,
+ stagingDir)
+
+ previousRow = rowKey.value
+ }
+
+ writerMap.values.foreach(
+ wl => {
+ rollOverRequested = rollOverRequested || wl.written > maxSize
+
+ // This will only roll if we have at least one column family file that is
+ // bigger then maxSize and we have finished a given row key
+ if (rollOverRequested) {
+ rollWriters(
+ fs,
+ writerMap,
+ regionSplitPartitioner,
+ previousRow,
+ compactionExclude)
+ rollOverRequested = false
+ }
+ })
+ }
+ }
+
+ // This will get a writer for the column family
+ // If there is no writer for a given column family then
+ // it will get created here.
+ // We have finished all the data so lets close up the writers
+ rollWriters(fs, writerMap, regionSplitPartitioner, previousRow, compactionExclude)
+ rollOverRequested = false
+ })
+ } finally {
+ if (null != conn) conn.close()
+ }
+ }
+
+ /**
+ * This will return a new HFile writer when requested
+ *
+ * @param family column family
+ * @param conf configuration to connect to HBase
+ * @param favoredNodes nodes that we would like to write too
+ * @param fs FileSystem object where we will be writing the HFiles to
+ * @return WriterLength object
+ */
+ private def getNewHFileWriter(
+ family: Array[Byte],
+ conf: Configuration,
+ favoredNodes: Array[InetSocketAddress],
+ fs: FileSystem,
+ familydir: Path,
+ familyHFileWriteOptionsMapInternal: util.HashMap[ByteArrayWrapper, FamilyHFileWriteOptions],
+ defaultCompression: Compression.Algorithm): WriterLength = {
+
+ var familyOptions = familyHFileWriteOptionsMapInternal.get(new ByteArrayWrapper(family))
+
+ if (familyOptions == null) {
+ familyOptions = new FamilyHFileWriteOptions(
+ defaultCompression.toString,
+ BloomType.NONE.toString,
+ HConstants.DEFAULT_BLOCKSIZE,
+ DataBlockEncoding.NONE.toString)
+ familyHFileWriteOptionsMapInternal.put(new ByteArrayWrapper(family), familyOptions)
+ }
+
+ val tempConf = new Configuration(conf)
+ tempConf.setFloat(HConstants.HFILE_BLOCK_CACHE_SIZE_KEY, 0.0f)
+
+ // HBASE-25249 introduced an incompatible change in the IA.Private HStore and StoreUtils
+ // so here, we directly use conf.get for CheckSumType and BytesPerCheckSum to make it
+ // compatible between hbase 2.3.x and 2.4.x
+ val contextBuilder = new HFileContextBuilder()
+ .withCompression(Algorithm.valueOf(familyOptions.compression))
+ // ChecksumType.nameToType is still an IA.Private Utils, but it's unlikely to be changed.
+ .withChecksumType(
+ ChecksumType
+ .nameToType(
+ conf.get(HConstants.CHECKSUM_TYPE_NAME, ChecksumType.getDefaultChecksumType.getName)))
+ .withCellComparator(CellComparator.getInstance())
+ .withBytesPerCheckSum(
+ conf.getInt(HConstants.BYTES_PER_CHECKSUM, HFile.DEFAULT_BYTES_PER_CHECKSUM))
+ .withBlockSize(familyOptions.blockSize)
+
+ if (HFile.getFormatVersion(conf) >= HFile.MIN_FORMAT_VERSION_WITH_TAGS) {
+ contextBuilder.withIncludesTags(true)
+ }
+
+ contextBuilder.withDataBlockEncoding(DataBlockEncoding.valueOf(familyOptions.dataBlockEncoding))
+ val hFileContext = contextBuilder.build()
+
+ // Add a '_' to the file name because this is a unfinished file. A rename will happen
+ // to remove the '_' when the file is closed.
+ new WriterLength(
+ 0,
+ new StoreFileWriter.Builder(conf, new CacheConfig(tempConf), new HFileSystem(fs))
+ .withBloomType(BloomType.valueOf(familyOptions.bloomType))
+ .withFileContext(hFileContext)
+ .withFilePath(new Path(familydir, "_" + UUID.randomUUID.toString.replaceAll("-", "")))
+ .withFavoredNodes(favoredNodes)
+ .build())
+
+ }
+
+ /**
+ * Encompasses the logic to write a value to an HFile
+ *
+ * @param rowKey The RowKey for the record
+ * @param family HBase column family for the record
+ * @param qualifier HBase column qualifier for the record
+ * @param cellValue HBase cell value
+ * @param nowTimeStamp The cell time stamp
+ * @param fs Connection to the FileSystem for the HFile
+ * @param conn Connection to HBaes
+ * @param tableName HBase TableName object
+ * @param conf Configuration to be used when making a new HFile
+ * @param familyHFileWriteOptionsMapInternal Extra configs for the HFile
+ * @param hfileCompression The compression codec for the new HFile
+ * @param writerMap HashMap of existing writers and their offsets
+ * @param stagingDir The staging directory on the FileSystem to store
+ * the HFiles
+ * @return The writer for the given HFile that was writen
+ * too
+ */
+ private def writeValueToHFile(
+ rowKey: Array[Byte],
+ family: Array[Byte],
+ qualifier: Array[Byte],
+ cellValue: Array[Byte],
+ nowTimeStamp: Long,
+ fs: FileSystem,
+ conn: Connection,
+ tableName: TableName,
+ conf: Configuration,
+ familyHFileWriteOptionsMapInternal: util.HashMap[ByteArrayWrapper, FamilyHFileWriteOptions],
+ hfileCompression: Compression.Algorithm,
+ writerMap: mutable.HashMap[ByteArrayWrapper, WriterLength],
+ stagingDir: String): WriterLength = {
+
+ val wl = writerMap.getOrElseUpdate(
+ new ByteArrayWrapper(family), {
+ val familyDir = new Path(stagingDir, Bytes.toString(family))
+
+ familyDir.getFileSystem(conf).mkdirs(familyDir);
+
+ val loc: HRegionLocation = {
+ try {
+ val locator =
+ conn.getRegionLocator(tableName)
+ locator.getRegionLocation(rowKey)
+ } catch {
+ case e: Throwable =>
+ logWarning(
+ "there's something wrong when locating rowkey: " +
+ Bytes.toString(rowKey))
+ null
+ }
+ }
+ if (null == loc) {
+ if (log.isTraceEnabled) {
+ logTrace(
+ "failed to get region location, so use default writer: " +
+ Bytes.toString(rowKey))
+ }
+ getNewHFileWriter(
+ family = family,
+ conf = conf,
+ favoredNodes = null,
+ fs = fs,
+ familydir = familyDir,
+ familyHFileWriteOptionsMapInternal,
+ hfileCompression)
+ } else {
+ if (log.isDebugEnabled) {
+ logDebug("first rowkey: [" + Bytes.toString(rowKey) + "]")
+ }
+ val initialIsa =
+ new InetSocketAddress(loc.getHostname, loc.getPort)
+ if (initialIsa.isUnresolved) {
+ if (log.isTraceEnabled) {
+ logTrace(
+ "failed to resolve bind address: " + loc.getHostname + ":"
+ + loc.getPort + ", so use default writer")
+ }
+ getNewHFileWriter(
+ family,
+ conf,
+ null,
+ fs,
+ familyDir,
+ familyHFileWriteOptionsMapInternal,
+ hfileCompression)
+ } else {
+ if (log.isDebugEnabled) {
+ logDebug("use favored nodes writer: " + initialIsa.getHostString)
+ }
+ getNewHFileWriter(
+ family,
+ conf,
+ Array[InetSocketAddress](initialIsa),
+ fs,
+ familyDir,
+ familyHFileWriteOptionsMapInternal,
+ hfileCompression)
+ }
+ }
+ })
+
+ val keyValue = new KeyValue(rowKey, family, qualifier, nowTimeStamp, cellValue)
+
+ wl.writer.append(keyValue)
+ wl.written += keyValue.getLength
+
+ wl
+ }
+
+ /**
+ * This will roll all Writers
+ * @param fs Hadoop FileSystem object
+ * @param writerMap HashMap that contains all the writers
+ * @param regionSplitPartitioner The partitioner with knowledge of how the
+ * Region's are split by row key
+ * @param previousRow The last row to fill the HFile ending range metadata
+ * @param compactionExclude The exclude compaction metadata flag for the HFile
+ */
+ private def rollWriters(
+ fs: FileSystem,
+ writerMap: mutable.HashMap[ByteArrayWrapper, WriterLength],
+ regionSplitPartitioner: BulkLoadPartitioner,
+ previousRow: Array[Byte],
+ compactionExclude: Boolean): Unit = {
+ writerMap.values.foreach(
+ wl => {
+ if (wl.writer != null) {
+ logDebug(
+ "Writer=" + wl.writer.getPath +
+ (if (wl.written == 0) "" else ", wrote=" + wl.written))
+ closeHFileWriter(fs, wl.writer, regionSplitPartitioner, previousRow, compactionExclude)
+ }
+ })
+ writerMap.clear()
+
+ }
+
+ /**
+ * Function to close an HFile
+ * @param fs Hadoop FileSystem object
+ * @param w HFile Writer
+ * @param regionSplitPartitioner The partitioner with knowledge of how the
+ * Region's are split by row key
+ * @param previousRow The last row to fill the HFile ending range metadata
+ * @param compactionExclude The exclude compaction metadata flag for the HFile
+ */
+ private def closeHFileWriter(
+ fs: FileSystem,
+ w: StoreFileWriter,
+ regionSplitPartitioner: BulkLoadPartitioner,
+ previousRow: Array[Byte],
+ compactionExclude: Boolean): Unit = {
+ if (w != null) {
+ w.appendFileInfo(HStoreFile.BULKLOAD_TIME_KEY, Bytes.toBytes(System.currentTimeMillis()))
+ w.appendFileInfo(
+ HStoreFile.BULKLOAD_TASK_KEY,
+ Bytes.toBytes(regionSplitPartitioner.getPartition(previousRow)))
+ w.appendFileInfo(HStoreFile.MAJOR_COMPACTION_KEY, Bytes.toBytes(true))
+ w.appendFileInfo(
+ HStoreFile.EXCLUDE_FROM_MINOR_COMPACTION_KEY,
+ Bytes.toBytes(compactionExclude))
+ w.appendTrackedTimestampsToMetadata()
+ w.close()
+
+ val srcPath = w.getPath
+
+ // In the new path you will see that we are using substring. This is to
+ // remove the '_' character in front of the HFile name. '_' is a character
+ // that will tell HBase that this file shouldn't be included in the bulk load
+ // This feature is to protect for unfinished HFiles being submitted to HBase
+ val newPath = new Path(w.getPath.getParent, w.getPath.getName.substring(1))
+ if (!fs.rename(srcPath, newPath)) {
+ throw new IOException(
+ "Unable to rename '" + srcPath +
+ "' to " + newPath)
+ }
+ }
+ }
+
+ /**
+ * This is a wrapper class around StoreFileWriter. The reason for the
+ * wrapper is to keep the length of the file along side the writer
+ *
+ * @param written The writer to be wrapped
+ * @param writer The number of bytes written to the writer
+ */
+ class WriterLength(var written: Long, val writer: StoreFileWriter)
+}
+
+@InterfaceAudience.Private
+object LatestHBaseContextCache {
+ var latest: HBaseContext = null
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseDStreamFunctions.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseDStreamFunctions.scala
new file mode 100644
index 00000000..0ca5038c
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseDStreamFunctions.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client._
+import org.apache.hadoop.hbase.io.ImmutableBytesWritable
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.yetus.audience.InterfaceAudience
+import scala.reflect.ClassTag
+
+/**
+ * HBaseDStreamFunctions contains a set of implicit functions that can be
+ * applied to a Spark DStream so that we can easily interact with HBase
+ */
+@InterfaceAudience.Public
+object HBaseDStreamFunctions {
+
+ /**
+ * These are implicit methods for a DStream that contains any type of
+ * data.
+ *
+ * @param dStream This is for dStreams of any type
+ * @tparam T Type T
+ */
+ implicit class GenericHBaseDStreamFunctions[T](val dStream: DStream[T]) {
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's bulk
+ * put. This will not return a new Stream. Think of it like a foreach
+ *
+ * @param hc The hbaseContext object to identify which
+ * HBase cluster connection to use
+ * @param tableName The tableName that the put will be sent to
+ * @param f The function that will turn the DStream values
+ * into HBase Put objects.
+ */
+ def hbaseBulkPut(hc: HBaseContext, tableName: TableName, f: (T) => Put): Unit = {
+ hc.streamBulkPut(dStream, tableName, f)
+ }
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's bulk
+ * get. This will return a new DStream. Think about it as a DStream map
+ * function. In that every DStream value will get a new value out of
+ * HBase. That new value will populate the newly generated DStream.
+ *
+ * @param hc The hbaseContext object to identify which
+ * HBase cluster connection to use
+ * @param tableName The tableName that the put will be sent to
+ * @param batchSize How many gets to execute in a single batch
+ * @param f The function that will turn the RDD values
+ * in HBase Get objects
+ * @param convertResult The function that will convert a HBase
+ * Result object into a value that will go
+ * into the resulting DStream
+ * @tparam R The type of Object that will be coming
+ * out of the resulting DStream
+ * @return A resulting DStream with type R objects
+ */
+ def hbaseBulkGet[R: ClassTag](
+ hc: HBaseContext,
+ tableName: TableName,
+ batchSize: Int,
+ f: (T) => Get,
+ convertResult: (Result) => R): DStream[R] = {
+ hc.streamBulkGet[T, R](tableName, batchSize, dStream, f, convertResult)
+ }
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's bulk
+ * get. This will return a new DStream. Think about it as a DStream map
+ * function. In that every DStream value will get a new value out of
+ * HBase. That new value will populate the newly generated DStream.
+ *
+ * @param hc The hbaseContext object to identify which
+ * HBase cluster connection to use
+ * @param tableName The tableName that the put will be sent to
+ * @param batchSize How many gets to execute in a single batch
+ * @param f The function that will turn the RDD values
+ * in HBase Get objects
+ * @return A resulting DStream with type R objects
+ */
+ def hbaseBulkGet(
+ hc: HBaseContext,
+ tableName: TableName,
+ batchSize: Int,
+ f: (T) => Get): DStream[(ImmutableBytesWritable, Result)] = {
+ hc.streamBulkGet[T, (ImmutableBytesWritable, Result)](
+ tableName,
+ batchSize,
+ dStream,
+ f,
+ result => (new ImmutableBytesWritable(result.getRow), result))
+ }
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's bulk
+ * Delete. This will not return a new DStream.
+ *
+ * @param hc The hbaseContext object to identify which HBase
+ * cluster connection to use
+ * @param tableName The tableName that the deletes will be sent to
+ * @param f The function that will convert the DStream value into
+ * a HBase Delete Object
+ * @param batchSize The number of Deletes to be sent in a single batch
+ */
+ def hbaseBulkDelete(
+ hc: HBaseContext,
+ tableName: TableName,
+ f: (T) => Delete,
+ batchSize: Int): Unit = {
+ hc.streamBulkDelete(dStream, tableName, f, batchSize)
+ }
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's
+ * foreachPartition method. This will ack very much like a normal DStream
+ * foreach method but for the fact that you will now have a HBase connection
+ * while iterating through the values.
+ *
+ * @param hc The hbaseContext object to identify which HBase
+ * cluster connection to use
+ * @param f This function will get an iterator for a Partition of an
+ * DStream along with a connection object to HBase
+ */
+ def hbaseForeachPartition(hc: HBaseContext, f: (Iterator[T], Connection) => Unit): Unit = {
+ hc.streamForeachPartition(dStream, f)
+ }
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's
+ * mapPartitions method. This will ask very much like a normal DStream
+ * map partitions method but for the fact that you will now have a
+ * HBase connection while iterating through the values
+ *
+ * @param hc The hbaseContext object to identify which HBase
+ * cluster connection to use
+ * @param f This function will get an iterator for a Partition of an
+ * DStream along with a connection object to HBase
+ * @tparam R This is the type of objects that will go into the resulting
+ * DStream
+ * @return A resulting DStream of type R
+ */
+ def hbaseMapPartitions[R: ClassTag](
+ hc: HBaseContext,
+ f: (Iterator[T], Connection) => Iterator[R]): DStream[R] = {
+ hc.streamMapPartitions(dStream, f)
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseRDDFunctions.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseRDDFunctions.scala
new file mode 100644
index 00000000..9e3da565
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseRDDFunctions.scala
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.util
+import org.apache.hadoop.hbase.{HConstants, TableName}
+import org.apache.hadoop.hbase.client._
+import org.apache.hadoop.hbase.io.ImmutableBytesWritable
+import org.apache.spark.rdd.RDD
+import org.apache.yetus.audience.InterfaceAudience
+import scala.reflect.ClassTag
+
+/**
+ * HBaseRDDFunctions contains a set of implicit functions that can be
+ * applied to a Spark RDD so that we can easily interact with HBase
+ */
+@InterfaceAudience.Public
+object HBaseRDDFunctions {
+
+ /**
+ * These are implicit methods for a RDD that contains any type of
+ * data.
+ *
+ * @param rdd This is for rdd of any type
+ * @tparam T This is any type
+ */
+ implicit class GenericHBaseRDDFunctions[T](val rdd: RDD[T]) {
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's bulk
+ * put. This will not return a new RDD. Think of it like a foreach
+ *
+ * @param hc The hbaseContext object to identify which
+ * HBase cluster connection to use
+ * @param tableName The tableName that the put will be sent to
+ * @param f The function that will turn the RDD values
+ * into HBase Put objects.
+ */
+ def hbaseBulkPut(hc: HBaseContext, tableName: TableName, f: (T) => Put): Unit = {
+ hc.bulkPut(rdd, tableName, f)
+ }
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's bulk
+ * get. This will return a new RDD. Think about it as a RDD map
+ * function. In that every RDD value will get a new value out of
+ * HBase. That new value will populate the newly generated RDD.
+ *
+ * @param hc The hbaseContext object to identify which
+ * HBase cluster connection to use
+ * @param tableName The tableName that the put will be sent to
+ * @param batchSize How many gets to execute in a single batch
+ * @param f The function that will turn the RDD values
+ * in HBase Get objects
+ * @param convertResult The function that will convert a HBase
+ * Result object into a value that will go
+ * into the resulting RDD
+ * @tparam R The type of Object that will be coming
+ * out of the resulting RDD
+ * @return A resulting RDD with type R objects
+ */
+ def hbaseBulkGet[R: ClassTag](
+ hc: HBaseContext,
+ tableName: TableName,
+ batchSize: Int,
+ f: (T) => Get,
+ convertResult: (Result) => R): RDD[R] = {
+ hc.bulkGet[T, R](tableName, batchSize, rdd, f, convertResult)
+ }
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's bulk
+ * get. This will return a new RDD. Think about it as a RDD map
+ * function. In that every RDD value will get a new value out of
+ * HBase. That new value will populate the newly generated RDD.
+ *
+ * @param hc The hbaseContext object to identify which
+ * HBase cluster connection to use
+ * @param tableName The tableName that the put will be sent to
+ * @param batchSize How many gets to execute in a single batch
+ * @param f The function that will turn the RDD values
+ * in HBase Get objects
+ * @return A resulting RDD with type R objects
+ */
+ def hbaseBulkGet(
+ hc: HBaseContext,
+ tableName: TableName,
+ batchSize: Int,
+ f: (T) => Get): RDD[(ImmutableBytesWritable, Result)] = {
+ hc.bulkGet[T, (ImmutableBytesWritable, Result)](
+ tableName,
+ batchSize,
+ rdd,
+ f,
+ result =>
+ if (result != null && result.getRow != null) {
+ (new ImmutableBytesWritable(result.getRow), result)
+ } else {
+ null
+ })
+ }
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's bulk
+ * Delete. This will not return a new RDD.
+ *
+ * @param hc The hbaseContext object to identify which HBase
+ * cluster connection to use
+ * @param tableName The tableName that the deletes will be sent to
+ * @param f The function that will convert the RDD value into
+ * a HBase Delete Object
+ * @param batchSize The number of Deletes to be sent in a single batch
+ */
+ def hbaseBulkDelete(
+ hc: HBaseContext,
+ tableName: TableName,
+ f: (T) => Delete,
+ batchSize: Int): Unit = {
+ hc.bulkDelete(rdd, tableName, f, batchSize)
+ }
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's
+ * foreachPartition method. This will ack very much like a normal RDD
+ * foreach method but for the fact that you will now have a HBase connection
+ * while iterating through the values.
+ *
+ * @param hc The hbaseContext object to identify which HBase
+ * cluster connection to use
+ * @param f This function will get an iterator for a Partition of an
+ * RDD along with a connection object to HBase
+ */
+ def hbaseForeachPartition(hc: HBaseContext, f: (Iterator[T], Connection) => Unit): Unit = {
+ hc.foreachPartition(rdd, f)
+ }
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's
+ * mapPartitions method. This will ask very much like a normal RDD
+ * map partitions method but for the fact that you will now have a
+ * HBase connection while iterating through the values
+ *
+ * @param hc The hbaseContext object to identify which HBase
+ * cluster connection to use
+ * @param f This function will get an iterator for a Partition of an
+ * RDD along with a connection object to HBase
+ * @tparam R This is the type of objects that will go into the resulting
+ * RDD
+ * @return A resulting RDD of type R
+ */
+ def hbaseMapPartitions[R: ClassTag](
+ hc: HBaseContext,
+ f: (Iterator[T], Connection) => Iterator[R]): RDD[R] = {
+ hc.mapPartitions[T, R](rdd, f)
+ }
+
+ /**
+ * Spark Implementation of HBase Bulk load for wide rows or when
+ * values are not already combined at the time of the map process
+ *
+ * A Spark Implementation of HBase Bulk load
+ *
+ * This will take the content from an existing RDD then sort and shuffle
+ * it with respect to region splits. The result of that sort and shuffle
+ * will be written to HFiles.
+ *
+ * After this function is executed the user will have to call
+ * LoadIncrementalHFiles.doBulkLoad(...) to move the files into HBase
+ *
+ * Also note this version of bulk load is different from past versions in
+ * that it includes the qualifier as part of the sort process. The
+ * reason for this is to be able to support rows will very large number
+ * of columns.
+ *
+ * @param tableName The HBase table we are loading into
+ * @param flatMap A flapMap function that will make every row in the RDD
+ * into N cells for the bulk load
+ * @param stagingDir The location on the FileSystem to bulk load into
+ * @param familyHFileWriteOptionsMap Options that will define how the HFile for a
+ * column family is written
+ * @param compactionExclude Compaction excluded for the HFiles
+ * @param maxSize Max size for the HFiles before they roll
+ */
+ def hbaseBulkLoad(
+ hc: HBaseContext,
+ tableName: TableName,
+ flatMap: (T) => Iterator[(KeyFamilyQualifier, Array[Byte])],
+ stagingDir: String,
+ familyHFileWriteOptionsMap: util.Map[Array[Byte], FamilyHFileWriteOptions] =
+ new util.HashMap[Array[Byte], FamilyHFileWriteOptions](),
+ compactionExclude: Boolean = false,
+ maxSize: Long = HConstants.DEFAULT_MAX_FILE_SIZE): Unit = {
+ hc.bulkLoad(
+ rdd,
+ tableName,
+ flatMap,
+ stagingDir,
+ familyHFileWriteOptionsMap,
+ compactionExclude,
+ maxSize)
+ }
+
+ /**
+ * Implicit method that gives easy access to HBaseContext's
+ * bulkLoadThinRows method.
+ *
+ * Spark Implementation of HBase Bulk load for short rows some where less then
+ * a 1000 columns. This bulk load should be faster for tables will thinner
+ * rows then the other spark implementation of bulk load that puts only one
+ * value into a record going into a shuffle
+ *
+ * This will take the content from an existing RDD then sort and shuffle
+ * it with respect to region splits. The result of that sort and shuffle
+ * will be written to HFiles.
+ *
+ * After this function is executed the user will have to call
+ * LoadIncrementalHFiles.doBulkLoad(...) to move the files into HBase
+ *
+ * In this implementation only the rowKey is given to the shuffle as the key
+ * and all the columns are already linked to the RowKey before the shuffle
+ * stage. The sorting of the qualifier is done in memory out side of the
+ * shuffle stage
+ *
+ * @param tableName The HBase table we are loading into
+ * @param mapFunction A function that will convert the RDD records to
+ * the key value format used for the shuffle to prep
+ * for writing to the bulk loaded HFiles
+ * @param stagingDir The location on the FileSystem to bulk load into
+ * @param familyHFileWriteOptionsMap Options that will define how the HFile for a
+ * column family is written
+ * @param compactionExclude Compaction excluded for the HFiles
+ * @param maxSize Max size for the HFiles before they roll
+ */
+ def hbaseBulkLoadThinRows(
+ hc: HBaseContext,
+ tableName: TableName,
+ mapFunction: (T) => (ByteArrayWrapper, FamiliesQualifiersValues),
+ stagingDir: String,
+ familyHFileWriteOptionsMap: util.Map[Array[Byte], FamilyHFileWriteOptions] =
+ new util.HashMap[Array[Byte], FamilyHFileWriteOptions](),
+ compactionExclude: Boolean = false,
+ maxSize: Long = HConstants.DEFAULT_MAX_FILE_SIZE): Unit = {
+ hc.bulkLoadThinRows(
+ rdd,
+ tableName,
+ mapFunction,
+ stagingDir,
+ familyHFileWriteOptionsMap,
+ compactionExclude,
+ maxSize)
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/JavaHBaseContext.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/JavaHBaseContext.scala
new file mode 100644
index 00000000..48ac3737
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/JavaHBaseContext.scala
@@ -0,0 +1,420 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.util
+
+import java.lang.Iterable
+import java.util.Map
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.{Connection, Delete, Get, Put, Result, Scan}
+import org.apache.hadoop.hbase.io.ImmutableBytesWritable
+import org.apache.hadoop.hbase.util.Pair
+import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
+import org.apache.spark.api.java.function.{FlatMapFunction, Function, VoidFunction}
+import org.apache.spark.streaming.api.java.JavaDStream
+import org.apache.yetus.audience.InterfaceAudience
+import scala.jdk.CollectionConverters._
+import scala.reflect.ClassTag
+
+/**
+ * This is the Java Wrapper over HBaseContext which is written in
+ * Scala. This class will be used by developers that want to
+ * work with Spark or Spark Streaming in Java
+ *
+ * @param jsc This is the JavaSparkContext that we will wrap
+ * @param config This is the config information to out HBase cluster
+ */
+@InterfaceAudience.Public
+class JavaHBaseContext(@transient val jsc: JavaSparkContext, @transient val config: Configuration)
+ extends Serializable {
+ val hbaseContext = new HBaseContext(jsc.sc, config)
+
+ /**
+ * A simple enrichment of the traditional Spark javaRdd foreachPartition.
+ * This function differs from the original in that it offers the
+ * developer access to a already connected Connection object
+ *
+ * Note: Do not close the Connection object. All Connection
+ * management is handled outside this method
+ *
+ * @param javaRdd Original javaRdd with data to iterate over
+ * @param f Function to be given a iterator to iterate through
+ * the RDD values and a Connection object to interact
+ * with HBase
+ */
+ def foreachPartition[T](
+ javaRdd: JavaRDD[T],
+ f: VoidFunction[(java.util.Iterator[T], Connection)]) = {
+
+ hbaseContext.foreachPartition(
+ javaRdd.rdd,
+ (it: Iterator[T], conn: Connection) => {
+ f.call((it.asJava, conn))
+ })
+ }
+
+ /**
+ * A simple enrichment of the traditional Spark Streaming dStream foreach
+ * This function differs from the original in that it offers the
+ * developer access to a already connected Connection object
+ *
+ * Note: Do not close the Connection object. All Connection
+ * management is handled outside this method
+ *
+ * @param javaDstream Original DStream with data to iterate over
+ * @param f Function to be given a iterator to iterate through
+ * the JavaDStream values and a Connection object to
+ * interact with HBase
+ */
+ def foreachPartition[T](
+ javaDstream: JavaDStream[T],
+ f: VoidFunction[(Iterator[T], Connection)]) = {
+ hbaseContext.foreachPartition(
+ javaDstream.dstream,
+ (it: Iterator[T], conn: Connection) => f.call(it, conn))
+ }
+
+ /**
+ * A simple enrichment of the traditional Spark JavaRDD mapPartition.
+ * This function differs from the original in that it offers the
+ * developer access to a already connected Connection object
+ *
+ * Note: Do not close the Connection object. All Connection
+ * management is handled outside this method
+ *
+ * Note: Make sure to partition correctly to avoid memory issue when
+ * getting data from HBase
+ *
+ * @param javaRdd Original JavaRdd with data to iterate over
+ * @param f Function to be given a iterator to iterate through
+ * the RDD values and a Connection object to interact
+ * with HBase
+ * @return Returns a new RDD generated by the user definition
+ * function just like normal mapPartition
+ */
+ def mapPartitions[T, R](
+ javaRdd: JavaRDD[T],
+ f: FlatMapFunction[(java.util.Iterator[T], Connection), R]): JavaRDD[R] = {
+ JavaRDD.fromRDD(
+ hbaseContext.mapPartitions(
+ javaRdd.rdd,
+ (it: Iterator[T], conn: Connection) => f.call((it.asJava, conn)).asScala)(fakeClassTag[R]))(fakeClassTag[R])
+ }
+
+ /**
+ * A simple enrichment of the traditional Spark Streaming JavaDStream
+ * mapPartition.
+ *
+ * This function differs from the original in that it offers the
+ * developer access to a already connected Connection object
+ *
+ * Note: Do not close the Connection object. All Connection
+ * management is handled outside this method
+ *
+ * Note: Make sure to partition correctly to avoid memory issue when
+ * getting data from HBase
+ *
+ * @param javaDstream Original JavaDStream with data to iterate over
+ * @param mp Function to be given a iterator to iterate through
+ * the JavaDStream values and a Connection object to
+ * interact with HBase
+ * @return Returns a new JavaDStream generated by the user
+ * definition function just like normal mapPartition
+ */
+ def streamMap[T, U](
+ javaDstream: JavaDStream[T],
+ mp: Function[(Iterator[T], Connection), Iterator[U]]): JavaDStream[U] = {
+ JavaDStream.fromDStream(
+ hbaseContext.streamMapPartitions(
+ javaDstream.dstream,
+ (it: Iterator[T], conn: Connection) => mp.call(it, conn))(fakeClassTag[U]))(fakeClassTag[U])
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.foreachPartition method.
+ *
+ * It allow addition support for a user to take JavaRDD
+ * and generate puts and send them to HBase.
+ * The complexity of managing the Connection is
+ * removed from the developer
+ *
+ * @param javaRdd Original JavaRDD with data to iterate over
+ * @param tableName The name of the table to put into
+ * @param f Function to convert a value in the JavaRDD
+ * to a HBase Put
+ */
+ def bulkPut[T](javaRdd: JavaRDD[T], tableName: TableName, f: Function[(T), Put]) {
+
+ hbaseContext.bulkPut(javaRdd.rdd, tableName, (t: T) => f.call(t))
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.streamMapPartition method.
+ *
+ * It allow addition support for a user to take a JavaDStream and
+ * generate puts and send them to HBase.
+ *
+ * The complexity of managing the Connection is
+ * removed from the developer
+ *
+ * @param javaDstream Original DStream with data to iterate over
+ * @param tableName The name of the table to put into
+ * @param f Function to convert a value in
+ * the JavaDStream to a HBase Put
+ */
+ def streamBulkPut[T](javaDstream: JavaDStream[T], tableName: TableName, f: Function[T, Put]) = {
+ hbaseContext.streamBulkPut(javaDstream.dstream, tableName, (t: T) => f.call(t))
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.foreachPartition method.
+ *
+ * It allow addition support for a user to take a JavaRDD and
+ * generate delete and send them to HBase.
+ *
+ * The complexity of managing the Connection is
+ * removed from the developer
+ *
+ * @param javaRdd Original JavaRDD with data to iterate over
+ * @param tableName The name of the table to delete from
+ * @param f Function to convert a value in the JavaRDD to a
+ * HBase Deletes
+ * @param batchSize The number of deletes to batch before sending to HBase
+ */
+ def bulkDelete[T](
+ javaRdd: JavaRDD[T],
+ tableName: TableName,
+ f: Function[T, Delete],
+ batchSize: Integer) {
+ hbaseContext.bulkDelete(javaRdd.rdd, tableName, (t: T) => f.call(t), batchSize)
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.streamBulkMutation method.
+ *
+ * It allow addition support for a user to take a JavaDStream and
+ * generate Delete and send them to HBase.
+ *
+ * The complexity of managing the Connection is
+ * removed from the developer
+ *
+ * @param javaDStream Original DStream with data to iterate over
+ * @param tableName The name of the table to delete from
+ * @param f Function to convert a value in the JavaDStream to a
+ * HBase Delete
+ * @param batchSize The number of deletes to be sent at once
+ */
+ def streamBulkDelete[T](
+ javaDStream: JavaDStream[T],
+ tableName: TableName,
+ f: Function[T, Delete],
+ batchSize: Integer) = {
+ hbaseContext.streamBulkDelete(javaDStream.dstream, tableName, (t: T) => f.call(t), batchSize)
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.mapPartition method.
+ *
+ * It allow addition support for a user to take a JavaRDD and generates a
+ * new RDD based on Gets and the results they bring back from HBase
+ *
+ * @param tableName The name of the table to get from
+ * @param batchSize batch size of how many gets to retrieve in a single fetch
+ * @param javaRdd Original JavaRDD with data to iterate over
+ * @param makeGet Function to convert a value in the JavaRDD to a
+ * HBase Get
+ * @param convertResult This will convert the HBase Result object to
+ * what ever the user wants to put in the resulting
+ * JavaRDD
+ * @return New JavaRDD that is created by the Get to HBase
+ */
+ def bulkGet[T, U](
+ tableName: TableName,
+ batchSize: Integer,
+ javaRdd: JavaRDD[T],
+ makeGet: Function[T, Get],
+ convertResult: Function[Result, U]): JavaRDD[U] = {
+
+ JavaRDD.fromRDD(
+ hbaseContext.bulkGet[T, U](
+ tableName,
+ batchSize,
+ javaRdd.rdd,
+ (t: T) => makeGet.call(t),
+ (r: Result) => {
+ convertResult.call(r)
+ })(fakeClassTag[U]))(fakeClassTag[U])
+
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.streamMap method.
+ *
+ * It allow addition support for a user to take a DStream and
+ * generates a new DStream based on Gets and the results
+ * they bring back from HBase
+ *
+ * @param tableName The name of the table to get from
+ * @param batchSize The number of gets to be batched together
+ * @param javaDStream Original DStream with data to iterate over
+ * @param makeGet Function to convert a value in the JavaDStream to a
+ * HBase Get
+ * @param convertResult This will convert the HBase Result object to
+ * what ever the user wants to put in the resulting
+ * JavaDStream
+ * @return New JavaDStream that is created by the Get to HBase
+ */
+ def streamBulkGet[T, U](
+ tableName: TableName,
+ batchSize: Integer,
+ javaDStream: JavaDStream[T],
+ makeGet: Function[T, Get],
+ convertResult: Function[Result, U]): JavaDStream[U] = {
+ JavaDStream.fromDStream(
+ hbaseContext.streamBulkGet(
+ tableName,
+ batchSize,
+ javaDStream.dstream,
+ (t: T) => makeGet.call(t),
+ (r: Result) => convertResult.call(r))(fakeClassTag[U]))(fakeClassTag[U])
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.bulkLoad method.
+ * It allow addition support for a user to take a JavaRDD and
+ * convert into new JavaRDD[Pair] based on MapFunction,
+ * and HFiles will be generated in stagingDir for bulk load
+ *
+ * @param javaRdd The javaRDD we are bulk loading from
+ * @param tableName The HBase table we are loading into
+ * @param mapFunc A Function that will convert a value in JavaRDD
+ * to Pair(KeyFamilyQualifier, Array[Byte])
+ * @param stagingDir The location on the FileSystem to bulk load into
+ * @param familyHFileWriteOptionsMap Options that will define how the HFile for a
+ * column family is written
+ * @param compactionExclude Compaction excluded for the HFiles
+ * @param maxSize Max size for the HFiles before they roll
+ */
+ def bulkLoad[T](
+ javaRdd: JavaRDD[T],
+ tableName: TableName,
+ mapFunc: Function[T, Pair[KeyFamilyQualifier, Array[Byte]]],
+ stagingDir: String,
+ familyHFileWriteOptionsMap: Map[Array[Byte], FamilyHFileWriteOptions],
+ compactionExclude: Boolean,
+ maxSize: Long): Unit = {
+ hbaseContext.bulkLoad[Pair[KeyFamilyQualifier, Array[Byte]]](
+ javaRdd.map(mapFunc).rdd,
+ tableName,
+ t => {
+ val keyFamilyQualifier = t.getFirst
+ val value = t.getSecond
+ Seq((keyFamilyQualifier, value)).iterator
+ },
+ stagingDir,
+ familyHFileWriteOptionsMap,
+ compactionExclude,
+ maxSize)
+ }
+
+ /**
+ * A simple abstraction over the HBaseContext.bulkLoadThinRows method.
+ * It allow addition support for a user to take a JavaRDD and
+ * convert into new JavaRDD[Pair] based on MapFunction,
+ * and HFiles will be generated in stagingDir for bulk load
+ *
+ * @param javaRdd The javaRDD we are bulk loading from
+ * @param tableName The HBase table we are loading into
+ * @param mapFunc A Function that will convert a value in JavaRDD
+ * to Pair(ByteArrayWrapper, FamiliesQualifiersValues)
+ * @param stagingDir The location on the FileSystem to bulk load into
+ * @param familyHFileWriteOptionsMap Options that will define how the HFile for a
+ * column family is written
+ * @param compactionExclude Compaction excluded for the HFiles
+ * @param maxSize Max size for the HFiles before they roll
+ */
+ def bulkLoadThinRows[T](
+ javaRdd: JavaRDD[T],
+ tableName: TableName,
+ mapFunc: Function[T, Pair[ByteArrayWrapper, FamiliesQualifiersValues]],
+ stagingDir: String,
+ familyHFileWriteOptionsMap: Map[Array[Byte], FamilyHFileWriteOptions],
+ compactionExclude: Boolean,
+ maxSize: Long): Unit = {
+ hbaseContext.bulkLoadThinRows[Pair[ByteArrayWrapper, FamiliesQualifiersValues]](
+ javaRdd.map(mapFunc).rdd,
+ tableName,
+ t => {
+ (t.getFirst, t.getSecond)
+ },
+ stagingDir,
+ familyHFileWriteOptionsMap,
+ compactionExclude,
+ maxSize)
+ }
+
+ /**
+ * This function will use the native HBase TableInputFormat with the
+ * given scan object to generate a new JavaRDD
+ *
+ * @param tableName The name of the table to scan
+ * @param scans The HBase scan object to use to read data from HBase
+ * @param f Function to convert a Result object from HBase into
+ * What the user wants in the final generated JavaRDD
+ * @return New JavaRDD with results from scan
+ */
+ def hbaseRDD[U](
+ tableName: TableName,
+ scans: Scan,
+ f: Function[(ImmutableBytesWritable, Result), U]): JavaRDD[U] = {
+ JavaRDD.fromRDD(
+ hbaseContext
+ .hbaseRDD[U](tableName, scans, (v: (ImmutableBytesWritable, Result)) => f.call(v._1, v._2))(
+ fakeClassTag[U]))(fakeClassTag[U])
+ }
+
+ /**
+ * A overloaded version of HBaseContext hbaseRDD that define the
+ * type of the resulting JavaRDD
+ *
+ * @param tableName The name of the table to scan
+ * @param scans The HBase scan object to use to read data from HBase
+ * @return New JavaRDD with results from scan
+ */
+ def hbaseRDD(tableName: TableName, scans: Scan): JavaRDD[(ImmutableBytesWritable, Result)] = {
+ JavaRDD.fromRDD(hbaseContext.hbaseRDD(tableName, scans))
+ }
+
+ /**
+ * Produces a ClassTag[T], which is actually just a casted ClassTag[AnyRef].
+ *
+ * This method is used to keep ClassTags out of the external Java API, as the Java compiler
+ * cannot produce them automatically. While this ClassTag-faking does please the compiler,
+ * it can cause problems at runtime if the Scala API relies on ClassTags for correctness.
+ *
+ * Often, though, a ClassTag[AnyRef] will not lead to incorrect behavior,
+ * just worse performance or security issues.
+ * For instance, an Array[AnyRef] can hold any type T,
+ * but may lose primitive
+ * specialization.
+ */
+ private[spark] def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]]
+
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/KeyFamilyQualifier.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/KeyFamilyQualifier.scala
new file mode 100644
index 00000000..d71e31a5
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/KeyFamilyQualifier.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.io.Serializable
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is the key to be used for sorting and shuffling.
+ *
+ * We will only partition on the rowKey but we will sort on all three
+ *
+ * @param rowKey Record RowKey
+ * @param family Record ColumnFamily
+ * @param qualifier Cell Qualifier
+ */
+@InterfaceAudience.Public
+class KeyFamilyQualifier(
+ val rowKey: Array[Byte],
+ val family: Array[Byte],
+ val qualifier: Array[Byte])
+ extends Comparable[KeyFamilyQualifier]
+ with Serializable {
+ override def compareTo(o: KeyFamilyQualifier): Int = {
+ var result = Bytes.compareTo(rowKey, o.rowKey)
+ if (result == 0) {
+ result = Bytes.compareTo(family, o.family)
+ if (result == 0) result = Bytes.compareTo(qualifier, o.qualifier)
+ }
+ result
+ }
+ override def toString: String = {
+ Bytes.toString(rowKey) + ":" + Bytes.toString(family) + ":" + Bytes.toString(qualifier)
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/Logging.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/Logging.scala
new file mode 100644
index 00000000..18636313
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/Logging.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.yetus.audience.InterfaceAudience
+import org.slf4j.Logger
+import org.slf4j.LoggerFactory
+import org.slf4j.impl.StaticLoggerBinder
+
+/**
+ * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows
+ * logging messages at different levels using methods that only evaluate parameters lazily if the
+ * log level is enabled.
+ * Logging is private in Spark 2.0
+ * This is to isolate incompatibilties across Spark releases.
+ */
+@InterfaceAudience.Private
+trait Logging {
+
+ // Make the log field transient so that objects with Logging can
+ // be serialized and used on another machine
+ @transient private var log_ : Logger = null
+
+ // Method to get the logger name for this object
+ protected def logName = {
+ // Ignore trailing $'s in the class names for Scala objects
+ this.getClass.getName.stripSuffix("$")
+ }
+
+ // Method to get or create the logger for this object
+ protected def log: Logger = {
+ if (log_ == null) {
+ initializeLogIfNecessary(false)
+ log_ = LoggerFactory.getLogger(logName)
+ }
+ log_
+ }
+
+ // Log methods that take only a String
+ protected def logInfo(msg: => String) {
+ if (log.isInfoEnabled) log.info(msg)
+ }
+
+ protected def logDebug(msg: => String) {
+ if (log.isDebugEnabled) log.debug(msg)
+ }
+
+ protected def logTrace(msg: => String) {
+ if (log.isTraceEnabled) log.trace(msg)
+ }
+
+ protected def logWarning(msg: => String) {
+ if (log.isWarnEnabled) log.warn(msg)
+ }
+
+ protected def logError(msg: => String) {
+ if (log.isErrorEnabled) log.error(msg)
+ }
+
+ // Log methods that take Throwables (Exceptions/Errors) too
+ protected def logInfo(msg: => String, throwable: Throwable) {
+ if (log.isInfoEnabled) log.info(msg, throwable)
+ }
+
+ protected def logDebug(msg: => String, throwable: Throwable) {
+ if (log.isDebugEnabled) log.debug(msg, throwable)
+ }
+
+ protected def logTrace(msg: => String, throwable: Throwable) {
+ if (log.isTraceEnabled) log.trace(msg, throwable)
+ }
+
+ protected def logWarning(msg: => String, throwable: Throwable) {
+ if (log.isWarnEnabled) log.warn(msg, throwable)
+ }
+
+ protected def logError(msg: => String, throwable: Throwable) {
+ if (log.isErrorEnabled) log.error(msg, throwable)
+ }
+
+ protected def initializeLogIfNecessary(isInterpreter: Boolean): Unit = {
+ if (!Logging.initialized) {
+ Logging.initLock.synchronized {
+ if (!Logging.initialized) {
+ initializeLogging(isInterpreter)
+ }
+ }
+ }
+ }
+
+ private def initializeLogging(isInterpreter: Boolean): Unit = {
+ // Don't use a logger in here, as this is itself occurring during initialization of a logger
+ // If Log4j 1.2 is being used, but is not initialized, load a default properties file
+ val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr
+ Logging.initialized = true
+
+ // Force a call into slf4j to initialize it. Avoids this happening from multiple threads
+ // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html
+ log
+ }
+}
+
+private object Logging {
+ @volatile private var initialized = false
+ val initLock = new Object()
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/NewHBaseRDD.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/NewHBaseRDD.scala
new file mode 100644
index 00000000..aeb502d9
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/NewHBaseRDD.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapreduce.InputFormat
+import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
+import org.apache.spark.rdd.NewHadoopRDD
+import org.apache.yetus.audience.InterfaceAudience
+
+@InterfaceAudience.Public
+class NewHBaseRDD[K, V](
+ @transient val sc: SparkContext,
+ @transient val inputFormatClass: Class[_ <: InputFormat[K, V]],
+ @transient val keyClass: Class[K],
+ @transient val valueClass: Class[V],
+ @transient private val __conf: Configuration,
+ val hBaseContext: HBaseContext)
+ extends NewHadoopRDD(sc, inputFormatClass, keyClass, valueClass, __conf) {
+
+ override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
+ hBaseContext.applyCreds()
+ super.compute(theSplit, context)
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Bound.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Bound.scala
new file mode 100644
index 00000000..ef0d287c
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Bound.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.datasources
+
+import org.apache.hadoop.hbase.spark.hbase._
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * The Bound represent the boudary for the scan
+ *
+ * @param b The byte array of the bound
+ * @param inc inclusive or not.
+ */
+@InterfaceAudience.Private
+case class Bound(b: Array[Byte], inc: Boolean)
+// The non-overlapping ranges we need to scan, if lower is equal to upper, it is a get request
+
+@InterfaceAudience.Private
+case class Range(lower: Option[Bound], upper: Option[Bound])
+
+@InterfaceAudience.Private
+object Range {
+ def apply(region: HBaseRegion): Range = {
+ Range(
+ region.start.map(Bound(_, true)),
+ if (region.end.get.size == 0) {
+ None
+ } else {
+ region.end.map((Bound(_, false)))
+ })
+ }
+}
+
+@InterfaceAudience.Private
+object Ranges {
+ // We assume that
+ // 1. r.lower.inc is true, and r.upper.inc is false
+ // 2. for each range in rs, its upper.inc is false
+ def and(r: Range, rs: Seq[Range]): Seq[Range] = {
+ rs.flatMap {
+ s =>
+ val lower = s.lower
+ .map {
+ x =>
+ // the scan has lower bound
+ r.lower
+ .map {
+ y =>
+ // the region has lower bound
+ if (ord.compare(x.b, y.b) < 0) {
+ // scan lower bound is smaller than region server lower bound
+ Some(y)
+ } else {
+ // scan low bound is greater or equal to region server lower bound
+ Some(x)
+ }
+ }
+ .getOrElse(Some(x))
+ }
+ .getOrElse(r.lower)
+
+ val upper = s.upper
+ .map {
+ x =>
+ // the scan has upper bound
+ r.upper
+ .map {
+ y =>
+ // the region has upper bound
+ if (ord.compare(x.b, y.b) >= 0) {
+ // scan upper bound is larger than server upper bound
+ // but region server scan stop is exclusive. It is OK here.
+ Some(y)
+ } else {
+ // scan upper bound is less or equal to region server upper bound
+ Some(x)
+ }
+ }
+ .getOrElse(Some(x))
+ }
+ .getOrElse(r.upper)
+
+ val c = lower
+ .map {
+ case x =>
+ upper
+ .map {
+ case y =>
+ ord.compare(x.b, y.b)
+ }
+ .getOrElse(-1)
+ }
+ .getOrElse(-1)
+ if (c < 0) {
+ Some(Range(lower, upper))
+ } else {
+ None
+ }
+ }.seq
+ }
+}
+
+@InterfaceAudience.Private
+object Points {
+ def and(r: Range, ps: Seq[Array[Byte]]): Seq[Array[Byte]] = {
+ ps.flatMap {
+ p =>
+ if (ord.compare(r.lower.get.b, p) <= 0) {
+ // if region lower bound is less or equal to the point
+ if (r.upper.isDefined) {
+ // if region upper bound is defined
+ if (ord.compare(r.upper.get.b, p) > 0) {
+ // if the upper bound is greater than the point (because upper bound is exclusive)
+ Some(p)
+ } else {
+ None
+ }
+ } else {
+ // if the region upper bound is not defined (infinity)
+ Some(p)
+ }
+ } else {
+ None
+ }
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/DataTypeParserWrapper.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/DataTypeParserWrapper.scala
new file mode 100644
index 00000000..936276da
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/DataTypeParserWrapper.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.datasources
+
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.types.DataType
+import org.apache.yetus.audience.InterfaceAudience
+
+@InterfaceAudience.Private
+trait DataTypeParser {
+ def parse(dataTypeString: String): DataType
+}
+
+@InterfaceAudience.Private
+object DataTypeParserWrapper extends DataTypeParser {
+ def parse(dataTypeString: String): DataType = CatalystSqlParser.parseDataType(dataTypeString)
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseResources.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseResources.scala
new file mode 100644
index 00000000..a179d514
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseResources.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.datasources
+
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client._
+import org.apache.hadoop.hbase.spark.{HBaseConnectionCache, HBaseConnectionKey, HBaseRelation, SmartConnection}
+import org.apache.yetus.audience.InterfaceAudience
+import scala.language.implicitConversions
+
+// Resource and ReferencedResources are defined for extensibility,
+// e.g., consolidate scan and bulkGet in the future work.
+
+// User has to invoke release explicitly to release the resource,
+// and potentially parent resources
+@InterfaceAudience.Private
+trait Resource {
+ def release(): Unit
+}
+
+@InterfaceAudience.Private
+case class ScanResource(tbr: TableResource, rs: ResultScanner) extends Resource {
+ def release() {
+ rs.close()
+ tbr.release()
+ }
+}
+
+@InterfaceAudience.Private
+case class GetResource(tbr: TableResource, rs: Array[Result]) extends Resource {
+ def release() {
+ tbr.release()
+ }
+}
+
+@InterfaceAudience.Private
+trait ReferencedResource {
+ var count: Int = 0
+ def init(): Unit
+ def destroy(): Unit
+ def acquire() = synchronized {
+ try {
+ count += 1
+ if (count == 1) {
+ init()
+ }
+ } catch {
+ case e: Throwable =>
+ release()
+ throw e
+ }
+ }
+
+ def release() = synchronized {
+ count -= 1
+ if (count == 0) {
+ destroy()
+ }
+ }
+
+ def releaseOnException[T](func: => T): T = {
+ acquire()
+ val ret = {
+ try {
+ func
+ } catch {
+ case e: Throwable =>
+ release()
+ throw e
+ }
+ }
+ ret
+ }
+}
+
+@InterfaceAudience.Private
+case class TableResource(relation: HBaseRelation) extends ReferencedResource {
+ var connection: SmartConnection = _
+ var table: Table = _
+
+ override def init(): Unit = {
+ connection = HBaseConnectionCache.getConnection(relation.hbaseConf)
+ table = connection.getTable(TableName.valueOf(relation.tableName))
+ }
+
+ override def destroy(): Unit = {
+ if (table != null) {
+ table.close()
+ table = null
+ }
+ if (connection != null) {
+ connection.close()
+ connection = null
+ }
+ }
+
+ def getScanner(scan: Scan): ScanResource = releaseOnException {
+ ScanResource(this, table.getScanner(scan))
+ }
+
+ def get(list: java.util.List[org.apache.hadoop.hbase.client.Get]) = releaseOnException {
+ GetResource(this, table.get(list))
+ }
+}
+
+@InterfaceAudience.Private
+case class RegionResource(relation: HBaseRelation) extends ReferencedResource {
+ var connection: SmartConnection = _
+ var rl: RegionLocator = _
+ val regions = releaseOnException {
+ val keys = rl.getStartEndKeys
+ keys.getFirst
+ .zip(keys.getSecond)
+ .zipWithIndex
+ .map(
+ x =>
+ HBaseRegion(
+ x._2,
+ Some(x._1._1),
+ Some(x._1._2),
+ Some(rl.getRegionLocation(x._1._1).getHostname)))
+ }
+
+ override def init(): Unit = {
+ connection = HBaseConnectionCache.getConnection(relation.hbaseConf)
+ rl = connection.getRegionLocator(TableName.valueOf(relation.tableName))
+ }
+
+ override def destroy(): Unit = {
+ if (rl != null) {
+ rl.close()
+ rl = null
+ }
+ if (connection != null) {
+ connection.close()
+ connection = null
+ }
+ }
+}
+
+@InterfaceAudience.Private
+object HBaseResources {
+ implicit def ScanResToScan(sr: ScanResource): ResultScanner = {
+ sr.rs
+ }
+
+ implicit def GetResToResult(gr: GetResource): Array[Result] = {
+ gr.rs
+ }
+
+ implicit def TableResToTable(tr: TableResource): Table = {
+ tr.table
+ }
+
+ implicit def RegionResToRegions(rr: RegionResource): Seq[HBaseRegion] = {
+ rr.regions
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala
new file mode 100644
index 00000000..77c15316
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.datasources
+
+import org.apache.yetus.audience.InterfaceAudience;
+
+/**
+ * This is the hbase configuration. User can either set them in SparkConf, which
+ * will take effect globally, or configure it per table, which will overwrite the value
+ * set in SparkConf. If not set, the default value will take effect.
+ */
+@InterfaceAudience.Public
+object HBaseSparkConf {
+
+ /**
+ * Set to false to disable server-side caching of blocks for this scan,
+ * false by default, since full table scans generate too much BC churn.
+ */
+ val QUERY_CACHEBLOCKS = "hbase.spark.query.cacheblocks"
+ val DEFAULT_QUERY_CACHEBLOCKS = false
+
+ /** The number of rows for caching that will be passed to scan. */
+ val QUERY_CACHEDROWS = "hbase.spark.query.cachedrows"
+
+ /** Set the maximum number of values to return for each call to next() in scan. */
+ val QUERY_BATCHSIZE = "hbase.spark.query.batchsize"
+
+ /** The number of BulkGets send to HBase. */
+ val BULKGET_SIZE = "hbase.spark.bulkget.size"
+ val DEFAULT_BULKGET_SIZE = 1000
+
+ /** Set to specify the location of hbase configuration file. */
+ val HBASE_CONFIG_LOCATION = "hbase.spark.config.location"
+
+ /** Set to specify whether create or use latest cached HBaseContext */
+ val USE_HBASECONTEXT = "hbase.spark.use.hbasecontext"
+ val DEFAULT_USE_HBASECONTEXT = true
+
+ /** Pushdown the filter to data source engine to increase the performance of queries. */
+ val PUSHDOWN_COLUMN_FILTER = "hbase.spark.pushdown.columnfilter"
+ val DEFAULT_PUSHDOWN_COLUMN_FILTER = true
+
+ /** Class name of the encoder, which encode data types from Spark to HBase bytes. */
+ val QUERY_ENCODER = "hbase.spark.query.encoder"
+ val DEFAULT_QUERY_ENCODER = classOf[NaiveEncoder].getCanonicalName
+
+ /** The timestamp used to filter columns with a specific timestamp. */
+ val TIMESTAMP = "hbase.spark.query.timestamp"
+
+ /** The starting timestamp used to filter columns with a specific range of versions. */
+ val TIMERANGE_START = "hbase.spark.query.timerange.start"
+
+ /** The ending timestamp used to filter columns with a specific range of versions. */
+ val TIMERANGE_END = "hbase.spark.query.timerange.end"
+
+ /** The maximum number of version to return. */
+ val MAX_VERSIONS = "hbase.spark.query.maxVersions"
+
+ /** Delayed time to close hbase-spark connection when no reference to this connection, in milliseconds. */
+ val DEFAULT_CONNECTION_CLOSE_DELAY = 10 * 60 * 1000
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableCatalog.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableCatalog.scala
new file mode 100644
index 00000000..d88d306c
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableCatalog.scala
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.datasources
+
+import org.apache.avro.Schema
+import org.apache.hadoop.hbase.spark.{Logging, SchemaConverters}
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.sql.types._
+import org.apache.yetus.audience.InterfaceAudience
+import org.json4s.DefaultFormats
+import org.json4s.Formats
+import org.json4s.jackson.JsonMethods
+import scala.collection.mutable
+
+// The definition of each column cell, which may be composite type
+// TODO: add avro support
+@InterfaceAudience.Private
+case class Field(
+ colName: String,
+ cf: String,
+ col: String,
+ sType: Option[String] = None,
+ avroSchema: Option[String] = None,
+ serdes: Option[SerDes] = None,
+ len: Int = -1)
+ extends Logging {
+ override def toString = s"$colName $cf $col"
+ val isRowKey = cf == HBaseTableCatalog.rowKey
+ var start: Int = _
+ def schema: Option[Schema] = avroSchema.map {
+ x =>
+ logDebug(s"avro: $x")
+ val p = new Schema.Parser
+ p.parse(x)
+ }
+
+ lazy val exeSchema = schema
+
+ // converter from avro to catalyst structure
+ lazy val avroToCatalyst: Option[Any => Any] = {
+ schema.map(SchemaConverters.createConverterToSQL(_))
+ }
+
+ // converter from catalyst to avro
+ lazy val catalystToAvro: (Any) => Any = {
+ SchemaConverters.createConverterToAvro(dt, colName, "recordNamespace")
+ }
+
+ def cfBytes: Array[Byte] = {
+ if (isRowKey) {
+ Bytes.toBytes("")
+ } else {
+ Bytes.toBytes(cf)
+ }
+ }
+ def colBytes: Array[Byte] = {
+ if (isRowKey) {
+ Bytes.toBytes("key")
+ } else {
+ Bytes.toBytes(col)
+ }
+ }
+
+ val dt = {
+ sType.map(DataTypeParserWrapper.parse(_)).getOrElse {
+ schema.map { x => SchemaConverters.toSqlType(x).dataType }.get
+ }
+ }
+
+ var length: Int = {
+ if (len == -1) {
+ dt match {
+ case BinaryType | StringType => -1
+ case BooleanType => Bytes.SIZEOF_BOOLEAN
+ case ByteType => 1
+ case DoubleType => Bytes.SIZEOF_DOUBLE
+ case FloatType => Bytes.SIZEOF_FLOAT
+ case IntegerType => Bytes.SIZEOF_INT
+ case LongType => Bytes.SIZEOF_LONG
+ case ShortType => Bytes.SIZEOF_SHORT
+ case _ => -1
+ }
+ } else {
+ len
+ }
+
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case that: Field =>
+ colName == that.colName && cf == that.cf && col == that.col
+ case _ => false
+ }
+}
+
+// The row key definition, with each key refer to the col defined in Field, e.g.,
+// key1:key2:key3
+@InterfaceAudience.Private
+case class RowKey(k: String) {
+ val keys = k.split(":")
+ var fields: Seq[Field] = _
+ var varLength = false
+ def length = {
+ if (varLength) {
+ -1
+ } else {
+ fields.foldLeft(0) {
+ case (x, y) =>
+ x + y.length
+ }
+ }
+ }
+}
+// The map between the column presented to Spark and the HBase field
+@InterfaceAudience.Private
+case class SchemaMap(map: mutable.HashMap[String, Field]) {
+ def toFields = map.map {
+ case (name, field) =>
+ StructField(name, field.dt)
+ }.toSeq
+
+ def fields = map.values
+
+ def getField(name: String) = map(name)
+}
+
+// The definition of HBase and Relation relation schema
+@InterfaceAudience.Private
+case class HBaseTableCatalog(
+ namespace: String,
+ name: String,
+ row: RowKey,
+ sMap: SchemaMap,
+ @transient params: Map[String, String])
+ extends Logging {
+ def toDataType = StructType(sMap.toFields)
+ def getField(name: String) = sMap.getField(name)
+ def getRowKey: Seq[Field] = row.fields
+ def getPrimaryKey = row.keys(0)
+ def getColumnFamilies = {
+ sMap.fields.map(_.cf).filter(_ != HBaseTableCatalog.rowKey).toSeq.distinct
+ }
+
+ def get(key: String) = params.get(key)
+
+ // Setup the start and length for each dimension of row key at runtime.
+ def dynSetupRowKey(rowKey: Array[Byte]) {
+ logDebug(s"length: ${rowKey.length}")
+ if (row.varLength) {
+ var start = 0
+ row.fields.foreach {
+ f =>
+ logDebug(s"start: $start")
+ f.start = start
+ f.length = {
+ // If the length is not defined
+ if (f.length == -1) {
+ f.dt match {
+ case StringType =>
+ var pos = rowKey.indexOf(HBaseTableCatalog.delimiter, start)
+ if (pos == -1 || pos > rowKey.length) {
+ // this is at the last dimension
+ pos = rowKey.length
+ }
+ pos - start
+ // We don't know the length, assume it extend to the end of the rowkey.
+ case _ => rowKey.length - start
+ }
+ } else {
+ f.length
+ }
+ }
+ start += f.length
+ }
+ }
+ }
+
+ def initRowKey = {
+ val fields = sMap.fields.filter(_.cf == HBaseTableCatalog.rowKey)
+ row.fields = row.keys.flatMap(n => fields.find(_.col == n))
+ // The length is determined at run time if it is string or binary and the length is undefined.
+ if (row.fields.filter(_.length == -1).isEmpty) {
+ var start = 0
+ row.fields.foreach {
+ f =>
+ f.start = start
+ start += f.length
+ }
+ } else {
+ row.varLength = true
+ }
+ }
+ initRowKey
+}
+
+@InterfaceAudience.Public
+object HBaseTableCatalog {
+ // If defined and larger than 3, a new table will be created with the nubmer of region specified.
+ val newTable = "newtable"
+ // The json string specifying hbase catalog information
+ val regionStart = "regionStart"
+ val defaultRegionStart = "aaaaaaa"
+ val regionEnd = "regionEnd"
+ val defaultRegionEnd = "zzzzzzz"
+ val tableCatalog = "catalog"
+ // The row key with format key1:key2 specifying table row key
+ val rowKey = "rowkey"
+ // The key for hbase table whose value specify namespace and table name
+ val table = "table"
+ // The namespace of hbase table
+ val nameSpace = "namespace"
+ // The name of hbase table
+ val tableName = "name"
+ // The name of columns in hbase catalog
+ val columns = "columns"
+ val cf = "cf"
+ val col = "col"
+ val `type` = "type"
+ // the name of avro schema json string
+ val avro = "avro"
+ val delimiter: Byte = 0
+ val serdes = "serdes"
+ val length = "length"
+
+ /**
+ * User provide table schema definition
+ * {"tablename":"name", "rowkey":"key1:key2",
+ * "columns":{"col1":{"cf":"cf1", "col":"col1", "type":"type1"},
+ * "col2":{"cf":"cf2", "col":"col2", "type":"type2"}}}
+ * Note that any col in the rowKey, there has to be one corresponding col defined in columns
+ */
+ def apply(params: Map[String, String]): HBaseTableCatalog = {
+ val parameters = convert(params)
+ // println(jString)
+ val jString = parameters(tableCatalog)
+ implicit val formats: Formats = DefaultFormats
+ val map = JsonMethods.parse(jString)
+ val tableMeta = map \ table
+ val nSpace = (tableMeta \ nameSpace).extractOrElse("default")
+ val tName = (tableMeta \ tableName).extract[String]
+ val cIter = (map \ columns).extract[Map[String, Map[String, String]]]
+ val schemaMap = mutable.HashMap.empty[String, Field]
+ cIter.foreach {
+ case (name, column) =>
+ val sd = {
+ column
+ .get(serdes)
+ .asInstanceOf[Option[String]]
+ .map(n => Class.forName(n).newInstance().asInstanceOf[SerDes])
+ }
+ val len = column.get(length).map(_.toInt).getOrElse(-1)
+ val sAvro = column.get(avro).map(parameters(_))
+ val f = Field(
+ name,
+ column.getOrElse(cf, rowKey),
+ column.get(col).get,
+ column.get(`type`),
+ sAvro,
+ sd,
+ len)
+ schemaMap.+=((name, f))
+ }
+ val rKey = RowKey((map \ rowKey).extract[String])
+ HBaseTableCatalog(nSpace, tName, rKey, SchemaMap(schemaMap), parameters)
+ }
+
+ val TABLE_KEY: String = "hbase.table"
+ val SCHEMA_COLUMNS_MAPPING_KEY: String = "hbase.columns.mapping"
+
+ /* for backward compatibility. Convert the old definition to new json based definition formated as below
+ val catalog = s"""{
+ |"table":{"namespace":"default", "name":"htable"},
+ |"rowkey":"key1:key2",
+ |"columns":{
+ |"col1":{"cf":"rowkey", "col":"key1", "type":"string"},
+ |"col2":{"cf":"rowkey", "col":"key2", "type":"double"},
+ |"col3":{"cf":"cf1", "col":"col2", "type":"binary"},
+ |"col4":{"cf":"cf1", "col":"col3", "type":"timestamp"},
+ |"col5":{"cf":"cf1", "col":"col4", "type":"double", "serdes":"${classOf[DoubleSerDes].getName}"},
+ |"col6":{"cf":"cf1", "col":"col5", "type":"$map"},
+ |"col7":{"cf":"cf1", "col":"col6", "type":"$array"},
+ |"col8":{"cf":"cf1", "col":"col7", "type":"$arrayMap"}
+ |}
+ |}""".stripMargin
+ */
+ @deprecated("Please use new json format to define HBaseCatalog")
+ // TODO: There is no need to deprecate since this is the first release.
+ def convert(parameters: Map[String, String]): Map[String, String] = {
+ val nsTableName = parameters.get(TABLE_KEY).getOrElse(null)
+ // if the hbase.table is not defined, we assume it is json format already.
+ if (nsTableName == null) return parameters
+ val tableParts = nsTableName.trim.split(':')
+ val tableNamespace = if (tableParts.length == 1) {
+ "default"
+ } else if (tableParts.length == 2) {
+ tableParts(0)
+ } else {
+ throw new IllegalArgumentException(
+ "Invalid table name '" + nsTableName +
+ "' should be ':' or '' ")
+ }
+ val tableName = tableParts(tableParts.length - 1)
+ val schemaMappingString = parameters.getOrElse(SCHEMA_COLUMNS_MAPPING_KEY, "")
+ import scala.collection.JavaConverters._
+ val schemaMap = generateSchemaMappingMap(schemaMappingString).asScala.map(
+ _._2.asInstanceOf[SchemaQualifierDefinition])
+
+ val rowkey = schemaMap
+ .filter {
+ _.columnFamily == "rowkey"
+ }
+ .map(_.columnName)
+ val cols = schemaMap.map {
+ x =>
+ s""""${x.columnName}":{"cf":"${x.columnFamily}", "col":"${x.qualifier}", "type":"${x.colType}"}""".stripMargin
+ }
+ val jsonCatalog =
+ s"""{
+ |"table":{"namespace":"${tableNamespace}", "name":"${tableName}"},
+ |"rowkey":"${rowkey.mkString(":")}",
+ |"columns":{
+ |${cols.mkString(",")}
+ |}
+ |}
+ """.stripMargin
+ parameters ++ Map(HBaseTableCatalog.tableCatalog -> jsonCatalog)
+ }
+
+ /**
+ * Reads the SCHEMA_COLUMNS_MAPPING_KEY and converts it to a map of
+ * SchemaQualifierDefinitions with the original sql column name as the key
+ *
+ * @param schemaMappingString The schema mapping string from the SparkSQL map
+ * @return A map of definitions keyed by the SparkSQL column name
+ */
+ @InterfaceAudience.Private
+ def generateSchemaMappingMap(
+ schemaMappingString: String): java.util.HashMap[String, SchemaQualifierDefinition] = {
+ println(schemaMappingString)
+ try {
+ val columnDefinitions = schemaMappingString.split(',')
+ val resultingMap = new java.util.HashMap[String, SchemaQualifierDefinition]()
+ columnDefinitions.map(
+ cd => {
+ val parts = cd.trim.split(' ')
+
+ // Make sure we get three parts
+ //
+ if (parts.length == 3) {
+ val hbaseDefinitionParts = if (parts(2).charAt(0) == ':') {
+ Array[String]("rowkey", parts(0))
+ } else {
+ parts(2).split(':')
+ }
+ resultingMap.put(
+ parts(0),
+ new SchemaQualifierDefinition(
+ parts(0),
+ parts(1),
+ hbaseDefinitionParts(0),
+ hbaseDefinitionParts(1)))
+ } else {
+ throw new IllegalArgumentException(
+ "Invalid value for schema mapping '" + cd +
+ "' should be ':' " +
+ "for columns and ' :' for rowKeys")
+ }
+ })
+ resultingMap
+ } catch {
+ case e: Exception =>
+ throw new IllegalArgumentException(
+ "Invalid value for " + SCHEMA_COLUMNS_MAPPING_KEY +
+ " '" +
+ schemaMappingString + "'",
+ e)
+ }
+ }
+}
+
+/**
+ * Construct to contains column data that spend SparkSQL and HBase
+ *
+ * @param columnName SparkSQL column name
+ * @param colType SparkSQL column type
+ * @param columnFamily HBase column family
+ * @param qualifier HBase qualifier name
+ */
+@InterfaceAudience.Private
+case class SchemaQualifierDefinition(
+ columnName: String,
+ colType: String,
+ columnFamily: String,
+ qualifier: String)
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala
new file mode 100644
index 00000000..82f7e8c4
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala
@@ -0,0 +1,345 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.datasources
+
+import java.util.ArrayList
+import org.apache.hadoop.hbase.client._
+import org.apache.hadoop.hbase.spark._
+import org.apache.hadoop.hbase.spark.datasources.HBaseResources._
+import org.apache.hadoop.hbase.spark.hbase._
+import org.apache.hadoop.hbase.util.ShutdownHookManager
+import org.apache.spark.{Partition, SparkEnv, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.yetus.audience.InterfaceAudience
+import scala.collection.mutable
+
+@InterfaceAudience.Private
+class HBaseTableScanRDD(
+ relation: HBaseRelation,
+ val hbaseContext: HBaseContext,
+ @transient val filter: Option[SparkSQLPushDownFilter] = None,
+ val columns: Seq[Field] = Seq.empty)
+ extends RDD[Result](relation.sqlContext.sparkContext, Nil) {
+ private def sparkConf = SparkEnv.get.conf
+ @transient var ranges = Seq.empty[Range]
+ @transient var points = Seq.empty[Array[Byte]]
+ def addPoint(p: Array[Byte]) {
+ points :+= p
+ }
+
+ def addRange(r: ScanRange) = {
+ val lower = if (r.lowerBound != null && r.lowerBound.length > 0) {
+ Some(Bound(r.lowerBound, r.isLowerBoundEqualTo))
+ } else {
+ None
+ }
+ val upper = if (r.upperBound != null && r.upperBound.length > 0) {
+ if (!r.isUpperBoundEqualTo) {
+ Some(Bound(r.upperBound, false))
+ } else {
+
+ // HBase stopRow is exclusive: therefore it DOESN'T act like isUpperBoundEqualTo
+ // by default. So we need to add a new max byte to the stopRow key
+ val newArray = new Array[Byte](r.upperBound.length + 1)
+ System.arraycopy(r.upperBound, 0, newArray, 0, r.upperBound.length)
+
+ // New Max Bytes
+ newArray(r.upperBound.length) = ByteMin
+ Some(Bound(newArray, false))
+ }
+ } else {
+ None
+ }
+ ranges :+= Range(lower, upper)
+ }
+
+ override def getPartitions: Array[Partition] = {
+ val regions = RegionResource(relation)
+ var idx = 0
+ logDebug(s"There are ${regions.size} regions")
+ val ps = regions.flatMap {
+ x =>
+ val rs = Ranges.and(Range(x), ranges)
+ val ps = Points.and(Range(x), points)
+ if (rs.size > 0 || ps.size > 0) {
+ if (log.isDebugEnabled) {
+ rs.foreach(x => logDebug(x.toString))
+ }
+ idx += 1
+ Some(
+ HBaseScanPartition(
+ idx - 1,
+ x,
+ rs,
+ ps,
+ SerializedFilter.toSerializedTypedFilter(filter)))
+ } else {
+ None
+ }
+ }.toArray
+ if (log.isDebugEnabled) {
+ logDebug(s"Partitions: ${ps.size}");
+ ps.foreach(x => logDebug(x.toString))
+ }
+ regions.release()
+ ShutdownHookManager.affixShutdownHook(
+ new Thread() {
+ override def run() {
+ HBaseConnectionCache.close()
+ }
+ },
+ 0)
+ ps.asInstanceOf[Array[Partition]]
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ split
+ .asInstanceOf[HBaseScanPartition]
+ .regions
+ .server
+ .map {
+ identity
+ }
+ .toSeq
+ }
+
+ private def buildGets(
+ tbr: TableResource,
+ g: Seq[Array[Byte]],
+ filter: Option[SparkSQLPushDownFilter],
+ columns: Seq[Field],
+ hbaseContext: HBaseContext): Iterator[Result] = {
+ g.grouped(relation.bulkGetSize).flatMap {
+ x =>
+ val gets = new ArrayList[Get](x.size)
+ val rowkeySet = new mutable.HashSet[String]()
+ x.foreach {
+ y =>
+ if (!rowkeySet.contains(y.mkString("Array(", ", ", ")"))) {
+ val g = new Get(y)
+ handleTimeSemantics(g)
+ columns.foreach {
+ d =>
+ if (!d.isRowKey) {
+ g.addColumn(d.cfBytes, d.colBytes)
+ }
+ }
+ filter.foreach(g.setFilter(_))
+ gets.add(g)
+ rowkeySet.add(y.mkString("Array(", ", ", ")"))
+ }
+ }
+ hbaseContext.applyCreds()
+ val tmp = tbr.get(gets)
+ rddResources.addResource(tmp)
+ toResultIterator(tmp)
+ }
+ }
+
+ private def toResultIterator(result: GetResource): Iterator[Result] = {
+ val iterator = new Iterator[Result] {
+ var idx = 0
+ var cur: Option[Result] = None
+ override def hasNext: Boolean = {
+ while (idx < result.length && cur.isEmpty) {
+ val r = result(idx)
+ idx += 1
+ if (!r.isEmpty) {
+ cur = Some(r)
+ }
+ }
+ if (cur.isEmpty) {
+ rddResources.release(result)
+ }
+ cur.isDefined
+ }
+ override def next(): Result = {
+ hasNext
+ val ret = cur.get
+ cur = None
+ ret
+ }
+ }
+ iterator
+ }
+
+ private def buildScan(
+ range: Range,
+ filter: Option[SparkSQLPushDownFilter],
+ columns: Seq[Field]): Scan = {
+ val scan = (range.lower, range.upper) match {
+ case (Some(Bound(a, b)), Some(Bound(c, d))) => new Scan(a, c)
+ case (None, Some(Bound(c, d))) => new Scan(Array[Byte](), c)
+ case (Some(Bound(a, b)), None) => new Scan(a)
+ case (None, None) => new Scan()
+ }
+ handleTimeSemantics(scan)
+
+ columns.foreach {
+ d =>
+ if (!d.isRowKey) {
+ scan.addColumn(d.cfBytes, d.colBytes)
+ }
+ }
+ scan.setCacheBlocks(relation.blockCacheEnable)
+ scan.setBatch(relation.batchNum)
+ scan.setCaching(relation.cacheSize)
+ filter.foreach(scan.setFilter(_))
+ scan
+ }
+ private def toResultIterator(scanner: ScanResource): Iterator[Result] = {
+ val iterator = new Iterator[Result] {
+ var cur: Option[Result] = None
+ override def hasNext: Boolean = {
+ if (cur.isEmpty) {
+ val r = scanner.next()
+ if (r == null) {
+ rddResources.release(scanner)
+ } else {
+ cur = Some(r)
+ }
+ }
+ cur.isDefined
+ }
+ override def next(): Result = {
+ hasNext
+ val ret = cur.get
+ cur = None
+ ret
+ }
+ }
+ iterator
+ }
+
+ lazy val rddResources = RDDResources(new mutable.HashSet[Resource]())
+
+ private def close() {
+ rddResources.release()
+ }
+
+ override def compute(split: Partition, context: TaskContext): Iterator[Result] = {
+ val partition = split.asInstanceOf[HBaseScanPartition]
+ val filter = SerializedFilter.fromSerializedFilter(partition.sf)
+ val scans = partition.scanRanges
+ .map(buildScan(_, filter, columns))
+ val tableResource = TableResource(relation)
+ context.addTaskCompletionListener[Unit](context => close())
+ val points = partition.points
+ val gIt: Iterator[Result] = {
+ if (points.isEmpty) {
+ Iterator.empty: Iterator[Result]
+ } else {
+ buildGets(tableResource, points, filter, columns, hbaseContext)
+ }
+ }
+ val rIts = scans
+ .map {
+ scan =>
+ hbaseContext.applyCreds()
+ val scanner = tableResource.getScanner(scan)
+ rddResources.addResource(scanner)
+ scanner
+ }
+ .map(toResultIterator(_))
+ .fold(Iterator.empty: Iterator[Result]) {
+ case (x, y) =>
+ x ++ y
+ } ++ gIt
+ ShutdownHookManager.affixShutdownHook(
+ new Thread() {
+ override def run() {
+ HBaseConnectionCache.close()
+ }
+ },
+ 0)
+ rIts
+ }
+
+ private def handleTimeSemantics(query: Query): Unit = {
+ // Set timestamp related values if present
+ (query, relation.timestamp, relation.minTimestamp, relation.maxTimestamp) match {
+ case (q: Scan, Some(ts), None, None) => q.setTimeStamp(ts)
+ case (q: Get, Some(ts), None, None) => q.setTimeStamp(ts)
+
+ case (q: Scan, None, Some(minStamp), Some(maxStamp)) => q.setTimeRange(minStamp, maxStamp)
+ case (q: Get, None, Some(minStamp), Some(maxStamp)) => q.setTimeRange(minStamp, maxStamp)
+
+ case (q, None, None, None) =>
+
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Invalid combination of query/timestamp/time range provided. " +
+ s"timeStamp is: ${relation.timestamp.get}, minTimeStamp is: ${relation.minTimestamp.get}, " +
+ s"maxTimeStamp is: ${relation.maxTimestamp.get}")
+ }
+ if (relation.maxVersions.isDefined) {
+ query match {
+ case q: Scan => q.setMaxVersions(relation.maxVersions.get)
+ case q: Get => q.setMaxVersions(relation.maxVersions.get)
+ case _ => throw new IllegalArgumentException("Invalid query provided with maxVersions")
+ }
+ }
+ }
+}
+
+@InterfaceAudience.Private
+case class SerializedFilter(b: Option[Array[Byte]])
+
+object SerializedFilter {
+ def toSerializedTypedFilter(f: Option[SparkSQLPushDownFilter]): SerializedFilter = {
+ SerializedFilter(f.map(_.toByteArray))
+ }
+
+ def fromSerializedFilter(sf: SerializedFilter): Option[SparkSQLPushDownFilter] = {
+ sf.b.map(SparkSQLPushDownFilter.parseFrom(_))
+ }
+}
+
+@InterfaceAudience.Private
+private[hbase] case class HBaseRegion(
+ override val index: Int,
+ val start: Option[HBaseType] = None,
+ val end: Option[HBaseType] = None,
+ val server: Option[String] = None)
+ extends Partition
+
+@InterfaceAudience.Private
+private[hbase] case class HBaseScanPartition(
+ override val index: Int,
+ val regions: HBaseRegion,
+ val scanRanges: Seq[Range],
+ val points: Seq[Array[Byte]],
+ val sf: SerializedFilter)
+ extends Partition
+
+@InterfaceAudience.Private
+case class RDDResources(set: mutable.HashSet[Resource]) {
+ def addResource(s: Resource) {
+ set += s
+ }
+ def release() {
+ set.foreach(release(_))
+ }
+ def release(rs: Resource) {
+ try {
+ rs.release()
+ } finally {
+ set.remove(rs)
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/JavaBytesEncoder.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/JavaBytesEncoder.scala
new file mode 100644
index 00000000..eac4feb6
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/JavaBytesEncoder.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.datasources
+
+import org.apache.hadoop.hbase.HBaseInterfaceAudience
+import org.apache.hadoop.hbase.spark.Logging
+import org.apache.hadoop.hbase.spark.datasources.JavaBytesEncoder.JavaBytesEncoder
+import org.apache.spark.sql.types._
+import org.apache.yetus.audience.InterfaceAudience
+import org.apache.yetus.audience.InterfaceStability
+
+/**
+ * The ranges for the data type whose size is known. Whether the bound is inclusive
+ * or exclusive is undefind, and upper to the caller to decide.
+ *
+ * @param low: the lower bound of the range.
+ * @param upper: the upper bound of the range.
+ */
+@InterfaceAudience.LimitedPrivate(Array(HBaseInterfaceAudience.SPARK))
+@InterfaceStability.Evolving
+case class BoundRange(low: Array[Byte], upper: Array[Byte])
+
+/**
+ * The class identifies the ranges for a java primitive type. The caller needs
+ * to decide the bound is either inclusive or exclusive on its own.
+ * information
+ *
+ * @param less: the set of ranges for LessThan/LessOrEqualThan
+ * @param greater: the set of ranges for GreaterThan/GreaterThanOrEqualTo
+ * @param value: the byte array of the original value
+ */
+@InterfaceAudience.LimitedPrivate(Array(HBaseInterfaceAudience.SPARK))
+@InterfaceStability.Evolving
+case class BoundRanges(less: Array[BoundRange], greater: Array[BoundRange], value: Array[Byte])
+
+/**
+ * The trait to support plugin architecture for different encoder/decoder.
+ * encode is used for serializing the data type to byte array and the filter is
+ * used to filter out the unnecessary records.
+ */
+@InterfaceAudience.LimitedPrivate(Array(HBaseInterfaceAudience.SPARK))
+@InterfaceStability.Evolving
+trait BytesEncoder {
+ def encode(dt: DataType, value: Any): Array[Byte]
+
+ /**
+ * The function performing real filtering operations. The format of filterBytes depends on the
+ * implementation of the BytesEncoder.
+ *
+ * @param input: the current input byte array that needs to be filtered out
+ * @param offset1: the starting offset of the input byte array.
+ * @param length1: the length of the input byte array.
+ * @param filterBytes: the byte array provided by query condition.
+ * @param offset2: the starting offset in the filterBytes.
+ * @param length2: the length of the bytes in the filterBytes
+ * @param ops: The operation of the filter operator.
+ * @return true: the record satisfies the predicates
+ * false: the record does not satisfy the predicates.
+ */
+ def filter(
+ input: Array[Byte],
+ offset1: Int,
+ length1: Int,
+ filterBytes: Array[Byte],
+ offset2: Int,
+ length2: Int,
+ ops: JavaBytesEncoder): Boolean
+
+ /**
+ * Currently, it is used for partition pruning.
+ * As for some codec, the order may be inconsistent between java primitive
+ * type and its byte array. We may have to split the predicates on some
+ * of the java primitive type into multiple predicates.
+ *
+ * For example in naive codec, some of the java primitive types have to be
+ * split into multiple predicates, and union these predicates together to
+ * make the predicates be performed correctly.
+ * For example, if we have "COLUMN < 2", we will transform it into
+ * "0 <= COLUMN < 2 OR Integer.MIN_VALUE <= COLUMN <= -1"
+ */
+ def ranges(in: Any): Option[BoundRanges]
+}
+
+@InterfaceAudience.LimitedPrivate(Array(HBaseInterfaceAudience.SPARK))
+@InterfaceStability.Evolving
+object JavaBytesEncoder extends Enumeration with Logging {
+ type JavaBytesEncoder = Value
+ val Greater, GreaterEqual, Less, LessEqual, Equal, Unknown = Value
+
+ /**
+ * create the encoder/decoder
+ *
+ * @param clsName: the class name of the encoder/decoder class
+ * @return the instance of the encoder plugin.
+ */
+ def create(clsName: String): BytesEncoder = {
+ try {
+ Class.forName(clsName).newInstance.asInstanceOf[BytesEncoder]
+ } catch {
+ case _: Throwable =>
+ logWarning(s"$clsName cannot be initiated, falling back to naive encoder")
+ new NaiveEncoder()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/NaiveEncoder.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/NaiveEncoder.scala
new file mode 100644
index 00000000..b54d2797
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/NaiveEncoder.scala
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.datasources
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.hbase.spark.Logging
+import org.apache.hadoop.hbase.spark.datasources.JavaBytesEncoder.JavaBytesEncoder
+import org.apache.hadoop.hbase.spark.hbase._
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is the naive non-order preserving encoder/decoder.
+ * Due to the inconsistency of the order between java primitive types
+ * and their bytearray. The data type has to be passed in so that the filter
+ * can work correctly, which is done by wrapping the type into the first byte
+ * of the serialized array.
+ */
+@InterfaceAudience.Private
+class NaiveEncoder extends BytesEncoder with Logging {
+ var code = 0
+ def nextCode: Byte = {
+ code += 1
+ (code - 1).asInstanceOf[Byte]
+ }
+ val BooleanEnc = nextCode
+ val ShortEnc = nextCode
+ val IntEnc = nextCode
+ val LongEnc = nextCode
+ val FloatEnc = nextCode
+ val DoubleEnc = nextCode
+ val StringEnc = nextCode
+ val BinaryEnc = nextCode
+ val TimestampEnc = nextCode
+ val UnknownEnc = nextCode
+
+ /**
+ * Evaluate the java primitive type and return the BoundRanges. For one value, it may have
+ * multiple output ranges because of the inconsistency of order between java primitive type
+ * and its byte array order.
+ *
+ * For short, integer, and long, the order of number is consistent with byte array order
+ * if two number has the same sign bit. But the negative number is larger than positive
+ * number in byte array.
+ *
+ * For double and float, the order of positive number is consistent with its byte array order.
+ * But the order of negative number is the reverse order of byte array. Please refer to IEEE-754
+ * and https://en.wikipedia.org/wiki/Single-precision_floating-point_format
+ */
+ def ranges(in: Any): Option[BoundRanges] = in match {
+ case a: Integer =>
+ val b = Bytes.toBytes(a)
+ if (a >= 0) {
+ logDebug(s"range is 0 to $a and ${Integer.MIN_VALUE} to -1")
+ Some(
+ BoundRanges(
+ Array(
+ BoundRange(Bytes.toBytes(0: Int), b),
+ BoundRange(Bytes.toBytes(Integer.MIN_VALUE), Bytes.toBytes(-1: Int))),
+ Array(BoundRange(b, Bytes.toBytes(Integer.MAX_VALUE))),
+ b))
+ } else {
+ Some(
+ BoundRanges(
+ Array(BoundRange(Bytes.toBytes(Integer.MIN_VALUE), b)),
+ Array(
+ BoundRange(b, Bytes.toBytes(-1: Integer)),
+ BoundRange(Bytes.toBytes(0: Int), Bytes.toBytes(Integer.MAX_VALUE))),
+ b))
+ }
+ case a: Long =>
+ val b = Bytes.toBytes(a)
+ if (a >= 0) {
+ Some(
+ BoundRanges(
+ Array(
+ BoundRange(Bytes.toBytes(0: Long), b),
+ BoundRange(Bytes.toBytes(Long.MinValue), Bytes.toBytes(-1: Long))),
+ Array(BoundRange(b, Bytes.toBytes(Long.MaxValue))),
+ b))
+ } else {
+ Some(
+ BoundRanges(
+ Array(BoundRange(Bytes.toBytes(Long.MinValue), b)),
+ Array(
+ BoundRange(b, Bytes.toBytes(-1: Long)),
+ BoundRange(Bytes.toBytes(0: Long), Bytes.toBytes(Long.MaxValue))),
+ b))
+ }
+ case a: Short =>
+ val b = Bytes.toBytes(a)
+ if (a >= 0) {
+ Some(
+ BoundRanges(
+ Array(
+ BoundRange(Bytes.toBytes(0: Short), b),
+ BoundRange(Bytes.toBytes(Short.MinValue), Bytes.toBytes(-1: Short))),
+ Array(BoundRange(b, Bytes.toBytes(Short.MaxValue))),
+ b))
+ } else {
+ Some(
+ BoundRanges(
+ Array(BoundRange(Bytes.toBytes(Short.MinValue), b)),
+ Array(
+ BoundRange(b, Bytes.toBytes(-1: Short)),
+ BoundRange(Bytes.toBytes(0: Short), Bytes.toBytes(Short.MaxValue))),
+ b))
+ }
+ case a: Double =>
+ val b = Bytes.toBytes(a)
+ if (a >= 0.0f) {
+ Some(
+ BoundRanges(
+ Array(
+ BoundRange(Bytes.toBytes(0.0d), b),
+ BoundRange(Bytes.toBytes(-0.0d), Bytes.toBytes(Double.MinValue))),
+ Array(BoundRange(b, Bytes.toBytes(Double.MaxValue))),
+ b))
+ } else {
+ Some(
+ BoundRanges(
+ Array(BoundRange(b, Bytes.toBytes(Double.MinValue))),
+ Array(
+ BoundRange(Bytes.toBytes(-0.0d), b),
+ BoundRange(Bytes.toBytes(0.0d), Bytes.toBytes(Double.MaxValue))),
+ b))
+ }
+ case a: Float =>
+ val b = Bytes.toBytes(a)
+ if (a >= 0.0f) {
+ Some(
+ BoundRanges(
+ Array(
+ BoundRange(Bytes.toBytes(0.0f), b),
+ BoundRange(Bytes.toBytes(-0.0f), Bytes.toBytes(Float.MinValue))),
+ Array(BoundRange(b, Bytes.toBytes(Float.MaxValue))),
+ b))
+ } else {
+ Some(
+ BoundRanges(
+ Array(BoundRange(b, Bytes.toBytes(Float.MinValue))),
+ Array(
+ BoundRange(Bytes.toBytes(-0.0f), b),
+ BoundRange(Bytes.toBytes(0.0f), Bytes.toBytes(Float.MaxValue))),
+ b))
+ }
+ case a: Array[Byte] =>
+ Some(BoundRanges(Array(BoundRange(bytesMin, a)), Array(BoundRange(a, bytesMax)), a))
+ case a: Byte =>
+ val b = Array(a)
+ Some(BoundRanges(Array(BoundRange(bytesMin, b)), Array(BoundRange(b, bytesMax)), b))
+ case a: String =>
+ val b = Bytes.toBytes(a)
+ Some(BoundRanges(Array(BoundRange(bytesMin, b)), Array(BoundRange(b, bytesMax)), b))
+ case a: UTF8String =>
+ val b = a.getBytes
+ Some(BoundRanges(Array(BoundRange(bytesMin, b)), Array(BoundRange(b, bytesMax)), b))
+ case _ => None
+ }
+
+ def compare(c: Int, ops: JavaBytesEncoder): Boolean = {
+ ops match {
+ case JavaBytesEncoder.Greater => c > 0
+ case JavaBytesEncoder.GreaterEqual => c >= 0
+ case JavaBytesEncoder.Less => c < 0
+ case JavaBytesEncoder.LessEqual => c <= 0
+ }
+ }
+
+ /**
+ * encode the data type into byte array. Note that it is a naive implementation with the
+ * data type byte appending to the head of the serialized byte array.
+ *
+ * @param dt: The data type of the input
+ * @param value: the value of the input
+ * @return the byte array with the first byte indicating the data type.
+ */
+ override def encode(dt: DataType, value: Any): Array[Byte] = {
+ dt match {
+ case BooleanType =>
+ val result = new Array[Byte](Bytes.SIZEOF_BOOLEAN + 1)
+ result(0) = BooleanEnc
+ value.asInstanceOf[Boolean] match {
+ case true => result(1) = -1: Byte
+ case false => result(1) = 0: Byte
+ }
+ result
+ case ShortType =>
+ val result = new Array[Byte](Bytes.SIZEOF_SHORT + 1)
+ result(0) = ShortEnc
+ Bytes.putShort(result, 1, value.asInstanceOf[Short])
+ result
+ case IntegerType =>
+ val result = new Array[Byte](Bytes.SIZEOF_INT + 1)
+ result(0) = IntEnc
+ Bytes.putInt(result, 1, value.asInstanceOf[Int])
+ result
+ case LongType | TimestampType =>
+ val result = new Array[Byte](Bytes.SIZEOF_LONG + 1)
+ result(0) = LongEnc
+ Bytes.putLong(result, 1, value.asInstanceOf[Long])
+ result
+ case FloatType =>
+ val result = new Array[Byte](Bytes.SIZEOF_FLOAT + 1)
+ result(0) = FloatEnc
+ Bytes.putFloat(result, 1, value.asInstanceOf[Float])
+ result
+ case DoubleType =>
+ val result = new Array[Byte](Bytes.SIZEOF_DOUBLE + 1)
+ result(0) = DoubleEnc
+ Bytes.putDouble(result, 1, value.asInstanceOf[Double])
+ result
+ case BinaryType =>
+ val v = value.asInstanceOf[Array[Bytes]]
+ val result = new Array[Byte](v.length + 1)
+ result(0) = BinaryEnc
+ System.arraycopy(v, 0, result, 1, v.length)
+ result
+ case StringType =>
+ val bytes = Bytes.toBytes(value.asInstanceOf[String])
+ val result = new Array[Byte](bytes.length + 1)
+ result(0) = StringEnc
+ System.arraycopy(bytes, 0, result, 1, bytes.length)
+ result
+ case _ =>
+ val bytes = Bytes.toBytes(value.toString)
+ val result = new Array[Byte](bytes.length + 1)
+ result(0) = UnknownEnc
+ System.arraycopy(bytes, 0, result, 1, bytes.length)
+ result
+ }
+ }
+
+ override def filter(
+ input: Array[Byte],
+ offset1: Int,
+ length1: Int,
+ filterBytes: Array[Byte],
+ offset2: Int,
+ length2: Int,
+ ops: JavaBytesEncoder): Boolean = {
+ filterBytes(offset2) match {
+ case ShortEnc =>
+ val in = Bytes.toShort(input, offset1)
+ val value = Bytes.toShort(filterBytes, offset2 + 1)
+ compare(in.compareTo(value), ops)
+ case IntEnc =>
+ val in = Bytes.toInt(input, offset1)
+ val value = Bytes.toInt(filterBytes, offset2 + 1)
+ compare(in.compareTo(value), ops)
+ case LongEnc | TimestampEnc =>
+ val in = Bytes.toLong(input, offset1)
+ val value = Bytes.toLong(filterBytes, offset2 + 1)
+ compare(in.compareTo(value), ops)
+ case FloatEnc =>
+ val in = Bytes.toFloat(input, offset1)
+ val value = Bytes.toFloat(filterBytes, offset2 + 1)
+ compare(in.compareTo(value), ops)
+ case DoubleEnc =>
+ val in = Bytes.toDouble(input, offset1)
+ val value = Bytes.toDouble(filterBytes, offset2 + 1)
+ compare(in.compareTo(value), ops)
+ case _ =>
+ // for String, Byte, Binary, Boolean and other types
+ // we can use the order of byte array directly.
+ compare(
+ Bytes.compareTo(input, offset1, length1, filterBytes, offset2 + 1, length2 - 1),
+ ops)
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SchemaConverters.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SchemaConverters.scala
new file mode 100644
index 00000000..3eafe172
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SchemaConverters.scala
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.io.ByteArrayInputStream
+import java.nio.ByteBuffer
+import java.sql.Timestamp
+import java.util
+import java.util.HashMap
+import org.apache.avro.{Schema, SchemaBuilder}
+import org.apache.avro.Schema.Type._
+import org.apache.avro.SchemaBuilder.BaseFieldTypeBuilder
+import org.apache.avro.SchemaBuilder.BaseTypeBuilder
+import org.apache.avro.SchemaBuilder.FieldAssembler
+import org.apache.avro.SchemaBuilder.FieldDefault
+import org.apache.avro.SchemaBuilder.RecordBuilder
+import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord}
+import org.apache.avro.generic.GenericData.{Fixed, Record}
+import org.apache.avro.io._
+import org.apache.commons.io.output.ByteArrayOutputStream
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.apache.yetus.audience.InterfaceAudience
+import scala.jdk.CollectionConverters._
+import scala.collection.immutable.Map
+
+@InterfaceAudience.Private
+abstract class AvroException(msg: String) extends Exception(msg)
+
+@InterfaceAudience.Private
+case class SchemaConversionException(msg: String) extends AvroException(msg)
+
+/**
+ * *
+ * On top level, the converters provide three high level interface.
+ * 1. toSqlType: This function takes an avro schema and returns a sql schema.
+ * 2. createConverterToSQL: Returns a function that is used to convert avro types to their
+ * corresponding sparkSQL representations.
+ * 3. convertTypeToAvro: This function constructs converter function for a given sparkSQL
+ * datatype. This is used in writing Avro records out to disk
+ */
+@InterfaceAudience.Private
+object SchemaConverters {
+
+ case class SchemaType(dataType: DataType, nullable: Boolean)
+
+ /**
+ * This function takes an avro schema and returns a sql schema.
+ */
+ def toSqlType(avroSchema: Schema): SchemaType = {
+ avroSchema.getType match {
+ case INT => SchemaType(IntegerType, nullable = false)
+ case STRING => SchemaType(StringType, nullable = false)
+ case BOOLEAN => SchemaType(BooleanType, nullable = false)
+ case BYTES => SchemaType(BinaryType, nullable = false)
+ case DOUBLE => SchemaType(DoubleType, nullable = false)
+ case FLOAT => SchemaType(FloatType, nullable = false)
+ case LONG => SchemaType(LongType, nullable = false)
+ case FIXED => SchemaType(BinaryType, nullable = false)
+ case ENUM => SchemaType(StringType, nullable = false)
+
+ case RECORD =>
+ val fields = avroSchema.getFields.asScala.map {
+ f =>
+ val schemaType = toSqlType(f.schema())
+ StructField(f.name, schemaType.dataType, schemaType.nullable)
+ }
+
+ SchemaType(StructType(fields.toList.asJava), nullable = false)
+
+ case ARRAY =>
+ val schemaType = toSqlType(avroSchema.getElementType)
+ SchemaType(
+ ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
+ nullable = false)
+
+ case MAP =>
+ val schemaType = toSqlType(avroSchema.getValueType)
+ SchemaType(
+ MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable),
+ nullable = false)
+
+ case UNION =>
+ if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
+ // In case of a union with null, eliminate it and make a recursive call
+ val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
+ if (remainingUnionTypes.size == 1) {
+ toSqlType(remainingUnionTypes.head).copy(nullable = true)
+ } else {
+ toSqlType(Schema.createUnion(remainingUnionTypes.toList.asJava)).copy(nullable = true)
+ }
+ } else
+ avroSchema.getTypes.asScala.map(_.getType) match {
+ case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
+ SchemaType(LongType, nullable = false)
+ case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
+ SchemaType(DoubleType, nullable = false)
+ case other =>
+ throw new SchemaConversionException(
+ s"This mix of union types is not supported: $other")
+ }
+
+ case other => throw new SchemaConversionException(s"Unsupported type $other")
+ }
+ }
+
+ /**
+ * This function converts sparkSQL StructType into avro schema. This method uses two other
+ * converter methods in order to do the conversion.
+ */
+ private def convertStructToAvro[T](
+ structType: StructType,
+ schemaBuilder: RecordBuilder[T],
+ recordNamespace: String): T = {
+ val fieldsAssembler: FieldAssembler[T] = schemaBuilder.fields()
+ structType.fields.foreach {
+ field =>
+ val newField = fieldsAssembler.name(field.name).`type`()
+
+ if (field.nullable) {
+ convertFieldTypeToAvro(
+ field.dataType,
+ newField.nullable(),
+ field.name,
+ recordNamespace).noDefault
+ } else {
+ convertFieldTypeToAvro(field.dataType, newField, field.name, recordNamespace).noDefault
+ }
+ }
+ fieldsAssembler.endRecord()
+ }
+
+ /**
+ * Returns a function that is used to convert avro types to their
+ * corresponding sparkSQL representations.
+ */
+ def createConverterToSQL(schema: Schema): Any => Any = {
+ schema.getType match {
+ // Avro strings are in Utf8, so we have to call toString on them
+ case STRING | ENUM => (item: Any) => if (item == null) null else item.toString
+ case INT | BOOLEAN | DOUBLE | FLOAT | LONG => identity
+ // Byte arrays are reused by avro, so we have to make a copy of them.
+ case FIXED =>
+ (item: Any) =>
+ if (item == null) {
+ null
+ } else {
+ item.asInstanceOf[Fixed].bytes().clone()
+ }
+ case BYTES =>
+ (item: Any) =>
+ if (item == null) {
+ null
+ } else {
+ val bytes = item.asInstanceOf[ByteBuffer]
+ val javaBytes = new Array[Byte](bytes.remaining)
+ bytes.get(javaBytes)
+ javaBytes
+ }
+ case RECORD =>
+ val fieldConverters = schema.getFields.asScala.map(f => createConverterToSQL(f.schema))
+ (item: Any) =>
+ if (item == null) {
+ null
+ } else {
+ val record = item.asInstanceOf[GenericRecord]
+ val converted = new Array[Any](fieldConverters.size)
+ var idx = 0
+ while (idx < fieldConverters.size) {
+ converted(idx) = fieldConverters.apply(idx)(record.get(idx))
+ idx += 1
+ }
+ Row.fromSeq(converted.toSeq)
+ }
+ case ARRAY =>
+ val elementConverter = createConverterToSQL(schema.getElementType)
+ (item: Any) =>
+ if (item == null) {
+ null
+ } else {
+ try {
+ item.asInstanceOf[GenericData.Array[Any]].asScala.map(elementConverter)
+ } catch {
+ case e: Throwable =>
+ item.asInstanceOf[util.ArrayList[Any]].asScala.map(elementConverter)
+ }
+ }
+ case MAP =>
+ val valueConverter = createConverterToSQL(schema.getValueType)
+ (item: Any) =>
+ if (item == null) {
+ null
+ } else {
+ item
+ .asInstanceOf[HashMap[Any, Any]].asScala
+ .map(x => (x._1.toString, valueConverter(x._2)))
+ .toMap
+ }
+ case UNION =>
+ if (schema.getTypes.asScala.exists(_.getType == NULL)) {
+ val remainingUnionTypes = schema.getTypes.asScala.filterNot(_.getType == NULL)
+ if (remainingUnionTypes.size == 1) {
+ createConverterToSQL(remainingUnionTypes.head)
+ } else {
+ createConverterToSQL(Schema.createUnion(remainingUnionTypes.toList.asJava))
+ }
+ } else
+ schema.getTypes.asScala.map(_.getType) match {
+ case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
+ (item: Any) => {
+ item match {
+ case l: Long => l
+ case i: Int => i.toLong
+ case null => null
+ }
+ }
+ case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
+ (item: Any) => {
+ item match {
+ case d: Double => d
+ case f: Float => f.toDouble
+ case null => null
+ }
+ }
+ case other =>
+ throw new SchemaConversionException(
+ s"This mix of union types is not supported (see README): $other")
+ }
+ case other => throw new SchemaConversionException(s"invalid avro type: $other")
+ }
+ }
+
+ /**
+ * This function is used to convert some sparkSQL type to avro type. Note that this function won't
+ * be used to construct fields of avro record (convertFieldTypeToAvro is used for that).
+ */
+ private def convertTypeToAvro[T](
+ dataType: DataType,
+ schemaBuilder: BaseTypeBuilder[T],
+ structName: String,
+ recordNamespace: String): T = {
+ dataType match {
+ case ByteType => schemaBuilder.intType()
+ case ShortType => schemaBuilder.intType()
+ case IntegerType => schemaBuilder.intType()
+ case LongType => schemaBuilder.longType()
+ case FloatType => schemaBuilder.floatType()
+ case DoubleType => schemaBuilder.doubleType()
+ case _: DecimalType => schemaBuilder.stringType()
+ case StringType => schemaBuilder.stringType()
+ case BinaryType => schemaBuilder.bytesType()
+ case BooleanType => schemaBuilder.booleanType()
+ case TimestampType => schemaBuilder.longType()
+
+ case ArrayType(elementType, _) =>
+ val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull)
+ val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace)
+ schemaBuilder.array().items(elementSchema)
+
+ case MapType(StringType, valueType, _) =>
+ val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull)
+ val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace)
+ schemaBuilder.map().values(valueSchema)
+
+ case structType: StructType =>
+ convertStructToAvro(
+ structType,
+ schemaBuilder.record(structName).namespace(recordNamespace),
+ recordNamespace)
+
+ case other => throw new IllegalArgumentException(s"Unexpected type $dataType.")
+ }
+ }
+
+ /**
+ * This function is used to construct fields of the avro record, where schema of the field is
+ * specified by avro representation of dataType. Since builders for record fields are different
+ * from those for everything else, we have to use a separate method.
+ */
+ private def convertFieldTypeToAvro[T](
+ dataType: DataType,
+ newFieldBuilder: BaseFieldTypeBuilder[T],
+ structName: String,
+ recordNamespace: String): FieldDefault[T, _] = {
+ dataType match {
+ case ByteType => newFieldBuilder.intType()
+ case ShortType => newFieldBuilder.intType()
+ case IntegerType => newFieldBuilder.intType()
+ case LongType => newFieldBuilder.longType()
+ case FloatType => newFieldBuilder.floatType()
+ case DoubleType => newFieldBuilder.doubleType()
+ case _: DecimalType => newFieldBuilder.stringType()
+ case StringType => newFieldBuilder.stringType()
+ case BinaryType => newFieldBuilder.bytesType()
+ case BooleanType => newFieldBuilder.booleanType()
+ case TimestampType => newFieldBuilder.longType()
+
+ case ArrayType(elementType, _) =>
+ val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull)
+ val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace)
+ newFieldBuilder.array().items(elementSchema)
+
+ case MapType(StringType, valueType, _) =>
+ val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull)
+ val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace)
+ newFieldBuilder.map().values(valueSchema)
+
+ case structType: StructType =>
+ convertStructToAvro(
+ structType,
+ newFieldBuilder.record(structName).namespace(recordNamespace),
+ recordNamespace)
+
+ case other => throw new IllegalArgumentException(s"Unexpected type $dataType.")
+ }
+ }
+
+ private def getSchemaBuilder(isNullable: Boolean): BaseTypeBuilder[Schema] = {
+ if (isNullable) {
+ SchemaBuilder.builder().nullable()
+ } else {
+ SchemaBuilder.builder()
+ }
+ }
+
+ /**
+ * This function constructs converter function for a given sparkSQL datatype. This is used in
+ * writing Avro records out to disk
+ */
+ def createConverterToAvro(
+ dataType: DataType,
+ structName: String,
+ recordNamespace: String): (Any) => Any = {
+ dataType match {
+ case BinaryType =>
+ (item: Any) =>
+ item match {
+ case null => null
+ case bytes: Array[Byte] => ByteBuffer.wrap(bytes)
+ }
+ case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | StringType |
+ BooleanType =>
+ identity
+ case _: DecimalType => (item: Any) => if (item == null) null else item.toString
+ case TimestampType =>
+ (item: Any) => if (item == null) null else item.asInstanceOf[Timestamp].getTime
+ case ArrayType(elementType, _) =>
+ val elementConverter = createConverterToAvro(elementType, structName, recordNamespace)
+ (item: Any) => {
+ if (item == null) {
+ null
+ } else {
+ val sourceArray = item.asInstanceOf[Seq[Any]]
+ val sourceArraySize = sourceArray.size
+ val targetArray = new util.ArrayList[Any](sourceArraySize)
+ var idx = 0
+ while (idx < sourceArraySize) {
+ targetArray.add(elementConverter(sourceArray(idx)))
+ idx += 1
+ }
+ targetArray
+ }
+ }
+ case MapType(StringType, valueType, _) =>
+ val valueConverter = createConverterToAvro(valueType, structName, recordNamespace)
+ (item: Any) => {
+ if (item == null) {
+ null
+ } else {
+ val javaMap = new HashMap[String, Any]()
+ item.asInstanceOf[Map[String, Any]].foreach {
+ case (key, value) =>
+ javaMap.put(key, valueConverter(value))
+ }
+ javaMap
+ }
+ }
+ case structType: StructType =>
+ val builder = SchemaBuilder.record(structName).namespace(recordNamespace)
+ val schema: Schema =
+ SchemaConverters.convertStructToAvro(structType, builder, recordNamespace)
+ val fieldConverters = structType.fields.map(
+ field => createConverterToAvro(field.dataType, field.name, recordNamespace))
+ (item: Any) => {
+ if (item == null) {
+ null
+ } else {
+ val record = new Record(schema)
+ val convertersIterator = fieldConverters.iterator
+ val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator
+ val rowIterator = item.asInstanceOf[Row].toSeq.iterator
+
+ while (convertersIterator.hasNext) {
+ val converter = convertersIterator.next()
+ record.put(fieldNamesIterator.next(), converter(rowIterator.next()))
+ }
+ record
+ }
+ }
+ }
+ }
+}
+
+@InterfaceAudience.Private
+object AvroSerdes {
+ // We only handle top level is record or primary type now
+ def serialize(input: Any, schema: Schema): Array[Byte] = {
+ schema.getType match {
+ case BOOLEAN => Bytes.toBytes(input.asInstanceOf[Boolean])
+ case BYTES | FIXED => input.asInstanceOf[Array[Byte]]
+ case DOUBLE => Bytes.toBytes(input.asInstanceOf[Double])
+ case FLOAT => Bytes.toBytes(input.asInstanceOf[Float])
+ case INT => Bytes.toBytes(input.asInstanceOf[Int])
+ case LONG => Bytes.toBytes(input.asInstanceOf[Long])
+ case STRING => Bytes.toBytes(input.asInstanceOf[String])
+ case RECORD =>
+ val gr = input.asInstanceOf[GenericRecord]
+ val writer2 = new GenericDatumWriter[GenericRecord](schema)
+ val bao2 = new ByteArrayOutputStream()
+ val encoder2: BinaryEncoder = EncoderFactory.get().directBinaryEncoder(bao2, null)
+ writer2.write(gr, encoder2)
+ bao2.toByteArray()
+ case _ => throw new Exception(s"unsupported data type ${schema.getType}") // TODO
+ }
+ }
+
+ def deserialize(input: Array[Byte], schema: Schema): GenericRecord = {
+ val reader2: DatumReader[GenericRecord] = new GenericDatumReader[GenericRecord](schema)
+ val bai2 = new ByteArrayInputStream(input)
+ val decoder2: BinaryDecoder = DecoderFactory.get().directBinaryDecoder(bai2, null)
+ val gr2: GenericRecord = reader2.read(null, decoder2)
+ gr2
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerDes.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerDes.scala
new file mode 100644
index 00000000..c3d5c634
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerDes.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.datasources
+
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.yetus.audience.InterfaceAudience
+
+// TODO: This is not really used in code.
+@InterfaceAudience.Public
+trait SerDes {
+ def serialize(value: Any): Array[Byte]
+ def deserialize(bytes: Array[Byte], start: Int, end: Int): Any
+}
+
+// TODO: This is not really used in code.
+@InterfaceAudience.Private
+class DoubleSerDes extends SerDes {
+ override def serialize(value: Any): Array[Byte] = Bytes.toBytes(value.asInstanceOf[Double])
+ override def deserialize(bytes: Array[Byte], start: Int, end: Int): Any = {
+ Bytes.toDouble(bytes, start)
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerializableConfiguration.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerializableConfiguration.scala
new file mode 100644
index 00000000..f5b4c5a2
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerializableConfiguration.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.datasources
+
+import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+import org.apache.hadoop.conf.Configuration
+import org.apache.yetus.audience.InterfaceAudience
+import scala.util.control.NonFatal
+
+@InterfaceAudience.Private
+class SerializableConfiguration(@transient var value: Configuration) extends Serializable {
+ private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException {
+ out.defaultWriteObject()
+ value.write(out)
+ }
+
+ private def readObject(in: ObjectInputStream): Unit = tryOrIOException {
+ value = new Configuration(false)
+ value.readFields(in)
+ }
+
+ def tryOrIOException(block: => Unit) {
+ try {
+ block
+ } catch {
+ case e: IOException => throw e
+ case NonFatal(t) => throw new IOException(t)
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala
new file mode 100644
index 00000000..de7eec9a
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.datasources
+
+import java.sql.{Date, Timestamp}
+import org.apache.hadoop.hbase.spark.AvroSerdes
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.yetus.audience.InterfaceAudience;
+
+@InterfaceAudience.Private
+object Utils {
+
+ /**
+ * Parses the hbase field to it's corresponding
+ * scala type which can then be put into a Spark GenericRow
+ * which is then automatically converted by Spark.
+ */
+ def hbaseFieldToScalaType(f: Field, src: Array[Byte], offset: Int, length: Int): Any = {
+ if (f.exeSchema.isDefined) {
+ // If we have avro schema defined, use it to get record, and then convert them to catalyst data type
+ val m = AvroSerdes.deserialize(src, f.exeSchema.get)
+ val n = f.avroToCatalyst.map(_(m))
+ n.get
+ } else {
+ // Fall back to atomic type
+ f.dt match {
+ case BooleanType => src(offset) != 0
+ case ByteType => src(offset)
+ case ShortType => Bytes.toShort(src, offset)
+ case IntegerType => Bytes.toInt(src, offset)
+ case LongType => Bytes.toLong(src, offset)
+ case FloatType => Bytes.toFloat(src, offset)
+ case DoubleType => Bytes.toDouble(src, offset)
+ case DateType => new Date(Bytes.toLong(src, offset))
+ case TimestampType => new Timestamp(Bytes.toLong(src, offset))
+ case StringType => Bytes.toString(src, offset, length)
+ case BinaryType =>
+ val newArray = new Array[Byte](length)
+ System.arraycopy(src, offset, newArray, 0, length)
+ newArray
+ // TODO: SparkSqlSerializer.deserialize[Any](src)
+ case _ => throw new Exception(s"unsupported data type ${f.dt}")
+ }
+ }
+ }
+
+ // convert input to data type
+ def toBytes(input: Any, field: Field): Array[Byte] = {
+ if (field.schema.isDefined) {
+ // Here we assume the top level type is structType
+ val record = field.catalystToAvro(input)
+ AvroSerdes.serialize(record, field.schema.get)
+ } else {
+ field.dt match {
+ case BooleanType => Bytes.toBytes(input.asInstanceOf[Boolean])
+ case ByteType => Array(input.asInstanceOf[Number].byteValue)
+ case ShortType => Bytes.toBytes(input.asInstanceOf[Number].shortValue)
+ case IntegerType => Bytes.toBytes(input.asInstanceOf[Number].intValue)
+ case LongType => Bytes.toBytes(input.asInstanceOf[Number].longValue)
+ case FloatType => Bytes.toBytes(input.asInstanceOf[Number].floatValue)
+ case DoubleType => Bytes.toBytes(input.asInstanceOf[Number].doubleValue)
+ case DateType | TimestampType => Bytes.toBytes(input.asInstanceOf[java.util.Date].getTime)
+ case StringType => Bytes.toBytes(input.toString)
+ case BinaryType => input.asInstanceOf[Array[Byte]]
+ case _ => throw new Exception(s"unsupported data type ${field.dt}")
+ }
+ }
+ }
+
+ // increment Byte array's value by 1
+ def incrementByteArray(array: Array[Byte]): Array[Byte] = {
+ if (array.length == 0) {
+ return null
+ }
+ var index = -1 // index of the byte we have to increment
+ var a = array.length - 1
+
+ while (a >= 0) {
+ if (array(a) != (-1).toByte) {
+ index = a
+ a = -1 // break from the loop because we found a non -1 element
+ }
+ a = a - 1
+ }
+
+ if (index < 0) {
+ return null
+ }
+ val returnArray = new Array[Byte](array.length)
+
+ for (a <- 0 until index) {
+ returnArray(a) = array(a)
+ }
+ returnArray(index) = (array(index) + 1).toByte
+ for (a <- index + 1 until array.length) {
+ returnArray(a) = 0.toByte
+ }
+
+ returnArray
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala
new file mode 100644
index 00000000..66e3897e
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.util.Bytes
+import scala.math.Ordering
+
+// TODO: add @InterfaceAudience.Private if https://issues.scala-lang.org/browse/SI-3600 is resolved
+package object hbase {
+ type HBaseType = Array[Byte]
+ def bytesMin = new Array[Byte](0)
+ def bytesMax = null
+ val ByteMax = -1.asInstanceOf[Byte]
+ val ByteMin = 0.asInstanceOf[Byte]
+ val ord: Ordering[HBaseType] = new Ordering[HBaseType] {
+ def compare(x: Array[Byte], y: Array[Byte]): Int = {
+ return Bytes.compareTo(x, y)
+ }
+ }
+ // Do not use BinaryType.ordering
+ implicit val order: Ordering[HBaseType] = ord
+
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/AvroSource.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/AvroSource.scala
new file mode 100644
index 00000000..b0d50290
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/AvroSource.scala
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.datasources
+
+import org.apache.avro.Schema
+import org.apache.avro.generic.GenericData
+import org.apache.hadoop.hbase.spark.AvroSerdes
+import org.apache.hadoop.hbase.spark.datasources.HBaseTableCatalog
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.SQLContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * @param col0 Column #0, Type is String
+ * @param col1 Column #1, Type is Array[Byte]
+ */
+@InterfaceAudience.Private
+case class AvroHBaseRecord(col0: String, col1: Array[Byte])
+@InterfaceAudience.Private
+object AvroHBaseRecord {
+ val schemaString =
+ s"""{"namespace": "example.avro",
+ | "type": "record", "name": "User",
+ | "fields": [
+ | {"name": "name", "type": "string"},
+ | {"name": "favorite_number", "type": ["int", "null"]},
+ | {"name": "favorite_color", "type": ["string", "null"]},
+ | {"name": "favorite_array", "type": {"type": "array", "items": "string"}},
+ | {"name": "favorite_map", "type": {"type": "map", "values": "int"}}
+ | ] }""".stripMargin
+
+ val avroSchema: Schema = {
+ val p = new Schema.Parser
+ p.parse(schemaString)
+ }
+
+ def apply(i: Int): AvroHBaseRecord = {
+
+ val user = new GenericData.Record(avroSchema);
+ user.put("name", s"name${"%03d".format(i)}")
+ user.put("favorite_number", i)
+ user.put("favorite_color", s"color${"%03d".format(i)}")
+ val favoriteArray =
+ new GenericData.Array[String](2, avroSchema.getField("favorite_array").schema())
+ favoriteArray.add(s"number${i}")
+ favoriteArray.add(s"number${i + 1}")
+ user.put("favorite_array", favoriteArray)
+ import scala.collection.JavaConverters._
+ val favoriteMap = Map[String, Int](("key1" -> i), ("key2" -> (i + 1))).asJava
+ user.put("favorite_map", favoriteMap)
+ val avroByte = AvroSerdes.serialize(user, avroSchema)
+ AvroHBaseRecord(s"name${"%03d".format(i)}", avroByte)
+ }
+}
+
+@InterfaceAudience.Private
+object AvroSource {
+ def catalog = s"""{
+ |"table":{"namespace":"default", "name":"ExampleAvrotable"},
+ |"rowkey":"key",
+ |"columns":{
+ |"col0":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"col1":{"cf":"cf1", "col":"col1", "type":"binary"}
+ |}
+ |}""".stripMargin
+
+ def avroCatalog = s"""{
+ |"table":{"namespace":"default", "name":"ExampleAvrotable"},
+ |"rowkey":"key",
+ |"columns":{
+ |"col0":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"col1":{"cf":"cf1", "col":"col1", "avro":"avroSchema"}
+ |}
+ |}""".stripMargin
+
+ def avroCatalogInsert = s"""{
+ |"table":{"namespace":"default", "name":"ExampleAvrotableInsert"},
+ |"rowkey":"key",
+ |"columns":{
+ |"col0":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"col1":{"cf":"cf1", "col":"col1", "avro":"avroSchema"}
+ |}
+ |}""".stripMargin
+
+ def main(args: Array[String]) {
+ val sparkConf = new SparkConf().setAppName("AvroSourceExample")
+ val sc = new SparkContext(sparkConf)
+ val sqlContext = new SQLContext(sc)
+
+ import sqlContext.implicits._
+
+ def withCatalog(cat: String): DataFrame = {
+ sqlContext.read
+ .options(
+ Map(
+ "avroSchema" -> AvroHBaseRecord.schemaString,
+ HBaseTableCatalog.tableCatalog -> avroCatalog))
+ .format("org.apache.hadoop.hbase.spark")
+ .load()
+ }
+
+ val data = (0 to 255).map { i => AvroHBaseRecord(i) }
+
+ sc.parallelize(data)
+ .toDF
+ .write
+ .options(Map(HBaseTableCatalog.tableCatalog -> catalog, HBaseTableCatalog.newTable -> "5"))
+ .format("org.apache.hadoop.hbase.spark")
+ .save()
+
+ val df = withCatalog(catalog)
+ df.show()
+ df.printSchema()
+ df.registerTempTable("ExampleAvrotable")
+ val c = sqlContext.sql("select count(1) from ExampleAvrotable")
+ c.show()
+
+ val filtered = df.select($"col0", $"col1.favorite_array").where($"col0" === "name001")
+ filtered.show()
+ val collected = filtered.collect()
+ if (collected(0).getSeq[String](1)(0) != "number1") {
+ throw new UserCustomizedSampleException("value invalid")
+ }
+ if (collected(0).getSeq[String](1)(1) != "number2") {
+ throw new UserCustomizedSampleException("value invalid")
+ }
+
+ df.write
+ .options(
+ Map(
+ "avroSchema" -> AvroHBaseRecord.schemaString,
+ HBaseTableCatalog.tableCatalog -> avroCatalogInsert,
+ HBaseTableCatalog.newTable -> "5"))
+ .format("org.apache.hadoop.hbase.spark")
+ .save()
+ val newDF = withCatalog(avroCatalogInsert)
+ newDF.show()
+ newDF.printSchema()
+ if (newDF.count() != 256) {
+ throw new UserCustomizedSampleException("value invalid")
+ }
+
+ df.filter($"col1.name" === "name005" || $"col1.name" <= "name005")
+ .select("col0", "col1.favorite_color", "col1.favorite_number")
+ .show()
+
+ df.filter($"col1.name" <= "name005" || $"col1.name".contains("name007"))
+ .select("col0", "col1.favorite_color", "col1.favorite_number")
+ .show()
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/DataType.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/DataType.scala
new file mode 100644
index 00000000..d314dc8c
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/DataType.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.datasources
+
+import org.apache.hadoop.hbase.spark.datasources.HBaseTableCatalog
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.SQLContext
+import org.apache.yetus.audience.InterfaceAudience
+
+@InterfaceAudience.Private
+class UserCustomizedSampleException(message: String = null, cause: Throwable = null)
+ extends RuntimeException(UserCustomizedSampleException.message(message, cause), cause)
+
+@InterfaceAudience.Private
+object UserCustomizedSampleException {
+ def message(message: String, cause: Throwable) =
+ if (message != null) message
+ else if (cause != null) cause.toString()
+ else null
+}
+
+@InterfaceAudience.Private
+case class IntKeyRecord(
+ col0: Integer,
+ col1: Boolean,
+ col2: Double,
+ col3: Float,
+ col4: Int,
+ col5: Long,
+ col6: Short,
+ col7: String,
+ col8: Byte)
+
+object IntKeyRecord {
+ def apply(i: Int): IntKeyRecord = {
+ IntKeyRecord(
+ if (i % 2 == 0) i else -i,
+ i % 2 == 0,
+ i.toDouble,
+ i.toFloat,
+ i,
+ i.toLong,
+ i.toShort,
+ s"String$i extra",
+ i.toByte)
+ }
+}
+
+@InterfaceAudience.Private
+object DataType {
+ val cat = s"""{
+ |"table":{"namespace":"default", "name":"DataTypeExampleTable"},
+ |"rowkey":"key",
+ |"columns":{
+ |"col0":{"cf":"rowkey", "col":"key", "type":"int"},
+ |"col1":{"cf":"cf1", "col":"col1", "type":"boolean"},
+ |"col2":{"cf":"cf2", "col":"col2", "type":"double"},
+ |"col3":{"cf":"cf3", "col":"col3", "type":"float"},
+ |"col4":{"cf":"cf4", "col":"col4", "type":"int"},
+ |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"},
+ |"col6":{"cf":"cf6", "col":"col6", "type":"smallint"},
+ |"col7":{"cf":"cf7", "col":"col7", "type":"string"},
+ |"col8":{"cf":"cf8", "col":"col8", "type":"tinyint"}
+ |}
+ |}""".stripMargin
+
+ def main(args: Array[String]) {
+ val sparkConf = new SparkConf().setAppName("DataTypeExample")
+ val sc = new SparkContext(sparkConf)
+ val sqlContext = new SQLContext(sc)
+
+ import sqlContext.implicits._
+
+ def withCatalog(cat: String): DataFrame = {
+ sqlContext.read
+ .options(Map(HBaseTableCatalog.tableCatalog -> cat))
+ .format("org.apache.hadoop.hbase.spark")
+ .load()
+ }
+
+ // test populate table
+ val data = (0 until 32).map { i => IntKeyRecord(i) }
+ sc.parallelize(data)
+ .toDF
+ .write
+ .options(Map(HBaseTableCatalog.tableCatalog -> cat, HBaseTableCatalog.newTable -> "5"))
+ .format("org.apache.hadoop.hbase.spark")
+ .save()
+
+ // test less than 0
+ val df = withCatalog(cat)
+ val s = df.filter($"col0" < 0)
+ s.show()
+ if (s.count() != 16) {
+ throw new UserCustomizedSampleException("value invalid")
+ }
+
+ // test less or equal than -10. The number of results is 11
+ val num1 = df.filter($"col0" <= -10)
+ num1.show()
+ val c1 = num1.count()
+ println(s"test result count should be 11: $c1")
+
+ // test less or equal than -9. The number of results is 12
+ val num2 = df.filter($"col0" <= -9)
+ num2.show()
+ val c2 = num2.count()
+ println(s"test result count should be 12: $c2")
+
+ // test greater or equal than -9". The number of results is 21
+ val num3 = df.filter($"col0" >= -9)
+ num3.show()
+ val c3 = num3.count()
+ println(s"test result count should be 21: $c3")
+
+ // test greater or equal than 0. The number of results is 16
+ val num4 = df.filter($"col0" >= 0)
+ num4.show()
+ val c4 = num4.count()
+ println(s"test result count should be 16: $c4")
+
+ // test greater than 10. The number of results is 10
+ val num5 = df.filter($"col0" > 10)
+ num5.show()
+ val c5 = num5.count()
+ println(s"test result count should be 10: $c5")
+
+ // test "and". The number of results is 11
+ val num6 = df.filter($"col0" > -10 && $"col0" <= 10)
+ num6.show()
+ val c6 = num6.count()
+ println(s"test result count should be 11: $c6")
+
+ // test "or". The number of results is 21
+ val num7 = df.filter($"col0" <= -10 || $"col0" > 10)
+ num7.show()
+ val c7 = num7.count()
+ println(s"test result count should be 21: $c7")
+
+ // test "all". The number of results is 32
+ val num8 = df.filter($"col0" >= -100)
+ num8.show()
+ val c8 = num8.count()
+ println(s"test result count should be 32: $c8")
+
+ // test "full query"
+ val df1 = withCatalog(cat)
+ df1.show()
+ val c_df = df1.count()
+ println(s"df count should be 32: $c_df")
+ if (c_df != 32) {
+ throw new UserCustomizedSampleException("value invalid")
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/HBaseSource.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/HBaseSource.scala
new file mode 100644
index 00000000..d7101e2e
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/HBaseSource.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.datasources
+
+import org.apache.hadoop.hbase.spark.datasources.HBaseTableCatalog
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.SQLContext
+import org.apache.yetus.audience.InterfaceAudience
+
+@InterfaceAudience.Private
+case class HBaseRecord(
+ col0: String,
+ col1: Boolean,
+ col2: Double,
+ col3: Float,
+ col4: Int,
+ col5: Long,
+ col6: Short,
+ col7: String,
+ col8: Byte)
+
+@InterfaceAudience.Private
+object HBaseRecord {
+ def apply(i: Int): HBaseRecord = {
+ val s = s"""row${"%03d".format(i)}"""
+ HBaseRecord(
+ s,
+ i % 2 == 0,
+ i.toDouble,
+ i.toFloat,
+ i,
+ i.toLong,
+ i.toShort,
+ s"String$i extra",
+ i.toByte)
+ }
+}
+
+@InterfaceAudience.Private
+object HBaseSource {
+ val cat = s"""{
+ |"table":{"namespace":"default", "name":"HBaseSourceExampleTable"},
+ |"rowkey":"key",
+ |"columns":{
+ |"col0":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"col1":{"cf":"cf1", "col":"col1", "type":"boolean"},
+ |"col2":{"cf":"cf2", "col":"col2", "type":"double"},
+ |"col3":{"cf":"cf3", "col":"col3", "type":"float"},
+ |"col4":{"cf":"cf4", "col":"col4", "type":"int"},
+ |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"},
+ |"col6":{"cf":"cf6", "col":"col6", "type":"smallint"},
+ |"col7":{"cf":"cf7", "col":"col7", "type":"string"},
+ |"col8":{"cf":"cf8", "col":"col8", "type":"tinyint"}
+ |}
+ |}""".stripMargin
+
+ def main(args: Array[String]) {
+ val sparkConf = new SparkConf().setAppName("HBaseSourceExample")
+ val sc = new SparkContext(sparkConf)
+ val sqlContext = new SQLContext(sc)
+
+ import sqlContext.implicits._
+
+ def withCatalog(cat: String): DataFrame = {
+ sqlContext.read
+ .options(Map(HBaseTableCatalog.tableCatalog -> cat))
+ .format("org.apache.hadoop.hbase.spark")
+ .load()
+ }
+
+ val data = (0 to 255).map { i => HBaseRecord(i) }
+
+ sc.parallelize(data)
+ .toDF
+ .write
+ .options(Map(HBaseTableCatalog.tableCatalog -> cat, HBaseTableCatalog.newTable -> "5"))
+ .format("org.apache.hadoop.hbase.spark")
+ .save()
+
+ val df = withCatalog(cat)
+ df.show()
+ df.filter($"col0" <= "row005")
+ .select($"col0", $"col1")
+ .show
+ df.filter($"col0" === "row005" || $"col0" <= "row005")
+ .select($"col0", $"col1")
+ .show
+ df.filter($"col0" > "row250")
+ .select($"col0", $"col1")
+ .show
+ df.registerTempTable("table1")
+ val c = sqlContext.sql("select count(col1) from table1 where col0 < 'row050'")
+ c.show()
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkDeleteExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkDeleteExample.scala
new file mode 100644
index 00000000..94659779
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkDeleteExample.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext
+
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.Delete
+import org.apache.hadoop.hbase.spark.HBaseContext
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a simple example of deleting records in HBase
+ * with the bulkDelete function.
+ */
+@InterfaceAudience.Private
+object HBaseBulkDeleteExample {
+ def main(args: Array[String]) {
+ if (args.length < 1) {
+ println("HBaseBulkDeleteExample {tableName} missing an argument")
+ return
+ }
+
+ val tableName = args(0)
+
+ val sparkConf = new SparkConf().setAppName("HBaseBulkDeleteExample " + tableName)
+ val sc = new SparkContext(sparkConf)
+ try {
+ // [Array[Byte]]
+ val rdd = sc.parallelize(
+ Array[Array[Byte]](
+ Bytes.toBytes("1"),
+ Bytes.toBytes("2"),
+ Bytes.toBytes("3"),
+ Bytes.toBytes("4"),
+ Bytes.toBytes("5")))
+
+ val conf = HBaseConfiguration.create()
+
+ val hbaseContext = new HBaseContext(sc, conf)
+ hbaseContext.bulkDelete[Array[Byte]](
+ rdd,
+ TableName.valueOf(tableName),
+ putRecord => new Delete(putRecord),
+ 4)
+ } finally {
+ sc.stop()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkGetExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkGetExample.scala
new file mode 100644
index 00000000..cae3dc11
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkGetExample.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext
+
+import org.apache.hadoop.hbase.CellUtil
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.Get
+import org.apache.hadoop.hbase.client.Result
+import org.apache.hadoop.hbase.spark.HBaseContext
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a simple example of getting records from HBase
+ * with the bulkGet function.
+ */
+@InterfaceAudience.Private
+object HBaseBulkGetExample {
+ def main(args: Array[String]) {
+ if (args.length < 1) {
+ println("HBaseBulkGetExample {tableName} missing an argument")
+ return
+ }
+
+ val tableName = args(0)
+
+ val sparkConf = new SparkConf().setAppName("HBaseBulkGetExample " + tableName)
+ val sc = new SparkContext(sparkConf)
+
+ try {
+
+ // [(Array[Byte])]
+ val rdd = sc.parallelize(
+ Array[Array[Byte]](
+ Bytes.toBytes("1"),
+ Bytes.toBytes("2"),
+ Bytes.toBytes("3"),
+ Bytes.toBytes("4"),
+ Bytes.toBytes("5"),
+ Bytes.toBytes("6"),
+ Bytes.toBytes("7")))
+
+ val conf = HBaseConfiguration.create()
+
+ val hbaseContext = new HBaseContext(sc, conf)
+
+ val getRdd = hbaseContext.bulkGet[Array[Byte], String](
+ TableName.valueOf(tableName),
+ 2,
+ rdd,
+ record => {
+ System.out.println("making Get")
+ new Get(record)
+ },
+ (result: Result) => {
+
+ val it = result.listCells().iterator()
+ val b = new StringBuilder
+
+ b.append(Bytes.toString(result.getRow) + ":")
+
+ while (it.hasNext) {
+ val cell = it.next()
+ val q = Bytes.toString(CellUtil.cloneQualifier(cell))
+ if (q.equals("counter")) {
+ b.append("(" + q + "," + Bytes.toLong(CellUtil.cloneValue(cell)) + ")")
+ } else {
+ b.append("(" + q + "," + Bytes.toString(CellUtil.cloneValue(cell)) + ")")
+ }
+ }
+ b.toString()
+ })
+
+ getRdd
+ .collect()
+ .foreach(v => println(v))
+
+ } finally {
+ sc.stop()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutExample.scala
new file mode 100644
index 00000000..59e2d298
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutExample.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext
+
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.Put
+import org.apache.hadoop.hbase.spark.HBaseContext
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a simple example of putting records in HBase
+ * with the bulkPut function.
+ */
+@InterfaceAudience.Private
+object HBaseBulkPutExample {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ println("HBaseBulkPutExample {tableName} {columnFamily} are missing an arguments")
+ return
+ }
+
+ val tableName = args(0)
+ val columnFamily = args(1)
+
+ val sparkConf = new SparkConf().setAppName(
+ "HBaseBulkPutExample " +
+ tableName + " " + columnFamily)
+ val sc = new SparkContext(sparkConf)
+
+ try {
+ // [(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]
+ val rdd = sc.parallelize(
+ Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
+ (
+ Bytes.toBytes("1"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("1")))),
+ (
+ Bytes.toBytes("2"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("2")))),
+ (
+ Bytes.toBytes("3"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("3")))),
+ (
+ Bytes.toBytes("4"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("4")))),
+ (
+ Bytes.toBytes("5"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("5"))))))
+
+ val conf = HBaseConfiguration.create()
+
+ val hbaseContext = new HBaseContext(sc, conf)
+ hbaseContext.bulkPut[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
+ rdd,
+ TableName.valueOf(tableName),
+ (putRecord) => {
+ val put = new Put(putRecord._1)
+ putRecord._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3))
+ put
+ });
+ } finally {
+ sc.stop()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutExampleFromFile.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutExampleFromFile.scala
new file mode 100644
index 00000000..e7895763
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutExampleFromFile.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext
+
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.Put
+import org.apache.hadoop.hbase.spark.HBaseContext
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.hadoop.io.LongWritable
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapred.TextInputFormat
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a simple example of putting records in HBase
+ * with the bulkPut function. In this example we are
+ * getting the put information from a file
+ */
+@InterfaceAudience.Private
+object HBaseBulkPutExampleFromFile {
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ println(
+ "HBaseBulkPutExampleFromFile {tableName} {columnFamily} {inputFile} are missing an argument")
+ return
+ }
+
+ val tableName = args(0)
+ val columnFamily = args(1)
+ val inputFile = args(2)
+
+ val sparkConf = new SparkConf().setAppName(
+ "HBaseBulkPutExampleFromFile " +
+ tableName + " " + columnFamily + " " + inputFile)
+ val sc = new SparkContext(sparkConf)
+
+ try {
+ var rdd = sc
+ .hadoopFile(inputFile, classOf[TextInputFormat], classOf[LongWritable], classOf[Text])
+ .map(
+ v => {
+ System.out.println("reading-" + v._2.toString)
+ v._2.toString
+ })
+
+ val conf = HBaseConfiguration.create()
+
+ val hbaseContext = new HBaseContext(sc, conf)
+ hbaseContext.bulkPut[String](
+ rdd,
+ TableName.valueOf(tableName),
+ (putRecord) => {
+ System.out.println("hbase-" + putRecord)
+ val put = new Put(Bytes.toBytes("Value- " + putRecord))
+ put.addColumn(Bytes.toBytes("c"), Bytes.toBytes("1"), Bytes.toBytes(putRecord.length()))
+ put
+ });
+ } finally {
+ sc.stop()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutTimestampExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutTimestampExample.scala
new file mode 100644
index 00000000..5e84f2e2
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutTimestampExample.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext
+
+import org.apache.hadoop.hbase.{HBaseConfiguration, TableName}
+import org.apache.hadoop.hbase.client.Put
+import org.apache.hadoop.hbase.spark.HBaseContext
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a simple example of putting records in HBase
+ * with the bulkPut function. In this example we are
+ * also setting the timestamp in the put
+ */
+@InterfaceAudience.Private
+object HBaseBulkPutTimestampExample {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.out.println(
+ "HBaseBulkPutTimestampExample {tableName} {columnFamily} are missing an argument")
+ return
+ }
+
+ val tableName = args(0)
+ val columnFamily = args(1)
+
+ val sparkConf = new SparkConf().setAppName(
+ "HBaseBulkPutTimestampExample " +
+ tableName + " " + columnFamily)
+ val sc = new SparkContext(sparkConf)
+
+ try {
+
+ val rdd = sc.parallelize(
+ Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
+ (
+ Bytes.toBytes("6"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("1")))),
+ (
+ Bytes.toBytes("7"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("2")))),
+ (
+ Bytes.toBytes("8"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("3")))),
+ (
+ Bytes.toBytes("9"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("4")))),
+ (
+ Bytes.toBytes("10"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("5"))))))
+
+ val conf = HBaseConfiguration.create()
+
+ val timeStamp = System.currentTimeMillis()
+
+ val hbaseContext = new HBaseContext(sc, conf)
+ hbaseContext.bulkPut[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
+ rdd,
+ TableName.valueOf(tableName),
+ (putRecord) => {
+ val put = new Put(putRecord._1)
+ putRecord._2.foreach(
+ (putValue) => put.addColumn(putValue._1, putValue._2, timeStamp, putValue._3))
+ put
+ })
+ } finally {
+ sc.stop()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseDistributedScanExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseDistributedScanExample.scala
new file mode 100644
index 00000000..2f3619b2
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseDistributedScanExample.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext
+
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.Scan
+import org.apache.hadoop.hbase.spark.HBaseContext
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a simple example of scanning records from HBase
+ * with the hbaseRDD function in Distributed fashion.
+ */
+@InterfaceAudience.Private
+object HBaseDistributedScanExample {
+ def main(args: Array[String]) {
+ if (args.length < 1) {
+ println("HBaseDistributedScanExample {tableName} missing an argument")
+ return
+ }
+
+ val tableName = args(0)
+
+ val sparkConf = new SparkConf().setAppName("HBaseDistributedScanExample " + tableName)
+ val sc = new SparkContext(sparkConf)
+
+ try {
+ val conf = HBaseConfiguration.create()
+
+ val hbaseContext = new HBaseContext(sc, conf)
+
+ val scan = new Scan()
+ scan.setCaching(100)
+
+ val getRdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan)
+
+ getRdd.foreach(v => println(Bytes.toString(v._1.get())))
+
+ println(
+ "Length: " + getRdd
+ .map(r => r._1.copyBytes())
+ .collect()
+ .length);
+ } finally {
+ sc.stop()
+ }
+ }
+
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseStreamingBulkPutExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseStreamingBulkPutExample.scala
new file mode 100644
index 00000000..4774a066
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseStreamingBulkPutExample.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.hbasecontext
+
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.Put
+import org.apache.hadoop.hbase.spark.HBaseContext
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.spark.streaming.Seconds
+import org.apache.spark.streaming.StreamingContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a simple example of BulkPut with Spark Streaming
+ */
+@InterfaceAudience.Private
+object HBaseStreamingBulkPutExample {
+ def main(args: Array[String]) {
+ if (args.length < 4) {
+ println(
+ "HBaseStreamingBulkPutExample " +
+ "{host} {port} {tableName} {columnFamily} are missing an argument")
+ return
+ }
+
+ val host = args(0)
+ val port = args(1)
+ val tableName = args(2)
+ val columnFamily = args(3)
+
+ val sparkConf = new SparkConf().setAppName(
+ "HBaseStreamingBulkPutExample " +
+ tableName + " " + columnFamily)
+ val sc = new SparkContext(sparkConf)
+ try {
+ val ssc = new StreamingContext(sc, Seconds(1))
+
+ val lines = ssc.socketTextStream(host, port.toInt)
+
+ val conf = HBaseConfiguration.create()
+
+ val hbaseContext = new HBaseContext(sc, conf)
+
+ hbaseContext.streamBulkPut[String](
+ lines,
+ TableName.valueOf(tableName),
+ (putRecord) => {
+ if (putRecord.length() > 0) {
+ val put = new Put(Bytes.toBytes(putRecord))
+ put.addColumn(Bytes.toBytes("c"), Bytes.toBytes("foo"), Bytes.toBytes("bar"))
+ put
+ } else {
+ null
+ }
+ })
+ ssc.start()
+ ssc.awaitTerminationOrTimeout(60000)
+ } finally {
+ sc.stop()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkDeleteExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkDeleteExample.scala
new file mode 100644
index 00000000..1498df05
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkDeleteExample.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.rdd
+
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.Delete
+import org.apache.hadoop.hbase.spark.HBaseContext
+import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a simple example of deleting records in HBase
+ * with the bulkDelete function.
+ */
+@InterfaceAudience.Private
+object HBaseBulkDeleteExample {
+ def main(args: Array[String]) {
+ if (args.length < 1) {
+ println("HBaseBulkDeleteExample {tableName} are missing an argument")
+ return
+ }
+
+ val tableName = args(0)
+
+ val sparkConf = new SparkConf().setAppName("HBaseBulkDeleteExample " + tableName)
+ val sc = new SparkContext(sparkConf)
+ try {
+ // [Array[Byte]]
+ val rdd = sc.parallelize(
+ Array[Array[Byte]](
+ Bytes.toBytes("1"),
+ Bytes.toBytes("2"),
+ Bytes.toBytes("3"),
+ Bytes.toBytes("4"),
+ Bytes.toBytes("5")))
+
+ val conf = HBaseConfiguration.create()
+
+ val hbaseContext = new HBaseContext(sc, conf)
+
+ rdd.hbaseBulkDelete(
+ hbaseContext,
+ TableName.valueOf(tableName),
+ putRecord => new Delete(putRecord),
+ 4)
+
+ } finally {
+ sc.stop()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkGetExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkGetExample.scala
new file mode 100644
index 00000000..e57c7690
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkGetExample.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.rdd
+
+import org.apache.hadoop.hbase.CellUtil
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.Get
+import org.apache.hadoop.hbase.client.Result
+import org.apache.hadoop.hbase.spark.HBaseContext
+import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a simple example of getting records from HBase
+ * with the bulkGet function.
+ */
+@InterfaceAudience.Private
+object HBaseBulkGetExample {
+ def main(args: Array[String]) {
+ if (args.length < 1) {
+ println("HBaseBulkGetExample {tableName} is missing an argument")
+ return
+ }
+
+ val tableName = args(0)
+
+ val sparkConf = new SparkConf().setAppName("HBaseBulkGetExample " + tableName)
+ val sc = new SparkContext(sparkConf)
+
+ try {
+
+ // [(Array[Byte])]
+ val rdd = sc.parallelize(
+ Array[Array[Byte]](
+ Bytes.toBytes("1"),
+ Bytes.toBytes("2"),
+ Bytes.toBytes("3"),
+ Bytes.toBytes("4"),
+ Bytes.toBytes("5"),
+ Bytes.toBytes("6"),
+ Bytes.toBytes("7")))
+
+ val conf = HBaseConfiguration.create()
+
+ val hbaseContext = new HBaseContext(sc, conf)
+
+ val getRdd = rdd.hbaseBulkGet[String](
+ hbaseContext,
+ TableName.valueOf(tableName),
+ 2,
+ record => {
+ System.out.println("making Get")
+ new Get(record)
+ },
+ (result: Result) => {
+
+ val it = result.listCells().iterator()
+ val b = new StringBuilder
+
+ b.append(Bytes.toString(result.getRow) + ":")
+
+ while (it.hasNext) {
+ val cell = it.next()
+ val q = Bytes.toString(CellUtil.cloneQualifier(cell))
+ if (q.equals("counter")) {
+ b.append("(" + q + "," + Bytes.toLong(CellUtil.cloneValue(cell)) + ")")
+ } else {
+ b.append("(" + q + "," + Bytes.toString(CellUtil.cloneValue(cell)) + ")")
+ }
+ }
+ b.toString()
+ })
+
+ getRdd
+ .collect()
+ .foreach(v => println(v))
+
+ } finally {
+ sc.stop()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkPutExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkPutExample.scala
new file mode 100644
index 00000000..6bc7ac2c
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkPutExample.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.rdd
+
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.Put
+import org.apache.hadoop.hbase.spark.HBaseContext
+import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a simple example of putting records in HBase
+ * with the bulkPut function.
+ */
+@InterfaceAudience.Private
+object HBaseBulkPutExample {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ println("HBaseBulkPutExample {tableName} {columnFamily} are missing an arguments")
+ return
+ }
+
+ val tableName = args(0)
+ val columnFamily = args(1)
+
+ val sparkConf = new SparkConf().setAppName(
+ "HBaseBulkPutExample " +
+ tableName + " " + columnFamily)
+ val sc = new SparkContext(sparkConf)
+
+ try {
+ // [(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]
+ val rdd = sc.parallelize(
+ Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
+ (
+ Bytes.toBytes("1"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("1")))),
+ (
+ Bytes.toBytes("2"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("2")))),
+ (
+ Bytes.toBytes("3"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("3")))),
+ (
+ Bytes.toBytes("4"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("4")))),
+ (
+ Bytes.toBytes("5"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("5"))))))
+
+ val conf = HBaseConfiguration.create()
+
+ val hbaseContext = new HBaseContext(sc, conf)
+
+ rdd.hbaseBulkPut(
+ hbaseContext,
+ TableName.valueOf(tableName),
+ (putRecord) => {
+ val put = new Put(putRecord._1)
+ putRecord._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3))
+ put
+ })
+
+ } finally {
+ sc.stop()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseForeachPartitionExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseForeachPartitionExample.scala
new file mode 100644
index 00000000..a952c548
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseForeachPartitionExample.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.rdd
+
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.Put
+import org.apache.hadoop.hbase.spark.HBaseContext
+import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a simple example of using the foreachPartition
+ * method with a HBase connection
+ */
+@InterfaceAudience.Private
+object HBaseForeachPartitionExample {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ println("HBaseForeachPartitionExample {tableName} {columnFamily} are missing an arguments")
+ return
+ }
+
+ val tableName = args(0)
+ val columnFamily = args(1)
+
+ val sparkConf = new SparkConf().setAppName(
+ "HBaseForeachPartitionExample " +
+ tableName + " " + columnFamily)
+ val sc = new SparkContext(sparkConf)
+
+ try {
+ // [(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]
+ val rdd = sc.parallelize(
+ Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
+ (
+ Bytes.toBytes("1"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("1")))),
+ (
+ Bytes.toBytes("2"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("2")))),
+ (
+ Bytes.toBytes("3"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("3")))),
+ (
+ Bytes.toBytes("4"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("4")))),
+ (
+ Bytes.toBytes("5"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("5"))))))
+
+ val conf = HBaseConfiguration.create()
+
+ val hbaseContext = new HBaseContext(sc, conf)
+
+ rdd.hbaseForeachPartition(
+ hbaseContext,
+ (it, connection) => {
+ val m = connection.getBufferedMutator(TableName.valueOf(tableName))
+
+ it.foreach(
+ r => {
+ val put = new Put(r._1)
+ r._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3))
+ m.mutate(put)
+ })
+ m.flush()
+ m.close()
+ })
+
+ } finally {
+ sc.stop()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseMapPartitionExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseMapPartitionExample.scala
new file mode 100644
index 00000000..cac41d21
--- /dev/null
+++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseMapPartitionExample.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark.example.rdd
+
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.Get
+import org.apache.hadoop.hbase.spark.HBaseContext
+import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.yetus.audience.InterfaceAudience
+
+/**
+ * This is a simple example of using the mapPartitions
+ * method with a HBase connection
+ */
+@InterfaceAudience.Private
+object HBaseMapPartitionExample {
+ def main(args: Array[String]) {
+ if (args.length < 1) {
+ println("HBaseMapPartitionExample {tableName} is missing an argument")
+ return
+ }
+
+ val tableName = args(0)
+
+ val sparkConf = new SparkConf().setAppName("HBaseMapPartitionExample " + tableName)
+ val sc = new SparkContext(sparkConf)
+
+ try {
+
+ // [(Array[Byte])]
+ val rdd = sc.parallelize(
+ Array[Array[Byte]](
+ Bytes.toBytes("1"),
+ Bytes.toBytes("2"),
+ Bytes.toBytes("3"),
+ Bytes.toBytes("4"),
+ Bytes.toBytes("5"),
+ Bytes.toBytes("6"),
+ Bytes.toBytes("7")))
+
+ val conf = HBaseConfiguration.create()
+
+ val hbaseContext = new HBaseContext(sc, conf)
+
+ val getRdd = rdd.hbaseMapPartitions[String](
+ hbaseContext,
+ (it, connection) => {
+ val table = connection.getTable(TableName.valueOf(tableName))
+ it.map {
+ r =>
+ // batching would be faster. This is just an example
+ val result = table.get(new Get(r))
+
+ val it = result.listCells().iterator()
+ val b = new StringBuilder
+
+ b.append(Bytes.toString(result.getRow) + ":")
+
+ while (it.hasNext) {
+ val cell = it.next()
+ val q = Bytes.toString(cell.getQualifierArray)
+ if (q.equals("counter")) {
+ b.append("(" + q + "," + Bytes.toLong(cell.getValueArray) + ")")
+ } else {
+ b.append("(" + q + "," + Bytes.toString(cell.getValueArray) + ")")
+ }
+ }
+ b.toString()
+ }
+ })
+
+ getRdd
+ .collect()
+ .foreach(v => println(v))
+
+ } finally {
+ sc.stop()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java b/spark4/hbase-spark4/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java
new file mode 100644
index 00000000..0f8c4aa7
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java
@@ -0,0 +1,515 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hbase.Cell;
+import org.apache.hadoop.hbase.CellUtil;
+import org.apache.hadoop.hbase.HBaseClassTestRule;
+import org.apache.hadoop.hbase.HBaseTestingUtility;
+import org.apache.hadoop.hbase.HConstants;
+import org.apache.hadoop.hbase.TableName;
+import org.apache.hadoop.hbase.client.Admin;
+import org.apache.hadoop.hbase.client.Connection;
+import org.apache.hadoop.hbase.client.ConnectionFactory;
+import org.apache.hadoop.hbase.client.Delete;
+import org.apache.hadoop.hbase.client.Get;
+import org.apache.hadoop.hbase.client.Put;
+import org.apache.hadoop.hbase.client.Result;
+import org.apache.hadoop.hbase.client.Scan;
+import org.apache.hadoop.hbase.client.Table;
+import org.apache.hadoop.hbase.io.ImmutableBytesWritable;
+import org.apache.hadoop.hbase.spark.example.hbasecontext.JavaHBaseBulkDeleteExample;
+import org.apache.hadoop.hbase.testclassification.MediumTests;
+import org.apache.hadoop.hbase.testclassification.MiscTests;
+import org.apache.hadoop.hbase.tool.LoadIncrementalHFiles;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.apache.hadoop.hbase.util.Pair;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.Tuple2;
+
+@Category({ MiscTests.class, MediumTests.class })
+public class TestJavaHBaseContext implements Serializable {
+
+ @ClassRule
+ public static final HBaseClassTestRule TIMEOUT =
+ HBaseClassTestRule.forClass(TestJavaHBaseContext.class);
+
+ protected static transient JavaSparkContext JSC;
+ private static HBaseTestingUtility TEST_UTIL;
+ private static JavaHBaseContext HBASE_CONTEXT;
+ private static final Logger LOG = LoggerFactory.getLogger(TestJavaHBaseContext.class);
+
+ protected byte[] tableName = Bytes.toBytes("t1");
+ protected byte[] columnFamily = Bytes.toBytes("c");
+ byte[] columnFamily1 = Bytes.toBytes("d");
+ String columnFamilyStr = Bytes.toString(columnFamily);
+ String columnFamilyStr1 = Bytes.toString(columnFamily1);
+
+ @BeforeClass
+ public static void setUpBeforeClass() throws Exception {
+ // NOTE: We need to do this due to behaviour change in spark 3.2, where the below conf is true
+ // by default. We will get empty table as result (for small sized tables) for HBase version not
+ // having HBASE-26340
+ SparkConf sparkConf = new SparkConf().set("spark.hadoopRDD.ignoreEmptySplits", "false");
+ JSC = new JavaSparkContext("local", "JavaHBaseContextSuite", sparkConf);
+
+ init();
+ }
+
+ protected static void init() throws Exception {
+ TEST_UTIL = new HBaseTestingUtility();
+ Configuration conf = TEST_UTIL.getConfiguration();
+
+ HBASE_CONTEXT = new JavaHBaseContext(JSC, conf);
+
+ LOG.info("cleaning up test dir");
+
+ TEST_UTIL.cleanupTestDir();
+
+ LOG.info("starting minicluster");
+
+ TEST_UTIL.startMiniCluster();
+
+ LOG.info(" - minicluster started");
+ }
+
+ @AfterClass
+ public static void tearDownAfterClass() throws Exception {
+ LOG.info("shuting down minicluster");
+ TEST_UTIL.shutdownMiniCluster();
+ LOG.info(" - minicluster shut down");
+ TEST_UTIL.cleanupTestDir();
+
+ JSC.stop();
+ JSC = null;
+ }
+
+ @Before
+ public void setUp() throws Exception {
+
+ try {
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName));
+ } catch (Exception e) {
+ LOG.info(" - no table {} found", Bytes.toString(tableName));
+ }
+
+ LOG.info(" - creating table {}", Bytes.toString(tableName));
+ TEST_UTIL.createTable(TableName.valueOf(tableName),
+ new byte[][] { columnFamily, columnFamily1 });
+ LOG.info(" - created table");
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName));
+ }
+
+ @Test
+ public void testBulkPut() throws IOException {
+
+ List list = new ArrayList<>(5);
+ list.add("1," + columnFamilyStr + ",a,1");
+ list.add("2," + columnFamilyStr + ",a,2");
+ list.add("3," + columnFamilyStr + ",a,3");
+ list.add("4," + columnFamilyStr + ",a,4");
+ list.add("5," + columnFamilyStr + ",a,5");
+
+ JavaRDD rdd = JSC.parallelize(list);
+
+ Configuration conf = TEST_UTIL.getConfiguration();
+
+ Connection conn = ConnectionFactory.createConnection(conf);
+ Table table = conn.getTable(TableName.valueOf(tableName));
+
+ try {
+ List deletes = new ArrayList<>(5);
+ for (int i = 1; i < 6; i++) {
+ deletes.add(new Delete(Bytes.toBytes(Integer.toString(i))));
+ }
+ table.delete(deletes);
+ } finally {
+ table.close();
+ }
+
+ HBASE_CONTEXT.bulkPut(rdd, TableName.valueOf(tableName), new PutFunction());
+
+ table = conn.getTable(TableName.valueOf(tableName));
+
+ try {
+ Result result1 = table.get(new Get(Bytes.toBytes("1")));
+ Assert.assertNotNull("Row 1 should had been deleted", result1.getRow());
+
+ Result result2 = table.get(new Get(Bytes.toBytes("2")));
+ Assert.assertNotNull("Row 2 should had been deleted", result2.getRow());
+
+ Result result3 = table.get(new Get(Bytes.toBytes("3")));
+ Assert.assertNotNull("Row 3 should had been deleted", result3.getRow());
+
+ Result result4 = table.get(new Get(Bytes.toBytes("4")));
+ Assert.assertNotNull("Row 4 should had been deleted", result4.getRow());
+
+ Result result5 = table.get(new Get(Bytes.toBytes("5")));
+ Assert.assertNotNull("Row 5 should had been deleted", result5.getRow());
+ } finally {
+ table.close();
+ conn.close();
+ }
+ }
+
+ public static class PutFunction implements Function {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public Put call(String v) throws Exception {
+ String[] cells = v.split(",");
+ Put put = new Put(Bytes.toBytes(cells[0]));
+
+ put.addColumn(Bytes.toBytes(cells[1]), Bytes.toBytes(cells[2]), Bytes.toBytes(cells[3]));
+ return put;
+ }
+ }
+
+ @Test
+ public void testBulkDelete() throws IOException {
+ List list = new ArrayList<>(3);
+ list.add(Bytes.toBytes("1"));
+ list.add(Bytes.toBytes("2"));
+ list.add(Bytes.toBytes("3"));
+
+ JavaRDD rdd = JSC.parallelize(list);
+
+ Configuration conf = TEST_UTIL.getConfiguration();
+
+ populateTableWithMockData(conf, TableName.valueOf(tableName));
+
+ HBASE_CONTEXT.bulkDelete(rdd, TableName.valueOf(tableName),
+ new JavaHBaseBulkDeleteExample.DeleteFunction(), 2);
+
+ try (Connection conn = ConnectionFactory.createConnection(conf);
+ Table table = conn.getTable(TableName.valueOf(tableName))) {
+ Result result1 = table.get(new Get(Bytes.toBytes("1")));
+ Assert.assertNull("Row 1 should had been deleted", result1.getRow());
+
+ Result result2 = table.get(new Get(Bytes.toBytes("2")));
+ Assert.assertNull("Row 2 should had been deleted", result2.getRow());
+
+ Result result3 = table.get(new Get(Bytes.toBytes("3")));
+ Assert.assertNull("Row 3 should had been deleted", result3.getRow());
+
+ Result result4 = table.get(new Get(Bytes.toBytes("4")));
+ Assert.assertNotNull("Row 4 should had been deleted", result4.getRow());
+
+ Result result5 = table.get(new Get(Bytes.toBytes("5")));
+ Assert.assertNotNull("Row 5 should had been deleted", result5.getRow());
+ }
+ }
+
+ @Test
+ public void testDistributedScan() throws IOException {
+ Configuration conf = TEST_UTIL.getConfiguration();
+
+ populateTableWithMockData(conf, TableName.valueOf(tableName));
+
+ Scan scan = new Scan();
+ scan.setCaching(100);
+
+ JavaRDD javaRdd =
+ HBASE_CONTEXT.hbaseRDD(TableName.valueOf(tableName), scan).map(new ScanConvertFunction());
+
+ List results = javaRdd.collect();
+
+ Assert.assertEquals(results.size(), 5);
+ }
+
+ private static class ScanConvertFunction
+ implements Function, String> {
+ @Override
+ public String call(Tuple2 v1) throws Exception {
+ return Bytes.toString(v1._1().copyBytes());
+ }
+ }
+
+ @Test
+ public void testBulkGet() throws IOException {
+ List list = new ArrayList<>(5);
+ list.add(Bytes.toBytes("1"));
+ list.add(Bytes.toBytes("2"));
+ list.add(Bytes.toBytes("3"));
+ list.add(Bytes.toBytes("4"));
+ list.add(Bytes.toBytes("5"));
+
+ JavaRDD rdd = JSC.parallelize(list);
+
+ Configuration conf = TEST_UTIL.getConfiguration();
+
+ populateTableWithMockData(conf, TableName.valueOf(tableName));
+
+ final JavaRDD stringJavaRDD = HBASE_CONTEXT.bulkGet(TableName.valueOf(tableName), 2,
+ rdd, new GetFunction(), new ResultFunction());
+
+ Assert.assertEquals(stringJavaRDD.count(), 5);
+ }
+
+ @Test
+ public void testBulkLoad() throws Exception {
+
+ Path output = TEST_UTIL.getDataTestDir("testBulkLoad");
+ // Add cell as String: "row,falmily,qualifier,value"
+ List list = new ArrayList();
+ // row1
+ list.add("1," + columnFamilyStr + ",b,1");
+ // row3
+ list.add("3," + columnFamilyStr + ",a,2");
+ list.add("3," + columnFamilyStr + ",b,1");
+ list.add("3," + columnFamilyStr1 + ",a,1");
+ // row2
+ list.add("2," + columnFamilyStr + ",a,3");
+ list.add("2," + columnFamilyStr + ",b,3");
+
+ JavaRDD rdd = JSC.parallelize(list);
+
+ Configuration conf = TEST_UTIL.getConfiguration();
+
+ HBASE_CONTEXT.bulkLoad(rdd, TableName.valueOf(tableName), new BulkLoadFunction(),
+ output.toUri().getPath(), new HashMap(), false,
+ HConstants.DEFAULT_MAX_FILE_SIZE);
+
+ try (Connection conn = ConnectionFactory.createConnection(conf);
+ Admin admin = conn.getAdmin()) {
+ Table table = conn.getTable(TableName.valueOf(tableName));
+ // Do bulk load
+ LoadIncrementalHFiles load = new LoadIncrementalHFiles(conf);
+ load.doBulkLoad(output, admin, table, conn.getRegionLocator(TableName.valueOf(tableName)));
+
+ // Check row1
+ List cell1 = table.get(new Get(Bytes.toBytes("1"))).listCells();
+ Assert.assertEquals(cell1.size(), 1);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell1.get(0))), columnFamilyStr);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell1.get(0))), "b");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell1.get(0))), "1");
+
+ // Check row3
+ List cell3 = table.get(new Get(Bytes.toBytes("3"))).listCells();
+ Assert.assertEquals(cell3.size(), 3);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell3.get(0))), columnFamilyStr);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell3.get(0))), "a");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell3.get(0))), "2");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell3.get(1))), columnFamilyStr);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell3.get(1))), "b");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell3.get(1))), "1");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell3.get(2))), columnFamilyStr1);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell3.get(2))), "a");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell3.get(2))), "1");
+
+ // Check row2
+ List cell2 = table.get(new Get(Bytes.toBytes("2"))).listCells();
+ Assert.assertEquals(cell2.size(), 2);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell2.get(0))), columnFamilyStr);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell2.get(0))), "a");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell2.get(0))), "3");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell2.get(1))), columnFamilyStr);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell2.get(1))), "b");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell2.get(1))), "3");
+ }
+ }
+
+ @Test
+ public void testBulkLoadThinRows() throws Exception {
+ Path output = TEST_UTIL.getDataTestDir("testBulkLoadThinRows");
+ // because of the limitation of scala bulkLoadThinRows API
+ // we need to provide data as
+ List> list = new ArrayList>();
+ // row1
+ List list1 = new ArrayList();
+ list1.add("1," + columnFamilyStr + ",b,1");
+ list.add(list1);
+ // row3
+ List list3 = new ArrayList();
+ list3.add("3," + columnFamilyStr + ",a,2");
+ list3.add("3," + columnFamilyStr + ",b,1");
+ list3.add("3," + columnFamilyStr1 + ",a,1");
+ list.add(list3);
+ // row2
+ List list2 = new ArrayList();
+ list2.add("2," + columnFamilyStr + ",a,3");
+ list2.add("2," + columnFamilyStr + ",b,3");
+ list.add(list2);
+
+ JavaRDD> rdd = JSC.parallelize(list);
+
+ Configuration conf = TEST_UTIL.getConfiguration();
+
+ HBASE_CONTEXT.bulkLoadThinRows(rdd, TableName.valueOf(tableName),
+ new BulkLoadThinRowsFunction(), output.toString(), new HashMap<>(), false,
+ HConstants.DEFAULT_MAX_FILE_SIZE);
+
+ try (Connection conn = ConnectionFactory.createConnection(conf);
+ Admin admin = conn.getAdmin()) {
+ Table table = conn.getTable(TableName.valueOf(tableName));
+ // Do bulk load
+ LoadIncrementalHFiles load = new LoadIncrementalHFiles(conf);
+ load.doBulkLoad(output, admin, table, conn.getRegionLocator(TableName.valueOf(tableName)));
+
+ // Check row1
+ List cell1 = table.get(new Get(Bytes.toBytes("1"))).listCells();
+ Assert.assertEquals(cell1.size(), 1);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell1.get(0))), columnFamilyStr);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell1.get(0))), "b");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell1.get(0))), "1");
+
+ // Check row3
+ List cell3 = table.get(new Get(Bytes.toBytes("3"))).listCells();
+ Assert.assertEquals(cell3.size(), 3);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell3.get(0))), columnFamilyStr);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell3.get(0))), "a");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell3.get(0))), "2");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell3.get(1))), columnFamilyStr);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell3.get(1))), "b");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell3.get(1))), "1");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell3.get(2))), columnFamilyStr1);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell3.get(2))), "a");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell3.get(2))), "1");
+
+ // Check row2
+ List cell2 = table.get(new Get(Bytes.toBytes("2"))).listCells();
+ Assert.assertEquals(cell2.size(), 2);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell2.get(0))), columnFamilyStr);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell2.get(0))), "a");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell2.get(0))), "3");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell2.get(1))), columnFamilyStr);
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell2.get(1))), "b");
+ Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell2.get(1))), "3");
+ }
+
+ }
+
+ public static class BulkLoadFunction
+ implements Function> {
+ @Override
+ public Pair call(String v1) throws Exception {
+ if (v1 == null) {
+ return null;
+ }
+
+ String[] strs = v1.split(",");
+ if (strs.length != 4) {
+ return null;
+ }
+
+ KeyFamilyQualifier kfq = new KeyFamilyQualifier(Bytes.toBytes(strs[0]),
+ Bytes.toBytes(strs[1]), Bytes.toBytes(strs[2]));
+ return new Pair(kfq, Bytes.toBytes(strs[3]));
+ }
+ }
+
+ public static class BulkLoadThinRowsFunction
+ implements Function, Pair> {
+ @Override
+ public Pair call(List list) {
+ if (list == null) {
+ return null;
+ }
+
+ ByteArrayWrapper rowKey = null;
+ FamiliesQualifiersValues fqv = new FamiliesQualifiersValues();
+ for (String cell : list) {
+ String[] strs = cell.split(",");
+ if (rowKey == null) {
+ rowKey = new ByteArrayWrapper(Bytes.toBytes(strs[0]));
+ }
+ fqv.add(Bytes.toBytes(strs[1]), Bytes.toBytes(strs[2]), Bytes.toBytes(strs[3]));
+ }
+ return new Pair(rowKey, fqv);
+ }
+ }
+
+ public static class GetFunction implements Function {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public Get call(byte[] v) throws Exception {
+ return new Get(v);
+ }
+ }
+
+ public static class ResultFunction implements Function {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public String call(Result result) throws Exception {
+ Iterator it = result.listCells().iterator();
+ StringBuilder b = new StringBuilder();
+
+ b.append(Bytes.toString(result.getRow())).append(":");
+
+ while (it.hasNext()) {
+ Cell cell = it.next();
+ String q = Bytes.toString(CellUtil.cloneQualifier(cell));
+ if ("counter".equals(q)) {
+ b.append("(").append(q).append(",").append(Bytes.toLong(CellUtil.cloneValue(cell)))
+ .append(")");
+ } else {
+ b.append("(").append(q).append(",").append(Bytes.toString(CellUtil.cloneValue(cell)))
+ .append(")");
+ }
+ }
+ return b.toString();
+ }
+ }
+
+ protected void populateTableWithMockData(Configuration conf, TableName tableName)
+ throws IOException {
+ try (Connection conn = ConnectionFactory.createConnection(conf);
+ Table table = conn.getTable(tableName); Admin admin = conn.getAdmin()) {
+
+ List puts = new ArrayList<>(5);
+
+ for (int i = 1; i < 6; i++) {
+ Put put = new Put(Bytes.toBytes(Integer.toString(i)));
+ put.addColumn(columnFamily, columnFamily, columnFamily);
+ puts.add(put);
+ }
+ table.put(puts);
+ admin.flush(tableName);
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContextForLargeRows.java b/spark4/hbase-spark4/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContextForLargeRows.java
new file mode 100644
index 00000000..824b53d4
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContextForLargeRows.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.HBaseClassTestRule;
+import org.apache.hadoop.hbase.TableName;
+import org.apache.hadoop.hbase.client.Admin;
+import org.apache.hadoop.hbase.client.Connection;
+import org.apache.hadoop.hbase.client.ConnectionFactory;
+import org.apache.hadoop.hbase.client.Put;
+import org.apache.hadoop.hbase.client.Table;
+import org.apache.hadoop.hbase.testclassification.MediumTests;
+import org.apache.hadoop.hbase.testclassification.MiscTests;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
+import org.junit.experimental.categories.Category;
+
+@Category({ MiscTests.class, MediumTests.class })
+public class TestJavaHBaseContextForLargeRows extends TestJavaHBaseContext {
+
+ @ClassRule
+ public static final HBaseClassTestRule TIMEOUT =
+ HBaseClassTestRule.forClass(TestJavaHBaseContextForLargeRows.class);
+
+ @BeforeClass
+ public static void setUpBeforeClass() throws Exception {
+ JSC = new JavaSparkContext("local", "JavaHBaseContextSuite");
+
+ init();
+ }
+
+ protected void populateTableWithMockData(Configuration conf, TableName tableName)
+ throws IOException {
+ try (Connection conn = ConnectionFactory.createConnection(conf);
+ Table table = conn.getTable(tableName); Admin admin = conn.getAdmin()) {
+
+ List puts = new ArrayList<>(5);
+
+ for (int i = 1; i < 6; i++) {
+ Put put = new Put(Bytes.toBytes(Integer.toString(i)));
+ // We are trying to generate a large row value here
+ char[] chars = new char[1024 * 1024 * 2];
+ // adding '0' to convert int to char
+ Arrays.fill(chars, (char) (i + '0'));
+ put.addColumn(columnFamily, columnFamily, Bytes.toBytes(String.valueOf(chars)));
+ puts.add(put);
+ }
+ table.put(puts);
+ admin.flush(tableName);
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/test/resources/hbase-site.xml b/spark4/hbase-spark4/src/test/resources/hbase-site.xml
new file mode 100644
index 00000000..b3fb0d90
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/resources/hbase-site.xml
@@ -0,0 +1,157 @@
+
+
+
+
+
+ hbase.regionserver.msginterval
+ 1000
+ Interval between messages from the RegionServer to HMaster
+ in milliseconds. Default is 15. Set this value low if you want unit
+ tests to be responsive.
+
+
+
+ hbase.defaults.for.version.skip
+ true
+
+
+ hbase.server.thread.wakefrequency
+ 1000
+ Time to sleep in between searches for work (in milliseconds).
+ Used as sleep interval by service threads such as hbase:meta scanner and log roller.
+
+
+
+ hbase.master.event.waiting.time
+ 50
+ Time to sleep between checks to see if a table event took place.
+
+
+
+ hbase.regionserver.handler.count
+ 5
+
+
+ hbase.regionserver.metahandler.count
+ 5
+
+
+ hbase.ipc.server.read.threadpool.size
+ 3
+
+
+ hbase.master.info.port
+ -1
+ The port for the hbase master web UI
+ Set to -1 if you do not want the info server to run.
+
+
+
+ hbase.master.port
+ 0
+ Always have masters and regionservers come up on port '0' so we don't clash over
+ default ports.
+
+
+
+ hbase.regionserver.port
+ 0
+ Always have masters and regionservers come up on port '0' so we don't clash over
+ default ports.
+
+
+
+ hbase.ipc.client.fallback-to-simple-auth-allowed
+ true
+
+
+
+ hbase.regionserver.info.port
+ -1
+ The port for the hbase regionserver web UI
+ Set to -1 if you do not want the info server to run.
+
+
+
+ hbase.regionserver.info.port.auto
+ true
+ Info server auto port bind. Enables automatic port
+ search if hbase.regionserver.info.port is already in use.
+ Enabled for testing to run multiple tests on one machine.
+
+
+
+ hbase.regionserver.safemode
+ false
+
+ Turn on/off safe mode in region server. Always on for production, always off
+ for tests.
+
+
+
+ hbase.hregion.max.filesize
+ 67108864
+
+ Maximum desired file size for an HRegion. If filesize exceeds
+ value + (value / 2), the HRegion is split in two. Default: 256M.
+
+ Keep the maximum filesize small so we split more often in tests.
+
+
+
+ hadoop.log.dir
+ ${user.dir}/../logs
+
+
+ hbase.zookeeper.property.clientPort
+ 21818
+ Property from ZooKeeper's config zoo.cfg.
+ The port at which the clients will connect.
+
+
+
+ hbase.defaults.for.version.skip
+ true
+
+ Set to true to skip the 'hbase.defaults.for.version'.
+ Setting this to true can be useful in contexts other than
+ the other side of a maven generation; i.e. running in an
+ ide. You'll want to set this boolean to true to avoid
+ seeing the RuntimeException complaint: "hbase-default.xml file
+ seems to be for and old version of HBase (@@@VERSION@@@), this
+ version is X.X.X-SNAPSHOT"
+
+
+
+ hbase.table.sanity.checks
+ false
+ Skip sanity checks in tests
+
+
+
+ hbase.procedure.fail.on.corruption
+ true
+
+ Enable replay sanity checks on procedure tests.
+
+
+
diff --git a/spark4/hbase-spark4/src/test/resources/log4j.properties b/spark4/hbase-spark4/src/test/resources/log4j.properties
new file mode 100644
index 00000000..cd3b8e9d
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/resources/log4j.properties
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Define some default values that can be overridden by system properties
+hbase.root.logger=INFO,FA
+hbase.log.dir=.
+hbase.log.file=hbase.log
+
+# Define the root logger to the system property "hbase.root.logger".
+log4j.rootLogger=${hbase.root.logger}
+
+# Logging Threshold
+log4j.threshold=ALL
+
+#
+# Daily Rolling File Appender
+#
+log4j.appender.DRFA=org.apache.log4j.DailyRollingFileAppender
+log4j.appender.DRFA.File=${hbase.log.dir}/${hbase.log.file}
+
+# Rollver at midnight
+log4j.appender.DRFA.DatePattern=.yyyy-MM-dd
+
+# 30-day backup
+#log4j.appender.DRFA.MaxBackupIndex=30
+log4j.appender.DRFA.layout=org.apache.log4j.PatternLayout
+# Debugging Pattern format
+log4j.appender.DRFA.layout.ConversionPattern=%d{ISO8601} %-5p [%t] %C{2}(%L): %m%n
+
+
+#
+# console
+# Add "console" to rootlogger above if you want to use this
+#
+log4j.appender.console=org.apache.log4j.ConsoleAppender
+log4j.appender.console.target=System.err
+log4j.appender.console.layout=org.apache.log4j.PatternLayout
+log4j.appender.console.layout.ConversionPattern=%d{ISO8601} %-5p [%t] %C{2}(%L): %m%n
+
+#File Appender
+log4j.appender.FA=org.apache.log4j.FileAppender
+log4j.appender.FA.append=false
+log4j.appender.FA.file=target/log-output.txt
+log4j.appender.FA.layout=org.apache.log4j.PatternLayout
+log4j.appender.FA.layout.ConversionPattern=%d{ISO8601} %-5p [%t] %C{2}(%L): %m%n
+log4j.appender.FA.Threshold = INFO
+
+# Custom Logging levels
+
+#log4j.logger.org.apache.hadoop.fs.FSNamesystem=DEBUG
+
+log4j.logger.org.apache.hadoop=WARN
+log4j.logger.org.apache.zookeeper=ERROR
+log4j.logger.org.apache.hadoop.hbase=DEBUG
+
+#These settings are workarounds against spurious logs from the minicluster.
+#See HBASE-4709
+log4j.logger.org.apache.hadoop.metrics2.impl.MetricsConfig=WARN
+log4j.logger.org.apache.hadoop.metrics2.impl.MetricsSinkAdapter=WARN
+log4j.logger.org.apache.hadoop.metrics2.impl.MetricsSystemImpl=WARN
+log4j.logger.org.apache.hadoop.metrics2.util.MBeans=WARN
+# Enable this to get detailed connection error/retry logging.
+# log4j.logger.org.apache.hadoop.hbase.client.ConnectionImplementation=TRACE
diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/BulkLoadSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/BulkLoadSuite.scala
new file mode 100644
index 00000000..2b03c0f9
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/BulkLoadSuite.scala
@@ -0,0 +1,1055 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.io.File
+import java.net.URI
+import java.nio.file.Files
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.hbase.{CellUtil, HBaseTestingUtility, HConstants, TableName}
+import org.apache.hadoop.hbase.client.{ConnectionFactory, Get}
+import org.apache.hadoop.hbase.io.hfile.{CacheConfig, HFile}
+import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._
+import org.apache.hadoop.hbase.tool.LoadIncrementalHFiles
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkContext
+import org.junit.rules.TemporaryFolder
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+
+class BulkLoadSuite extends AnyFunSuite with BeforeAndAfterEach with BeforeAndAfterAll with Logging {
+ @transient var sc: SparkContext = null
+ var TEST_UTIL = new HBaseTestingUtility
+
+ val tableName = "t1"
+ val columnFamily1 = "f1"
+ val columnFamily2 = "f2"
+ val testFolder = new TemporaryFolder()
+
+ override def beforeAll() {
+ TEST_UTIL.startMiniCluster()
+ logInfo(" - minicluster started")
+
+ try {
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName))
+ } catch {
+ case e: Exception =>
+ logInfo(" - no table " + tableName + " found")
+ }
+
+ logInfo(" - created table")
+
+ val envMap = Map[String, String](("Xmx", "512m"))
+
+ sc = new SparkContext("local", "test", null, Nil, envMap)
+ }
+
+ override def afterAll() {
+ logInfo("shuting down minicluster")
+ TEST_UTIL.shutdownMiniCluster()
+ logInfo(" - minicluster shut down")
+ TEST_UTIL.cleanupTestDir()
+ sc.stop()
+ }
+
+ test("Staging dir: Test usage of staging dir on a separate filesystem") {
+ val config = TEST_UTIL.getConfiguration
+
+ logInfo(" - creating table " + tableName)
+ TEST_UTIL.createTable(
+ TableName.valueOf(tableName),
+ Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2)))
+
+ // Test creates rdd with 2 column families and
+ // write those to hfiles on local filesystem
+ // using bulkLoad functionality. We don't check the load functionality
+ // due the limitations of the HBase Minicluster
+
+ val rdd = sc.parallelize(
+ Array[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))](
+ (
+ Bytes.toBytes("1"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))),
+ (
+ Bytes.toBytes("2"),
+ (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("bar.2")))))
+
+ val hbaseContext = new HBaseContext(sc, config)
+ val uri = Files.createTempDirectory("tmpDirPrefix").toUri
+ val stagingUri = new URI(uri + "staging_dir")
+ val stagingFolder = new File(stagingUri)
+ val fs = new Path(stagingUri.toString).getFileSystem(config)
+ try {
+ hbaseContext.bulkLoad[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))](
+ rdd,
+ TableName.valueOf(tableName),
+ t => {
+ val rowKey = t._1
+ val family: Array[Byte] = t._2._1
+ val qualifier = t._2._2
+ val value: Array[Byte] = t._2._3
+
+ val keyFamilyQualifier = new KeyFamilyQualifier(rowKey, family, qualifier)
+
+ Seq((keyFamilyQualifier, value)).iterator
+ },
+ stagingUri.toString)
+
+ assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 2)
+
+ } finally {
+ val admin = ConnectionFactory.createConnection(config).getAdmin
+ try {
+ admin.disableTable(TableName.valueOf(tableName))
+ admin.deleteTable(TableName.valueOf(tableName))
+ } finally {
+ admin.close()
+ }
+ fs.delete(new Path(stagingFolder.getPath), true)
+
+ testFolder.delete()
+
+ }
+ }
+
+ test(
+ "Wide Row Bulk Load: Test multi family and multi column tests " +
+ "with all default HFile Configs.") {
+ val config = TEST_UTIL.getConfiguration
+
+ logInfo(" - creating table " + tableName)
+ TEST_UTIL.createTable(
+ TableName.valueOf(tableName),
+ Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2)))
+
+ // There are a number of tests in here.
+ // 1. Row keys are not in order
+ // 2. Qualifiers are not in order
+ // 3. Column Families are not in order
+ // 4. There are tests for records in one column family and some in two column families
+ // 5. There are records will a single qualifier and some with two
+ val rdd = sc.parallelize(
+ Array[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))](
+ (
+ Bytes.toBytes("1"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))),
+ (
+ Bytes.toBytes("3"),
+ (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo2.a"))),
+ (
+ Bytes.toBytes("3"),
+ (Bytes.toBytes(columnFamily2), Bytes.toBytes("a"), Bytes.toBytes("foo2.b"))),
+ (
+ Bytes.toBytes("3"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo2.c"))),
+ (
+ Bytes.toBytes("5"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo3"))),
+ (
+ Bytes.toBytes("4"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo.1"))),
+ (
+ Bytes.toBytes("4"),
+ (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo.2"))),
+ (
+ Bytes.toBytes("2"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("bar.1"))),
+ (
+ Bytes.toBytes("2"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("bar.2")))))
+
+ val hbaseContext = new HBaseContext(sc, config)
+
+ testFolder.create()
+ val stagingFolder = testFolder.newFolder()
+
+ hbaseContext.bulkLoad[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))](
+ rdd,
+ TableName.valueOf(tableName),
+ t => {
+ val rowKey = t._1
+ val family: Array[Byte] = t._2._1
+ val qualifier = t._2._2
+ val value: Array[Byte] = t._2._3
+
+ val keyFamilyQualifier = new KeyFamilyQualifier(rowKey, family, qualifier)
+
+ Seq((keyFamilyQualifier, value)).iterator
+ },
+ stagingFolder.getPath)
+
+ val fs = FileSystem.get(config)
+ assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 2)
+
+ val conn = ConnectionFactory.createConnection(config)
+
+ val load = new LoadIncrementalHFiles(config)
+ val table = conn.getTable(TableName.valueOf(tableName))
+ try {
+ load.doBulkLoad(
+ new Path(stagingFolder.getPath),
+ conn.getAdmin,
+ table,
+ conn.getRegionLocator(TableName.valueOf(tableName)))
+
+ val cells5 = table.get(new Get(Bytes.toBytes("5"))).listCells()
+ assert(cells5.size == 1)
+ assert(Bytes.toString(CellUtil.cloneValue(cells5.get(0))).equals("foo3"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells5.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells5.get(0))).equals("a"))
+
+ val cells4 = table.get(new Get(Bytes.toBytes("4"))).listCells()
+ assert(cells4.size == 2)
+ assert(Bytes.toString(CellUtil.cloneValue(cells4.get(0))).equals("foo.1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells4.get(1))).equals("foo.2"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(1))).equals("f2"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(1))).equals("b"))
+
+ val cells3 = table.get(new Get(Bytes.toBytes("3"))).listCells()
+ assert(cells3.size == 3)
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(0))).equals("foo2.c"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(1))).equals("foo2.b"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(1))).equals("f2"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(1))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(2))).equals("foo2.a"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(2))).equals("f2"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(2))).equals("b"))
+
+ val cells2 = table.get(new Get(Bytes.toBytes("2"))).listCells()
+ assert(cells2.size == 2)
+ assert(Bytes.toString(CellUtil.cloneValue(cells2.get(0))).equals("bar.1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells2.get(1))).equals("bar.2"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(1))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(1))).equals("b"))
+
+ val cells1 = table.get(new Get(Bytes.toBytes("1"))).listCells()
+ assert(cells1.size == 1)
+ assert(Bytes.toString(CellUtil.cloneValue(cells1.get(0))).equals("foo1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells1.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells1.get(0))).equals("a"))
+
+ } finally {
+ table.close()
+ val admin = ConnectionFactory.createConnection(config).getAdmin
+ try {
+ admin.disableTable(TableName.valueOf(tableName))
+ admin.deleteTable(TableName.valueOf(tableName))
+ } finally {
+ admin.close()
+ }
+ fs.delete(new Path(stagingFolder.getPath), true)
+
+ testFolder.delete()
+
+ }
+ }
+
+ test(
+ "Wide Row Bulk Load: Test HBase client: Test Roll Over and " +
+ "using an implicit call to bulk load") {
+ val config = TEST_UTIL.getConfiguration
+
+ logInfo(" - creating table " + tableName)
+ TEST_UTIL.createTable(
+ TableName.valueOf(tableName),
+ Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2)))
+
+ // There are a number of tests in here.
+ // 1. Row keys are not in order
+ // 2. Qualifiers are not in order
+ // 3. Column Families are not in order
+ // 4. There are tests for records in one column family and some in two column families
+ // 5. There are records will a single qualifier and some with two
+ val rdd = sc.parallelize(
+ Array[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))](
+ (
+ Bytes.toBytes("1"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))),
+ (
+ Bytes.toBytes("3"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("foo2.b"))),
+ (
+ Bytes.toBytes("3"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo2.a"))),
+ (
+ Bytes.toBytes("3"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("c"), Bytes.toBytes("foo2.c"))),
+ (
+ Bytes.toBytes("5"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo3"))),
+ (
+ Bytes.toBytes("4"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo.1"))),
+ (
+ Bytes.toBytes("4"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("foo.2"))),
+ (
+ Bytes.toBytes("2"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("bar.1"))),
+ (
+ Bytes.toBytes("2"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("bar.2")))))
+
+ val hbaseContext = new HBaseContext(sc, config)
+
+ testFolder.create()
+ val stagingFolder = testFolder.newFolder()
+
+ rdd.hbaseBulkLoad(
+ hbaseContext,
+ TableName.valueOf(tableName),
+ t => {
+ val rowKey = t._1
+ val family: Array[Byte] = t._2._1
+ val qualifier = t._2._2
+ val value = t._2._3
+
+ val keyFamilyQualifier = new KeyFamilyQualifier(rowKey, family, qualifier)
+
+ Seq((keyFamilyQualifier, value)).iterator
+ },
+ stagingFolder.getPath,
+ new java.util.HashMap[Array[Byte], FamilyHFileWriteOptions],
+ compactionExclude = false,
+ 20)
+
+ val fs = FileSystem.get(config)
+ assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 1)
+
+ assert(fs.listStatus(new Path(stagingFolder.getPath + "/f1")).length == 5)
+
+ val conn = ConnectionFactory.createConnection(config)
+
+ val load = new LoadIncrementalHFiles(config)
+ val table = conn.getTable(TableName.valueOf(tableName))
+ try {
+ load.doBulkLoad(
+ new Path(stagingFolder.getPath),
+ conn.getAdmin,
+ table,
+ conn.getRegionLocator(TableName.valueOf(tableName)))
+
+ val cells5 = table.get(new Get(Bytes.toBytes("5"))).listCells()
+ assert(cells5.size == 1)
+ assert(Bytes.toString(CellUtil.cloneValue(cells5.get(0))).equals("foo3"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells5.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells5.get(0))).equals("a"))
+
+ val cells4 = table.get(new Get(Bytes.toBytes("4"))).listCells()
+ assert(cells4.size == 2)
+ assert(Bytes.toString(CellUtil.cloneValue(cells4.get(0))).equals("foo.1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells4.get(1))).equals("foo.2"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(1))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(1))).equals("b"))
+
+ val cells3 = table.get(new Get(Bytes.toBytes("3"))).listCells()
+ assert(cells3.size == 3)
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(0))).equals("foo2.a"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(1))).equals("foo2.b"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(1))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(1))).equals("b"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(2))).equals("foo2.c"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(2))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(2))).equals("c"))
+
+ val cells2 = table.get(new Get(Bytes.toBytes("2"))).listCells()
+ assert(cells2.size == 2)
+ assert(Bytes.toString(CellUtil.cloneValue(cells2.get(0))).equals("bar.1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells2.get(1))).equals("bar.2"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(1))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(1))).equals("b"))
+
+ val cells1 = table.get(new Get(Bytes.toBytes("1"))).listCells()
+ assert(cells1.size == 1)
+ assert(Bytes.toString(CellUtil.cloneValue(cells1.get(0))).equals("foo1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells1.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells1.get(0))).equals("a"))
+
+ } finally {
+ table.close()
+ val admin = ConnectionFactory.createConnection(config).getAdmin
+ try {
+ admin.disableTable(TableName.valueOf(tableName))
+ admin.deleteTable(TableName.valueOf(tableName))
+ } finally {
+ admin.close()
+ }
+ fs.delete(new Path(stagingFolder.getPath), true)
+
+ testFolder.delete()
+ }
+ }
+
+ test(
+ "Wide Row Bulk Load: Test multi family and multi column tests" +
+ " with one column family with custom configs plus multi region") {
+ val config = TEST_UTIL.getConfiguration
+
+ val splitKeys: Array[Array[Byte]] = new Array[Array[Byte]](2)
+ splitKeys(0) = Bytes.toBytes("2")
+ splitKeys(1) = Bytes.toBytes("4")
+
+ logInfo(" - creating table " + tableName)
+ TEST_UTIL.createTable(
+ TableName.valueOf(tableName),
+ Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2)),
+ splitKeys)
+
+ // There are a number of tests in here.
+ // 1. Row keys are not in order
+ // 2. Qualifiers are not in order
+ // 3. Column Families are not in order
+ // 4. There are tests for records in one column family and some in two column families
+ // 5. There are records will a single qualifier and some with two
+ val rdd = sc.parallelize(
+ Array[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))](
+ (
+ Bytes.toBytes("1"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))),
+ (
+ Bytes.toBytes("3"),
+ (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo2.a"))),
+ (
+ Bytes.toBytes("3"),
+ (Bytes.toBytes(columnFamily2), Bytes.toBytes("a"), Bytes.toBytes("foo2.b"))),
+ (
+ Bytes.toBytes("3"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo2.c"))),
+ (
+ Bytes.toBytes("5"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo3"))),
+ (
+ Bytes.toBytes("4"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo.1"))),
+ (
+ Bytes.toBytes("4"),
+ (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo.2"))),
+ (
+ Bytes.toBytes("2"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("bar.1"))),
+ (
+ Bytes.toBytes("2"),
+ (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("bar.2")))))
+
+ val hbaseContext = new HBaseContext(sc, config)
+
+ testFolder.create()
+ val stagingFolder = testFolder.newFolder()
+
+ val familyHBaseWriterOptions = new java.util.HashMap[Array[Byte], FamilyHFileWriteOptions]
+
+ val f1Options = new FamilyHFileWriteOptions("GZ", "ROW", 128, "PREFIX")
+
+ familyHBaseWriterOptions.put(Bytes.toBytes(columnFamily1), f1Options)
+
+ hbaseContext.bulkLoad[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))](
+ rdd,
+ TableName.valueOf(tableName),
+ t => {
+ val rowKey = t._1
+ val family: Array[Byte] = t._2._1
+ val qualifier = t._2._2
+ val value = t._2._3
+
+ val keyFamilyQualifier = new KeyFamilyQualifier(rowKey, family, qualifier)
+
+ Seq((keyFamilyQualifier, value)).iterator
+ },
+ stagingFolder.getPath,
+ familyHBaseWriterOptions,
+ compactionExclude = false,
+ HConstants.DEFAULT_MAX_FILE_SIZE)
+
+ val fs = FileSystem.get(config)
+ assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 2)
+
+ val f1FileList = fs.listStatus(new Path(stagingFolder.getPath + "/f1"))
+ for (i <- 0 until f1FileList.length) {
+ val reader =
+ HFile.createReader(fs, f1FileList(i).getPath, new CacheConfig(config), true, config)
+ assert(reader.getTrailer.getCompressionCodec().getName.equals("gz"))
+ assert(reader.getDataBlockEncoding.name().equals("PREFIX"))
+ }
+
+ assert(3 == f1FileList.length)
+
+ val f2FileList = fs.listStatus(new Path(stagingFolder.getPath + "/f2"))
+ for (i <- 0 until f2FileList.length) {
+ val reader =
+ HFile.createReader(fs, f2FileList(i).getPath, new CacheConfig(config), true, config)
+ assert(reader.getTrailer.getCompressionCodec().getName.equals("none"))
+ assert(reader.getDataBlockEncoding.name().equals("NONE"))
+ }
+
+ assert(2 == f2FileList.length)
+
+ val conn = ConnectionFactory.createConnection(config)
+
+ val load = new LoadIncrementalHFiles(config)
+ val table = conn.getTable(TableName.valueOf(tableName))
+ try {
+ load.doBulkLoad(
+ new Path(stagingFolder.getPath),
+ conn.getAdmin,
+ table,
+ conn.getRegionLocator(TableName.valueOf(tableName)))
+
+ val cells5 = table.get(new Get(Bytes.toBytes("5"))).listCells()
+ assert(cells5.size == 1)
+ assert(Bytes.toString(CellUtil.cloneValue(cells5.get(0))).equals("foo3"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells5.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells5.get(0))).equals("a"))
+
+ val cells4 = table.get(new Get(Bytes.toBytes("4"))).listCells()
+ assert(cells4.size == 2)
+ assert(Bytes.toString(CellUtil.cloneValue(cells4.get(0))).equals("foo.1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells4.get(1))).equals("foo.2"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(1))).equals("f2"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(1))).equals("b"))
+
+ val cells3 = table.get(new Get(Bytes.toBytes("3"))).listCells()
+ assert(cells3.size == 3)
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(0))).equals("foo2.c"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(1))).equals("foo2.b"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(1))).equals("f2"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(1))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(2))).equals("foo2.a"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(2))).equals("f2"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(2))).equals("b"))
+
+ val cells2 = table.get(new Get(Bytes.toBytes("2"))).listCells()
+ assert(cells2.size == 2)
+ assert(Bytes.toString(CellUtil.cloneValue(cells2.get(0))).equals("bar.1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells2.get(1))).equals("bar.2"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(1))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(1))).equals("b"))
+
+ val cells1 = table.get(new Get(Bytes.toBytes("1"))).listCells()
+ assert(cells1.size == 1)
+ assert(Bytes.toString(CellUtil.cloneValue(cells1.get(0))).equals("foo1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells1.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells1.get(0))).equals("a"))
+
+ } finally {
+ table.close()
+ val admin = ConnectionFactory.createConnection(config).getAdmin
+ try {
+ admin.disableTable(TableName.valueOf(tableName))
+ admin.deleteTable(TableName.valueOf(tableName))
+ } finally {
+ admin.close()
+ }
+ fs.delete(new Path(stagingFolder.getPath), true)
+
+ testFolder.delete()
+
+ }
+ }
+
+ test("Test partitioner") {
+
+ var splitKeys: Array[Array[Byte]] = new Array[Array[Byte]](3)
+ splitKeys(0) = Bytes.toBytes("")
+ splitKeys(1) = Bytes.toBytes("3")
+ splitKeys(2) = Bytes.toBytes("7")
+
+ var partitioner = new BulkLoadPartitioner(splitKeys)
+
+ assert(0 == partitioner.getPartition(Bytes.toBytes("")))
+ assert(0 == partitioner.getPartition(Bytes.toBytes("1")))
+ assert(0 == partitioner.getPartition(Bytes.toBytes("2")))
+ assert(1 == partitioner.getPartition(Bytes.toBytes("3")))
+ assert(1 == partitioner.getPartition(Bytes.toBytes("4")))
+ assert(1 == partitioner.getPartition(Bytes.toBytes("6")))
+ assert(2 == partitioner.getPartition(Bytes.toBytes("7")))
+ assert(2 == partitioner.getPartition(Bytes.toBytes("8")))
+
+ splitKeys = new Array[Array[Byte]](1)
+ splitKeys(0) = Bytes.toBytes("")
+
+ partitioner = new BulkLoadPartitioner(splitKeys)
+
+ assert(0 == partitioner.getPartition(Bytes.toBytes("")))
+ assert(0 == partitioner.getPartition(Bytes.toBytes("1")))
+ assert(0 == partitioner.getPartition(Bytes.toBytes("2")))
+ assert(0 == partitioner.getPartition(Bytes.toBytes("3")))
+ assert(0 == partitioner.getPartition(Bytes.toBytes("4")))
+ assert(0 == partitioner.getPartition(Bytes.toBytes("6")))
+ assert(0 == partitioner.getPartition(Bytes.toBytes("7")))
+
+ splitKeys = new Array[Array[Byte]](7)
+ splitKeys(0) = Bytes.toBytes("")
+ splitKeys(1) = Bytes.toBytes("02")
+ splitKeys(2) = Bytes.toBytes("04")
+ splitKeys(3) = Bytes.toBytes("06")
+ splitKeys(4) = Bytes.toBytes("08")
+ splitKeys(5) = Bytes.toBytes("10")
+ splitKeys(6) = Bytes.toBytes("12")
+
+ partitioner = new BulkLoadPartitioner(splitKeys)
+
+ assert(0 == partitioner.getPartition(Bytes.toBytes("")))
+ assert(0 == partitioner.getPartition(Bytes.toBytes("01")))
+ assert(1 == partitioner.getPartition(Bytes.toBytes("02")))
+ assert(1 == partitioner.getPartition(Bytes.toBytes("03")))
+ assert(2 == partitioner.getPartition(Bytes.toBytes("04")))
+ assert(2 == partitioner.getPartition(Bytes.toBytes("05")))
+ assert(3 == partitioner.getPartition(Bytes.toBytes("06")))
+ assert(3 == partitioner.getPartition(Bytes.toBytes("07")))
+ assert(4 == partitioner.getPartition(Bytes.toBytes("08")))
+ assert(4 == partitioner.getPartition(Bytes.toBytes("09")))
+ assert(5 == partitioner.getPartition(Bytes.toBytes("10")))
+ assert(5 == partitioner.getPartition(Bytes.toBytes("11")))
+ assert(6 == partitioner.getPartition(Bytes.toBytes("12")))
+ assert(6 == partitioner.getPartition(Bytes.toBytes("13")))
+ }
+
+ test(
+ "Thin Row Bulk Load: Test multi family and multi column tests " +
+ "with all default HFile Configs") {
+ val config = TEST_UTIL.getConfiguration
+
+ logInfo(" - creating table " + tableName)
+ TEST_UTIL.createTable(
+ TableName.valueOf(tableName),
+ Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2)))
+
+ // There are a number of tests in here.
+ // 1. Row keys are not in order
+ // 2. Qualifiers are not in order
+ // 3. Column Families are not in order
+ // 4. There are tests for records in one column family and some in two column families
+ // 5. There are records will a single qualifier and some with two
+ val rdd = sc
+ .parallelize(
+ Array[(String,(Array[Byte], Array[Byte], Array[Byte]))](
+ ("1", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))),
+ ("3", (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo2.a"))),
+ ("3", (Bytes.toBytes(columnFamily2), Bytes.toBytes("a"), Bytes.toBytes("foo2.b"))),
+ ("3", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo2.c"))),
+ ("5", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo3"))),
+ ("4", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo.1"))),
+ ("4", (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo.2"))),
+ ("2", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("bar.1"))),
+ ("2", (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("bar.2")))))
+ .groupByKey()
+
+ val hbaseContext = new HBaseContext(sc, config)
+
+ testFolder.create()
+ val stagingFolder = testFolder.newFolder()
+
+ hbaseContext.bulkLoadThinRows[(String, Iterable[(Array[Byte], Array[Byte], Array[Byte])])](
+ rdd,
+ TableName.valueOf(tableName),
+ t => {
+ val rowKey = Bytes.toBytes(t._1)
+
+ val familyQualifiersValues = new FamiliesQualifiersValues
+ t._2.foreach(
+ f => {
+ val family: Array[Byte] = f._1
+ val qualifier = f._2
+ val value: Array[Byte] = f._3
+
+ familyQualifiersValues += (family, qualifier, value)
+ })
+ (new ByteArrayWrapper(rowKey), familyQualifiersValues)
+ },
+ stagingFolder.getPath)
+
+ val fs = FileSystem.get(config)
+ assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 2)
+
+ val conn = ConnectionFactory.createConnection(config)
+
+ val load = new LoadIncrementalHFiles(config)
+ val table = conn.getTable(TableName.valueOf(tableName))
+ try {
+ load.doBulkLoad(
+ new Path(stagingFolder.getPath),
+ conn.getAdmin,
+ table,
+ conn.getRegionLocator(TableName.valueOf(tableName)))
+
+ val cells5 = table.get(new Get(Bytes.toBytes("5"))).listCells()
+ assert(cells5.size == 1)
+ assert(Bytes.toString(CellUtil.cloneValue(cells5.get(0))).equals("foo3"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells5.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells5.get(0))).equals("a"))
+
+ val cells4 = table.get(new Get(Bytes.toBytes("4"))).listCells()
+ assert(cells4.size == 2)
+ assert(Bytes.toString(CellUtil.cloneValue(cells4.get(0))).equals("foo.1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells4.get(1))).equals("foo.2"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(1))).equals("f2"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(1))).equals("b"))
+
+ val cells3 = table.get(new Get(Bytes.toBytes("3"))).listCells()
+ assert(cells3.size == 3)
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(0))).equals("foo2.c"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(1))).equals("foo2.b"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(1))).equals("f2"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(1))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(2))).equals("foo2.a"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(2))).equals("f2"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(2))).equals("b"))
+
+ val cells2 = table.get(new Get(Bytes.toBytes("2"))).listCells()
+ assert(cells2.size == 2)
+ assert(Bytes.toString(CellUtil.cloneValue(cells2.get(0))).equals("bar.1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells2.get(1))).equals("bar.2"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(1))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(1))).equals("b"))
+
+ val cells1 = table.get(new Get(Bytes.toBytes("1"))).listCells()
+ assert(cells1.size == 1)
+ assert(Bytes.toString(CellUtil.cloneValue(cells1.get(0))).equals("foo1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells1.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells1.get(0))).equals("a"))
+
+ } finally {
+ table.close()
+ val admin = ConnectionFactory.createConnection(config).getAdmin
+ try {
+ admin.disableTable(TableName.valueOf(tableName))
+ admin.deleteTable(TableName.valueOf(tableName))
+ } finally {
+ admin.close()
+ }
+ fs.delete(new Path(stagingFolder.getPath), true)
+
+ testFolder.delete()
+
+ }
+ }
+
+ test(
+ "Thin Row Bulk Load: Test HBase client: Test Roll Over and " +
+ "using an implicit call to bulk load") {
+ val config = TEST_UTIL.getConfiguration
+
+ logInfo(" - creating table " + tableName)
+ TEST_UTIL.createTable(
+ TableName.valueOf(tableName),
+ Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2)))
+
+ // There are a number of tests in here.
+ // 1. Row keys are not in order
+ // 2. Qualifiers are not in order
+ // 3. Column Families are not in order
+ // 4. There are tests for records in one column family and some in two column families
+ // 5. There are records will a single qualifier and some with two
+ val rdd = sc
+ .parallelize(
+ Array[(String,(Array[Byte], Array[Byte], Array[Byte]))](
+ ("1", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))),
+ ("3", (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("foo2.b"))),
+ ("3", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo2.a"))),
+ ("3", (Bytes.toBytes(columnFamily1), Bytes.toBytes("c"), Bytes.toBytes("foo2.c"))),
+ ("5", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo3"))),
+ ("4", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo.1"))),
+ ("4", (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("foo.2"))),
+ ("2", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("bar.1"))),
+ ("2", (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("bar.2")))))
+ .groupByKey()
+
+ val hbaseContext = new HBaseContext(sc, config)
+
+ testFolder.create()
+ val stagingFolder = testFolder.newFolder()
+
+ rdd.hbaseBulkLoadThinRows(
+ hbaseContext,
+ TableName.valueOf(tableName),
+ t => {
+ val rowKey = t._1
+
+ val familyQualifiersValues = new FamiliesQualifiersValues
+ t._2.foreach(
+ f => {
+ val family: Array[Byte] = f._1
+ val qualifier = f._2
+ val value: Array[Byte] = f._3
+
+ familyQualifiersValues += (family, qualifier, value)
+ })
+ (new ByteArrayWrapper(Bytes.toBytes(rowKey)), familyQualifiersValues)
+ },
+ stagingFolder.getPath,
+ new java.util.HashMap[Array[Byte], FamilyHFileWriteOptions],
+ compactionExclude = false,
+ 20)
+
+ val fs = FileSystem.get(config)
+ assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 1)
+
+ assert(fs.listStatus(new Path(stagingFolder.getPath + "/f1")).length == 5)
+
+ val conn = ConnectionFactory.createConnection(config)
+
+ val load = new LoadIncrementalHFiles(config)
+ val table = conn.getTable(TableName.valueOf(tableName))
+ try {
+ load.doBulkLoad(
+ new Path(stagingFolder.getPath),
+ conn.getAdmin,
+ table,
+ conn.getRegionLocator(TableName.valueOf(tableName)))
+
+ val cells5 = table.get(new Get(Bytes.toBytes("5"))).listCells()
+ assert(cells5.size == 1)
+ assert(Bytes.toString(CellUtil.cloneValue(cells5.get(0))).equals("foo3"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells5.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells5.get(0))).equals("a"))
+
+ val cells4 = table.get(new Get(Bytes.toBytes("4"))).listCells()
+ assert(cells4.size == 2)
+ assert(Bytes.toString(CellUtil.cloneValue(cells4.get(0))).equals("foo.1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells4.get(1))).equals("foo.2"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(1))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(1))).equals("b"))
+
+ val cells3 = table.get(new Get(Bytes.toBytes("3"))).listCells()
+ assert(cells3.size == 3)
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(0))).equals("foo2.a"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(1))).equals("foo2.b"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(1))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(1))).equals("b"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(2))).equals("foo2.c"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(2))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(2))).equals("c"))
+
+ val cells2 = table.get(new Get(Bytes.toBytes("2"))).listCells()
+ assert(cells2.size == 2)
+ assert(Bytes.toString(CellUtil.cloneValue(cells2.get(0))).equals("bar.1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells2.get(1))).equals("bar.2"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(1))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(1))).equals("b"))
+
+ val cells1 = table.get(new Get(Bytes.toBytes("1"))).listCells()
+ assert(cells1.size == 1)
+ assert(Bytes.toString(CellUtil.cloneValue(cells1.get(0))).equals("foo1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells1.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells1.get(0))).equals("a"))
+
+ } finally {
+ table.close()
+ val admin = ConnectionFactory.createConnection(config).getAdmin
+ try {
+ admin.disableTable(TableName.valueOf(tableName))
+ admin.deleteTable(TableName.valueOf(tableName))
+ } finally {
+ admin.close()
+ }
+ fs.delete(new Path(stagingFolder.getPath), true)
+
+ testFolder.delete()
+ }
+ }
+
+ test(
+ "Thin Row Bulk Load: Test multi family and multi column tests" +
+ " with one column family with custom configs plus multi region") {
+ val config = TEST_UTIL.getConfiguration
+
+ val splitKeys: Array[Array[Byte]] = new Array[Array[Byte]](2)
+ splitKeys(0) = Bytes.toBytes("2")
+ splitKeys(1) = Bytes.toBytes("4")
+
+ logInfo(" - creating table " + tableName)
+ TEST_UTIL.createTable(
+ TableName.valueOf(tableName),
+ Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2)),
+ splitKeys)
+
+ // There are a number of tests in here.
+ // 1. Row keys are not in order
+ // 2. Qualifiers are not in order
+ // 3. Column Families are not in order
+ // 4. There are tests for records in one column family and some in two column families
+ // 5. There are records will a single qualifier and some with two
+ val rdd = sc
+ .parallelize(
+ Array[(String, (Array[Byte], Array[Byte], Array[Byte]))](
+ ("1", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))),
+ ("3", (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo2.a"))),
+ ("3", (Bytes.toBytes(columnFamily2), Bytes.toBytes("a"), Bytes.toBytes("foo2.b"))),
+ ("3", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo2.c"))),
+ ("5", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo3"))),
+ ("4", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo.1"))),
+ ("4", (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo.2"))),
+ ("2", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("bar.1"))),
+ ("2", (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("bar.2")))))
+ .groupByKey()
+
+ val hbaseContext = new HBaseContext(sc, config)
+
+ testFolder.create()
+ val stagingFolder = testFolder.newFolder()
+
+ val familyHBaseWriterOptions = new java.util.HashMap[Array[Byte], FamilyHFileWriteOptions]
+
+ val f1Options = new FamilyHFileWriteOptions("GZ", "ROW", 128, "PREFIX")
+
+ familyHBaseWriterOptions.put(Bytes.toBytes(columnFamily1), f1Options)
+
+ hbaseContext.bulkLoadThinRows[(String, Iterable[(Array[Byte], Array[Byte], Array[Byte])])](
+ rdd,
+ TableName.valueOf(tableName),
+ t => {
+ val rowKey = t._1
+
+ val familyQualifiersValues = new FamiliesQualifiersValues
+ t._2.foreach(
+ f => {
+ val family: Array[Byte] = f._1
+ val qualifier = f._2
+ val value: Array[Byte] = f._3
+
+ familyQualifiersValues += (family, qualifier, value)
+ })
+ (new ByteArrayWrapper(Bytes.toBytes(rowKey)), familyQualifiersValues)
+ },
+ stagingFolder.getPath,
+ familyHBaseWriterOptions,
+ compactionExclude = false,
+ HConstants.DEFAULT_MAX_FILE_SIZE)
+
+ val fs = FileSystem.get(config)
+ assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 2)
+
+ val f1FileList = fs.listStatus(new Path(stagingFolder.getPath + "/f1"))
+ for (i <- 0 until f1FileList.length) {
+ val reader =
+ HFile.createReader(fs, f1FileList(i).getPath, new CacheConfig(config), true, config)
+ assert(reader.getTrailer.getCompressionCodec().getName.equals("gz"))
+ assert(reader.getDataBlockEncoding.name().equals("PREFIX"))
+ }
+
+ assert(3 == f1FileList.length)
+
+ val f2FileList = fs.listStatus(new Path(stagingFolder.getPath + "/f2"))
+ for (i <- 0 until f2FileList.length) {
+ val reader =
+ HFile.createReader(fs, f2FileList(i).getPath, new CacheConfig(config), true, config)
+ assert(reader.getTrailer.getCompressionCodec().getName.equals("none"))
+ assert(reader.getDataBlockEncoding.name().equals("NONE"))
+ }
+
+ assert(2 == f2FileList.length)
+
+ val conn = ConnectionFactory.createConnection(config)
+
+ val load = new LoadIncrementalHFiles(config)
+ val table = conn.getTable(TableName.valueOf(tableName))
+ try {
+ load.doBulkLoad(
+ new Path(stagingFolder.getPath),
+ conn.getAdmin,
+ table,
+ conn.getRegionLocator(TableName.valueOf(tableName)))
+
+ val cells5 = table.get(new Get(Bytes.toBytes("5"))).listCells()
+ assert(cells5.size == 1)
+ assert(Bytes.toString(CellUtil.cloneValue(cells5.get(0))).equals("foo3"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells5.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells5.get(0))).equals("a"))
+
+ val cells4 = table.get(new Get(Bytes.toBytes("4"))).listCells()
+ assert(cells4.size == 2)
+ assert(Bytes.toString(CellUtil.cloneValue(cells4.get(0))).equals("foo.1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells4.get(1))).equals("foo.2"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(1))).equals("f2"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(1))).equals("b"))
+
+ val cells3 = table.get(new Get(Bytes.toBytes("3"))).listCells()
+ assert(cells3.size == 3)
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(0))).equals("foo2.c"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(1))).equals("foo2.b"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(1))).equals("f2"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(1))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells3.get(2))).equals("foo2.a"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(2))).equals("f2"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(2))).equals("b"))
+
+ val cells2 = table.get(new Get(Bytes.toBytes("2"))).listCells()
+ assert(cells2.size == 2)
+ assert(Bytes.toString(CellUtil.cloneValue(cells2.get(0))).equals("bar.1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(0))).equals("a"))
+ assert(Bytes.toString(CellUtil.cloneValue(cells2.get(1))).equals("bar.2"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(1))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(1))).equals("b"))
+
+ val cells1 = table.get(new Get(Bytes.toBytes("1"))).listCells()
+ assert(cells1.size == 1)
+ assert(Bytes.toString(CellUtil.cloneValue(cells1.get(0))).equals("foo1"))
+ assert(Bytes.toString(CellUtil.cloneFamily(cells1.get(0))).equals("f1"))
+ assert(Bytes.toString(CellUtil.cloneQualifier(cells1.get(0))).equals("a"))
+
+ } finally {
+ table.close()
+ val admin = ConnectionFactory.createConnection(config).getAdmin
+ try {
+ admin.disableTable(TableName.valueOf(tableName))
+ admin.deleteTable(TableName.valueOf(tableName))
+ } finally {
+ admin.close()
+ }
+ fs.delete(new Path(stagingFolder.getPath), true)
+
+ testFolder.delete()
+
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
new file mode 100644
index 00000000..58e4d37f
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
@@ -0,0 +1,1365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.sql.{Date, Timestamp}
+import org.apache.avro.Schema
+import org.apache.avro.generic.GenericData
+import org.apache.hadoop.hbase.{HBaseTestingUtility, TableName}
+import org.apache.hadoop.hbase.client.{ConnectionFactory, Put}
+import org.apache.hadoop.hbase.spark.datasources.{HBaseSparkConf, HBaseTableCatalog}
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.functions._
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.xml.sax.SAXParseException
+
+import scala.collection.IterableOnce.iterableOnceExtensionMethods
+
+case class HBaseRecord(
+ col0: String,
+ col1: Boolean,
+ col2: Double,
+ col3: Float,
+ col4: Int,
+ col5: Long,
+ col6: Short,
+ col7: String,
+ col8: Byte)
+
+object HBaseRecord {
+ def apply(i: Int, t: String): HBaseRecord = {
+ val s = s"""row${"%03d".format(i)}"""
+ HBaseRecord(
+ s,
+ i % 2 == 0,
+ i.toDouble,
+ i.toFloat,
+ i,
+ i.toLong,
+ i.toShort,
+ s"String$i: $t",
+ i.toByte)
+ }
+}
+
+case class AvroHBaseKeyRecord(col0: Array[Byte], col1: Array[Byte])
+
+object AvroHBaseKeyRecord {
+ val schemaString =
+ s"""{"namespace": "example.avro",
+ | "type": "record", "name": "User",
+ | "fields": [ {"name": "name", "type": "string"},
+ | {"name": "favorite_number", "type": ["int", "null"]},
+ | {"name": "favorite_color", "type": ["string", "null"]} ] }""".stripMargin
+
+ val avroSchema: Schema = {
+ val p = new Schema.Parser
+ p.parse(schemaString)
+ }
+
+ def apply(i: Int): AvroHBaseKeyRecord = {
+ val user = new GenericData.Record(avroSchema);
+ user.put("name", s"name${"%03d".format(i)}")
+ user.put("favorite_number", i)
+ user.put("favorite_color", s"color${"%03d".format(i)}")
+ val avroByte = AvroSerdes.serialize(user, avroSchema)
+ AvroHBaseKeyRecord(avroByte, avroByte)
+ }
+}
+
+class DefaultSourceSuite
+ extends AnyFunSuite
+ with BeforeAndAfterEach
+ with BeforeAndAfterAll
+ with Logging {
+ @transient var sc: SparkContext = null
+ var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility
+
+ val t1TableName = "t1"
+ val t2TableName = "t2"
+ val t3TableName = "t3"
+ val columnFamily = "c"
+
+ val timestamp = 1234567890000L
+
+ var sqlContext: SQLContext = null
+ var df: DataFrame = null
+
+ override def beforeAll() {
+
+ TEST_UTIL.startMiniCluster
+
+ logInfo(" - minicluster started")
+ try
+ TEST_UTIL.deleteTable(TableName.valueOf(t1TableName))
+ catch {
+ case e: Exception => logInfo(" - no table " + t1TableName + " found")
+ }
+ try
+ TEST_UTIL.deleteTable(TableName.valueOf(t2TableName))
+ catch {
+ case e: Exception => logInfo(" - no table " + t2TableName + " found")
+ }
+ try
+ TEST_UTIL.deleteTable(TableName.valueOf(t3TableName))
+ catch {
+ case e: Exception => logInfo(" - no table " + t3TableName + " found")
+ }
+
+ logInfo(" - creating table " + t1TableName)
+ TEST_UTIL.createTable(TableName.valueOf(t1TableName), Bytes.toBytes(columnFamily))
+ logInfo(" - created table")
+ logInfo(" - creating table " + t2TableName)
+ TEST_UTIL.createTable(TableName.valueOf(t2TableName), Bytes.toBytes(columnFamily))
+ logInfo(" - created table")
+ logInfo(" - creating table " + t3TableName)
+ TEST_UTIL.createTable(TableName.valueOf(t3TableName), Bytes.toBytes(columnFamily))
+ logInfo(" - created table")
+
+ val sparkConf = new SparkConf
+ sparkConf.set(HBaseSparkConf.QUERY_CACHEBLOCKS, "true")
+ sparkConf.set(HBaseSparkConf.QUERY_BATCHSIZE, "100")
+ sparkConf.set(HBaseSparkConf.QUERY_CACHEDROWS, "100")
+
+ sc = new SparkContext("local", "test", sparkConf)
+
+ val connection = ConnectionFactory.createConnection(TEST_UTIL.getConfiguration)
+ try {
+ val t1Table = connection.getTable(TableName.valueOf(t1TableName))
+
+ try {
+ var put = new Put(Bytes.toBytes("get1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(1))
+ t1Table.put(put)
+ put = new Put(Bytes.toBytes("get2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("4"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(4))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("z"), Bytes.toBytes("FOO"))
+ t1Table.put(put)
+ put = new Put(Bytes.toBytes("get3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(8))
+ t1Table.put(put)
+ put = new Put(Bytes.toBytes("get4"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo4"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("10"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(10))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("z"), Bytes.toBytes("BAR"))
+ t1Table.put(put)
+ put = new Put(Bytes.toBytes("get5"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo5"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(8))
+ t1Table.put(put)
+ } finally {
+ t1Table.close()
+ }
+
+ val t2Table = connection.getTable(TableName.valueOf(t2TableName))
+
+ try {
+ var put = new Put(Bytes.toBytes(1))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(1))
+ t2Table.put(put)
+ put = new Put(Bytes.toBytes(2))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("4"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(4))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("z"), Bytes.toBytes("FOO"))
+ t2Table.put(put)
+ put = new Put(Bytes.toBytes(3))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(8))
+ t2Table.put(put)
+ put = new Put(Bytes.toBytes(4))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo4"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("10"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(10))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("z"), Bytes.toBytes("BAR"))
+ t2Table.put(put)
+ put = new Put(Bytes.toBytes(5))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo5"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(8))
+ t2Table.put(put)
+ } finally {
+ t2Table.close()
+ }
+
+ val t3Table = connection.getTable(TableName.valueOf(t3TableName))
+
+ try {
+ val put = new Put(Bytes.toBytes("row"))
+ put.addColumn(
+ Bytes.toBytes(columnFamily),
+ Bytes.toBytes("binary"),
+ Array(1.toByte, 2.toByte, 3.toByte))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("boolean"), Bytes.toBytes(true))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("byte"), Array(127.toByte))
+ put.addColumn(
+ Bytes.toBytes(columnFamily),
+ Bytes.toBytes("short"),
+ Bytes.toBytes(32767.toShort))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("int"), Bytes.toBytes(1000000))
+ put.addColumn(
+ Bytes.toBytes(columnFamily),
+ Bytes.toBytes("long"),
+ Bytes.toBytes(10000000000L))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("float"), Bytes.toBytes(0.5f))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("double"), Bytes.toBytes(0.125))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("date"), Bytes.toBytes(timestamp))
+ put.addColumn(
+ Bytes.toBytes(columnFamily),
+ Bytes.toBytes("timestamp"),
+ Bytes.toBytes(timestamp))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("string"), Bytes.toBytes("string"))
+ t3Table.put(put)
+ } finally {
+ t3Table.close()
+ }
+ } finally {
+ connection.close()
+ }
+
+ def hbaseTable1Catalog = s"""{
+ |"table":{"namespace":"default", "name":"t1"},
+ |"rowkey":"key",
+ |"columns":{
+ |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
+ |"B_FIELD":{"cf":"c", "col":"b", "type":"string"}
+ |}
+ |}""".stripMargin
+
+ new HBaseContext(sc, TEST_UTIL.getConfiguration)
+ sqlContext = new SQLContext(sc)
+
+ df = sqlContext.load(
+ "org.apache.hadoop.hbase.spark",
+ Map(HBaseTableCatalog.tableCatalog -> hbaseTable1Catalog))
+
+ df.registerTempTable("hbaseTable1")
+
+ def hbaseTable2Catalog = s"""{
+ |"table":{"namespace":"default", "name":"t2"},
+ |"rowkey":"key",
+ |"columns":{
+ |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"int"},
+ |"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
+ |"B_FIELD":{"cf":"c", "col":"b", "type":"string"}
+ |}
+ |}""".stripMargin
+
+ df = sqlContext.load(
+ "org.apache.hadoop.hbase.spark",
+ Map(HBaseTableCatalog.tableCatalog -> hbaseTable2Catalog))
+
+ df.registerTempTable("hbaseTable2")
+ }
+
+ override def afterAll() {
+ TEST_UTIL.deleteTable(TableName.valueOf(t1TableName))
+ logInfo("shuting down minicluster")
+ TEST_UTIL.shutdownMiniCluster()
+
+ sc.stop()
+ }
+
+ override def beforeEach(): Unit = {
+ DefaultSourceStaticUtils.lastFiveExecutionRules.clear()
+ }
+
+ /**
+ * A example of query three fields and also only using rowkey points for the filter
+ */
+ test("Test rowKey point only rowKey query") {
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " +
+ "WHERE " +
+ "(KEY_FIELD = 'get1' or KEY_FIELD = 'get2' or KEY_FIELD = 'get3')")
+ .take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 3)
+
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString.equals(
+ "( ( KEY_FIELD == 0 OR KEY_FIELD == 1 ) OR KEY_FIELD == 2 )"))
+
+ assert(executionRules.rowKeyFilter.points.size == 3)
+ assert(executionRules.rowKeyFilter.ranges.size == 0)
+ }
+
+ /**
+ * A example of query three fields and also only using rowkey points for the filter,
+ * some rowkey points are duplicate.
+ */
+ test("Test rowKey point only rowKey query, which contains duplicate rowkey") {
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " +
+ "WHERE " +
+ "(KEY_FIELD = 'get1' or KEY_FIELD = 'get2' or KEY_FIELD = 'get1')")
+ .take(10)
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+ assert(results.length == 2)
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString.equals(
+ "( KEY_FIELD == 0 OR KEY_FIELD == 1 )"))
+ assert(executionRules.rowKeyFilter.points.size == 2)
+ assert(executionRules.rowKeyFilter.ranges.size == 0)
+ }
+
+ /**
+ * A example of query three fields and also only using cell points for the filter
+ */
+ test("Test cell point only rowKey query") {
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " +
+ "WHERE " +
+ "(B_FIELD = '4' or B_FIELD = '10' or A_FIELD = 'foo1')")
+ .take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 3)
+
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString.equals(
+ "( ( B_FIELD == 0 OR B_FIELD == 1 ) OR A_FIELD == 2 )"))
+ }
+
+ /**
+ * A example of a OR merge between to ranges the result is one range
+ * Also an example of less then and greater then
+ */
+ test("Test two range rowKey query") {
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " +
+ "WHERE " +
+ "( KEY_FIELD < 'get2' or KEY_FIELD > 'get3')")
+ .take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 3)
+
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString.equals(
+ "( KEY_FIELD < 0 OR KEY_FIELD > 1 )"))
+
+ assert(executionRules.rowKeyFilter.points.size == 0)
+ assert(executionRules.rowKeyFilter.ranges.size == 2)
+
+ val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0)
+ assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes("")))
+ assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes("get2")))
+ assert(scanRange1.isLowerBoundEqualTo)
+ assert(!scanRange1.isUpperBoundEqualTo)
+
+ val scanRange2 = executionRules.rowKeyFilter.ranges.toList(1)
+ assert(Bytes.equals(scanRange2.lowerBound, Bytes.toBytes("get3")))
+ assert(scanRange2.upperBound == null)
+ assert(!scanRange2.isLowerBoundEqualTo)
+ assert(scanRange2.isUpperBoundEqualTo)
+ }
+
+ /**
+ * A example of a OR merge between to ranges the result is one range
+ * Also an example of less then and greater then
+ *
+ * This example makes sure the code works for a int rowKey
+ */
+ test("Test two range rowKey query where the rowKey is Int and there is a range over lap") {
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable2 " +
+ "WHERE " +
+ "( KEY_FIELD < 4 or KEY_FIELD > 2)")
+ .take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString.equals(
+ "( KEY_FIELD < 0 OR KEY_FIELD > 1 )"))
+
+ assert(executionRules.rowKeyFilter.points.size == 0)
+ assert(executionRules.rowKeyFilter.ranges.size == 2)
+ assert(results.length == 5)
+ }
+
+ /**
+ * A example of a OR merge between to ranges the result is two ranges
+ * Also an example of less then and greater then
+ *
+ * This example makes sure the code works for a int rowKey
+ */
+ test("Test two range rowKey query where the rowKey is Int and the ranges don't over lap") {
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable2 " +
+ "WHERE " +
+ "( KEY_FIELD < 2 or KEY_FIELD > 4)")
+ .take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString.equals(
+ "( KEY_FIELD < 0 OR KEY_FIELD > 1 )"))
+
+ assert(executionRules.rowKeyFilter.points.size == 0)
+
+ assert(executionRules.rowKeyFilter.ranges.size == 3)
+
+ val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0)
+ assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes(2)))
+ assert(scanRange1.isLowerBoundEqualTo)
+ assert(!scanRange1.isUpperBoundEqualTo)
+
+ val scanRange2 = executionRules.rowKeyFilter.ranges.toList(1)
+ assert(scanRange2.isUpperBoundEqualTo)
+
+ assert(results.length == 2)
+ }
+
+ /**
+ * A example of a AND merge between to ranges the result is one range
+ * Also an example of less then and equal to and greater then and equal to
+ */
+ test("Test one combined range rowKey query") {
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " +
+ "WHERE " +
+ "(KEY_FIELD <= 'get3' and KEY_FIELD >= 'get2')")
+ .take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 2)
+
+ val expr = executionRules.dynamicLogicExpression.toExpressionString
+ assert(expr.equals("( ( KEY_FIELD isNotNull AND KEY_FIELD <= 0 ) AND KEY_FIELD >= 1 )"), expr)
+
+ assert(executionRules.rowKeyFilter.points.size == 0)
+ assert(executionRules.rowKeyFilter.ranges.size == 1)
+
+ val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0)
+ assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes("get2")))
+ assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes("get3")))
+ assert(scanRange1.isLowerBoundEqualTo)
+ assert(scanRange1.isUpperBoundEqualTo)
+
+ }
+
+ /**
+ * Do a select with no filters
+ */
+ test("Test select only query") {
+
+ val results = df.select("KEY_FIELD").take(10)
+ assert(results.length == 5)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(executionRules.dynamicLogicExpression == null)
+
+ }
+
+ /**
+ * A complex query with one point and one range for both the
+ * rowKey and the a column
+ */
+ test("Test SQL point and range combo") {
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD FROM hbaseTable1 " +
+ "WHERE " +
+ "(KEY_FIELD = 'get1' and B_FIELD < '3') or " +
+ "(KEY_FIELD >= 'get3' and B_FIELD = '8')")
+ .take(5)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString.equals(
+ "( ( KEY_FIELD == 0 AND B_FIELD < 1 ) OR " +
+ "( KEY_FIELD >= 2 AND B_FIELD == 3 ) )"))
+
+ assert(executionRules.rowKeyFilter.points.size == 1)
+ assert(executionRules.rowKeyFilter.ranges.size == 1)
+
+ val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0)
+ assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes("get3")))
+ assert(scanRange1.upperBound == null)
+ assert(scanRange1.isLowerBoundEqualTo)
+ assert(scanRange1.isUpperBoundEqualTo)
+
+ assert(results.length == 3)
+ }
+
+ /**
+ * A complex query with two complex ranges that doesn't merge into one
+ */
+ test("Test two complete range non merge rowKey query") {
+
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable2 " +
+ "WHERE " +
+ "( KEY_FIELD >= 1 and KEY_FIELD <= 2) or" +
+ "( KEY_FIELD > 3 and KEY_FIELD <= 5)")
+ .take(10)
+
+ assert(results.length == 4)
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString.equals(
+ "( ( KEY_FIELD >= 0 AND KEY_FIELD <= 1 ) OR " +
+ "( KEY_FIELD > 2 AND KEY_FIELD <= 3 ) )"))
+
+ assert(executionRules.rowKeyFilter.points.size == 0)
+ assert(executionRules.rowKeyFilter.ranges.size == 2)
+
+ val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0)
+ assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes(1)))
+ assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes(2)))
+ assert(scanRange1.isLowerBoundEqualTo)
+ assert(scanRange1.isUpperBoundEqualTo)
+
+ val scanRange2 = executionRules.rowKeyFilter.ranges.toList(1)
+ assert(Bytes.equals(scanRange2.lowerBound, Bytes.toBytes(3)))
+ assert(Bytes.equals(scanRange2.upperBound, Bytes.toBytes(5)))
+ assert(!scanRange2.isLowerBoundEqualTo)
+ assert(scanRange2.isUpperBoundEqualTo)
+
+ }
+
+ /**
+ * A complex query with two complex ranges that does merge into one
+ */
+ test("Test two complete range merge rowKey query") {
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " +
+ "WHERE " +
+ "( KEY_FIELD >= 'get1' and KEY_FIELD <= 'get2') or" +
+ "( KEY_FIELD > 'get3' and KEY_FIELD <= 'get5')")
+ .take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 4)
+
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString.equals(
+ "( ( KEY_FIELD >= 0 AND KEY_FIELD <= 1 ) OR " +
+ "( KEY_FIELD > 2 AND KEY_FIELD <= 3 ) )"))
+
+ assert(executionRules.rowKeyFilter.points.size == 0)
+ assert(executionRules.rowKeyFilter.ranges.size == 2)
+
+ val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0)
+ assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes("get1")))
+ assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes("get2")))
+ assert(scanRange1.isLowerBoundEqualTo)
+ assert(scanRange1.isUpperBoundEqualTo)
+
+ val scanRange2 = executionRules.rowKeyFilter.ranges.toList(1)
+ assert(Bytes.equals(scanRange2.lowerBound, Bytes.toBytes("get3")))
+ assert(Bytes.equals(scanRange2.upperBound, Bytes.toBytes("get5")))
+ assert(!scanRange2.isLowerBoundEqualTo)
+ assert(scanRange2.isUpperBoundEqualTo)
+ }
+
+ test("Test OR logic with a one RowKey and One column") {
+
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " +
+ "WHERE " +
+ "( KEY_FIELD >= 'get1' or A_FIELD <= 'foo2') or" +
+ "( KEY_FIELD > 'get3' or B_FIELD <= '4')")
+ .take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 5)
+
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString.equals(
+ "( ( KEY_FIELD >= 0 OR A_FIELD <= 1 ) OR " +
+ "( KEY_FIELD > 2 OR B_FIELD <= 3 ) )"))
+
+ assert(executionRules.rowKeyFilter.points.size == 0)
+ assert(executionRules.rowKeyFilter.ranges.size == 1)
+
+ val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0)
+ // This is the main test for 14406
+ // Because the key is joined through a or with a qualifier
+ // There is no filter on the rowKey
+ assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes("")))
+ assert(scanRange1.upperBound == null)
+ assert(scanRange1.isLowerBoundEqualTo)
+ assert(scanRange1.isUpperBoundEqualTo)
+ }
+
+ test("Test OR logic with a two columns") {
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " +
+ "WHERE " +
+ "( B_FIELD > '4' or A_FIELD <= 'foo2') or" +
+ "( A_FIELD > 'foo2' or B_FIELD < '4')")
+ .take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 5)
+
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString.equals(
+ "( ( B_FIELD > 0 OR A_FIELD <= 1 ) OR " +
+ "( A_FIELD > 2 OR B_FIELD < 3 ) )"))
+
+ assert(executionRules.rowKeyFilter.points.size == 0)
+ assert(executionRules.rowKeyFilter.ranges.size == 1)
+
+ val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0)
+ assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes("")))
+ assert(scanRange1.upperBound == null)
+ assert(scanRange1.isLowerBoundEqualTo)
+ assert(scanRange1.isUpperBoundEqualTo)
+
+ }
+
+ test("Test single RowKey Or Column logic") {
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " +
+ "WHERE " +
+ "( KEY_FIELD >= 'get4' or A_FIELD <= 'foo2' )")
+ .take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 4)
+
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString.equals(
+ "( KEY_FIELD >= 0 OR A_FIELD <= 1 )"))
+
+ assert(executionRules.rowKeyFilter.points.size == 0)
+ assert(executionRules.rowKeyFilter.ranges.size == 1)
+
+ val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0)
+ assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes("")))
+ assert(scanRange1.upperBound == null)
+ assert(scanRange1.isLowerBoundEqualTo)
+ assert(scanRange1.isUpperBoundEqualTo)
+ }
+
+ test("Test Rowkey And with complex logic (HBASE-26863)") {
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " +
+ "WHERE " +
+ "( KEY_FIELD >= 'get1' AND KEY_FIELD <= 'get3' ) AND (A_FIELD = 'foo1' OR B_FIELD = '8')")
+ .take(10)
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+ assert(results.length == 2)
+
+ assert(
+ executionRules.dynamicLogicExpression.toExpressionString
+ == "( ( ( KEY_FIELD isNotNull AND KEY_FIELD >= 0 ) AND KEY_FIELD <= 1 ) AND ( A_FIELD == 2 OR B_FIELD == 3 ) )")
+
+ assert(executionRules.rowKeyFilter.points.size == 0)
+ assert(executionRules.rowKeyFilter.ranges.size == 1)
+
+ val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0)
+ assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes("get1")))
+ assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes("get3")))
+ assert(scanRange1.isLowerBoundEqualTo)
+ assert(scanRange1.isUpperBoundEqualTo)
+ }
+
+ test("Test table that doesn't exist") {
+ val catalog = s"""{
+ |"table":{"namespace":"default", "name":"t1NotThere"},
+ |"rowkey":"key",
+ |"columns":{
+ |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
+ |"B_FIELD":{"cf":"c", "col":"c", "type":"string"}
+ |}
+ |}""".stripMargin
+
+ intercept[Exception] {
+ df = sqlContext.load(
+ "org.apache.hadoop.hbase.spark",
+ Map(HBaseTableCatalog.tableCatalog -> catalog))
+
+ df.registerTempTable("hbaseNonExistingTmp")
+
+ sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseNonExistingTmp " +
+ "WHERE " +
+ "( KEY_FIELD >= 'get1' and KEY_FIELD <= 'get3') or" +
+ "( KEY_FIELD > 'get3' and KEY_FIELD <= 'get5')")
+ .count()
+ }
+ DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+ }
+
+ test("Test table with column that doesn't exist") {
+ val catalog = s"""{
+ |"table":{"namespace":"default", "name":"t1"},
+ |"rowkey":"key",
+ |"columns":{
+ |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
+ |"B_FIELD":{"cf":"c", "col":"b", "type":"string"},
+ |"C_FIELD":{"cf":"c", "col":"c", "type":"string"}
+ |}
+ |}""".stripMargin
+ df = sqlContext.load(
+ "org.apache.hadoop.hbase.spark",
+ Map(HBaseTableCatalog.tableCatalog -> catalog))
+
+ df.registerTempTable("hbaseFactColumnTmp")
+
+ val result = sqlContext.sql(
+ "SELECT KEY_FIELD, " +
+ "B_FIELD, A_FIELD FROM hbaseFactColumnTmp")
+
+ assert(result.count() == 5)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+ assert(executionRules.dynamicLogicExpression == null)
+
+ }
+
+ test("Test table with INT column") {
+ val catalog = s"""{
+ |"table":{"namespace":"default", "name":"t1"},
+ |"rowkey":"key",
+ |"columns":{
+ |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
+ |"B_FIELD":{"cf":"c", "col":"b", "type":"string"},
+ |"I_FIELD":{"cf":"c", "col":"i", "type":"int"}
+ |}
+ |}""".stripMargin
+ df = sqlContext.load(
+ "org.apache.hadoop.hbase.spark",
+ Map(HBaseTableCatalog.tableCatalog -> catalog))
+
+ df.registerTempTable("hbaseIntTmp")
+
+ val result = sqlContext.sql(
+ "SELECT KEY_FIELD, B_FIELD, I_FIELD FROM hbaseIntTmp" +
+ " where I_FIELD > 4 and I_FIELD < 10")
+
+ val localResult = result.take(5)
+
+ assert(localResult.length == 2)
+ assert(localResult(0).getInt(2) == 8)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+ val expr = executionRules.dynamicLogicExpression.toExpressionString
+ logInfo(expr)
+ assert(expr.equals("( ( I_FIELD isNotNull AND I_FIELD > 0 ) AND I_FIELD < 1 )"), expr)
+
+ }
+
+ test("Test table with INT column defined at wrong type") {
+ val catalog = s"""{
+ |"table":{"namespace":"default", "name":"t1"},
+ |"rowkey":"key",
+ |"columns":{
+ |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
+ |"B_FIELD":{"cf":"c", "col":"b", "type":"string"},
+ |"I_FIELD":{"cf":"c", "col":"i", "type":"string"}
+ |}
+ |}""".stripMargin
+ df = sqlContext.load(
+ "org.apache.hadoop.hbase.spark",
+ Map(HBaseTableCatalog.tableCatalog -> catalog))
+
+ df.registerTempTable("hbaseIntWrongTypeTmp")
+
+ val result = sqlContext.sql(
+ "SELECT KEY_FIELD, " +
+ "B_FIELD, I_FIELD FROM hbaseIntWrongTypeTmp")
+
+ val localResult = result.take(10)
+ assert(localResult.length == 5)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+ assert(executionRules.dynamicLogicExpression == null)
+
+ assert(localResult(0).getString(2).length == 4)
+ assert(localResult(0).getString(2).charAt(0).toByte == 0)
+ assert(localResult(0).getString(2).charAt(1).toByte == 0)
+ assert(localResult(0).getString(2).charAt(2).toByte == 0)
+ assert(localResult(0).getString(2).charAt(3).toByte == 1)
+ }
+
+ test("Test bad column type") {
+ val catalog = s"""{
+ |"table":{"namespace":"default", "name":"t1"},
+ |"rowkey":"key",
+ |"columns":{
+ |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"FOOBAR"},
+ |"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
+ |"I_FIELD":{"cf":"c", "col":"i", "type":"string"}
+ |}
+ |}""".stripMargin
+ intercept[Exception] {
+ df = sqlContext.load(
+ "org.apache.hadoop.hbase.spark",
+ Map(HBaseTableCatalog.tableCatalog -> catalog))
+
+ df.registerTempTable("hbaseIntWrongTypeTmp")
+
+ val result = sqlContext.sql(
+ "SELECT KEY_FIELD, " +
+ "B_FIELD, I_FIELD FROM hbaseIntWrongTypeTmp")
+
+ val localResult = result.take(10)
+ assert(localResult.length == 5)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+ assert(executionRules.dynamicLogicExpression == null)
+
+ }
+ }
+
+ test("Test HBaseSparkConf matching") {
+ val df = sqlContext.load(
+ "org.apache.hadoop.hbase.spark.HBaseTestSource",
+ Map(
+ "cacheSize" -> "100",
+ "batchNum" -> "100",
+ "blockCacheingEnable" -> "true",
+ "rowNum" -> "10"))
+ assert(df.count() == 10)
+
+ val df1 = sqlContext.load(
+ "org.apache.hadoop.hbase.spark.HBaseTestSource",
+ Map(
+ "cacheSize" -> "1000",
+ "batchNum" -> "100",
+ "blockCacheingEnable" -> "true",
+ "rowNum" -> "10"))
+ intercept[Exception] {
+ assert(df1.count() == 10)
+ }
+
+ val df2 = sqlContext.load(
+ "org.apache.hadoop.hbase.spark.HBaseTestSource",
+ Map(
+ "cacheSize" -> "100",
+ "batchNum" -> "1000",
+ "blockCacheingEnable" -> "true",
+ "rowNum" -> "10"))
+ intercept[Exception] {
+ assert(df2.count() == 10)
+ }
+
+ val df3 = sqlContext.load(
+ "org.apache.hadoop.hbase.spark.HBaseTestSource",
+ Map(
+ "cacheSize" -> "100",
+ "batchNum" -> "100",
+ "blockCacheingEnable" -> "false",
+ "rowNum" -> "10"))
+ intercept[Exception] {
+ assert(df3.count() == 10)
+ }
+ }
+
+ test("Test table with sparse column") {
+ val catalog = s"""{
+ |"table":{"namespace":"default", "name":"t1"},
+ |"rowkey":"key",
+ |"columns":{
+ |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
+ |"B_FIELD":{"cf":"c", "col":"b", "type":"string"},
+ |"Z_FIELD":{"cf":"c", "col":"z", "type":"string"}
+ |}
+ |}""".stripMargin
+ df = sqlContext.load(
+ "org.apache.hadoop.hbase.spark",
+ Map(HBaseTableCatalog.tableCatalog -> catalog))
+
+ df.registerTempTable("hbaseZTmp")
+
+ val result = sqlContext.sql("SELECT KEY_FIELD, B_FIELD, Z_FIELD FROM hbaseZTmp")
+
+ val localResult = result.take(10)
+ assert(localResult.length == 5)
+
+ assert(localResult(0).getString(2) == null)
+ assert(localResult(1).getString(2) == "FOO")
+ assert(localResult(2).getString(2) == null)
+ assert(localResult(3).getString(2) == "BAR")
+ assert(localResult(4).getString(2) == null)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+ assert(executionRules.dynamicLogicExpression == null)
+ }
+
+ test("Test with column logic disabled") {
+ val catalog = s"""{
+ |"table":{"namespace":"default", "name":"t1"},
+ |"rowkey":"key",
+ |"columns":{
+ |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
+ |"B_FIELD":{"cf":"c", "col":"b", "type":"string"},
+ |"Z_FIELD":{"cf":"c", "col":"z", "type":"string"}
+ |}
+ |}""".stripMargin
+ df = sqlContext.load(
+ "org.apache.hadoop.hbase.spark",
+ Map(
+ HBaseTableCatalog.tableCatalog -> catalog,
+ HBaseSparkConf.PUSHDOWN_COLUMN_FILTER -> "false"))
+
+ df.registerTempTable("hbaseNoPushDownTmp")
+
+ val results = sqlContext
+ .sql(
+ "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseNoPushDownTmp " +
+ "WHERE " +
+ "(KEY_FIELD <= 'get3' and KEY_FIELD >= 'get2')")
+ .take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 2)
+
+ assert(executionRules.dynamicLogicExpression == null)
+ }
+
+ test("Test mapping") {
+ val catalog = s"""{
+ |"table":{"namespace":"default", "name":"t3"},
+ |"rowkey":"key",
+ |"columns":{
+ |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"BINARY_FIELD":{"cf":"c", "col":"binary", "type":"binary"},
+ |"BOOLEAN_FIELD":{"cf":"c", "col":"boolean", "type":"boolean"},
+ |"BYTE_FIELD":{"cf":"c", "col":"byte", "type":"byte"},
+ |"SHORT_FIELD":{"cf":"c", "col":"short", "type":"short"},
+ |"INT_FIELD":{"cf":"c", "col":"int", "type":"int"},
+ |"LONG_FIELD":{"cf":"c", "col":"long", "type":"long"},
+ |"FLOAT_FIELD":{"cf":"c", "col":"float", "type":"float"},
+ |"DOUBLE_FIELD":{"cf":"c", "col":"double", "type":"double"},
+ |"DATE_FIELD":{"cf":"c", "col":"date", "type":"date"},
+ |"TIMESTAMP_FIELD":{"cf":"c", "col":"timestamp", "type":"timestamp"},
+ |"STRING_FIELD":{"cf":"c", "col":"string", "type":"string"}
+ |}
+ |}""".stripMargin
+ df = sqlContext.load(
+ "org.apache.hadoop.hbase.spark",
+ Map(HBaseTableCatalog.tableCatalog -> catalog))
+
+ df.registerTempTable("hbaseTestMapping")
+
+ val results = sqlContext
+ .sql(
+ "SELECT binary_field, boolean_field, " +
+ "byte_field, short_field, int_field, long_field, " +
+ "float_field, double_field, date_field, timestamp_field, " +
+ "string_field FROM hbaseTestMapping")
+ .collect()
+
+ assert(results.length == 1)
+
+ val result = results(0)
+
+ System.out.println("row: " + result)
+ System.out.println("0: " + result.get(0))
+ System.out.println("1: " + result.get(1))
+ System.out.println("2: " + result.get(2))
+ System.out.println("3: " + result.get(3))
+
+ assert(
+ result.get(0).asInstanceOf[Array[Byte]].sameElements(Array(1.toByte, 2.toByte, 3.toByte)))
+ assert(result.get(1) == true)
+ assert(result.get(2) == 127)
+ assert(result.get(3) == 32767)
+ assert(result.get(4) == 1000000)
+ assert(result.get(5) == 10000000000L)
+ assert(result.get(6) == 0.5)
+ assert(result.get(7) == 0.125)
+ // sql date stores only year, month and day, so checking it is within a day
+ assert(Math.abs(result.get(8).asInstanceOf[Date].getTime - timestamp) <= 86400000)
+ assert(result.get(9).asInstanceOf[Timestamp].getTime == timestamp)
+ assert(result.get(10) == "string")
+ }
+
+ def writeCatalog = s"""{
+ |"table":{"namespace":"default", "name":"table1"},
+ |"rowkey":"key",
+ |"columns":{
+ |"col0":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"col1":{"cf":"cf1", "col":"col1", "type":"boolean"},
+ |"col2":{"cf":"cf1", "col":"col2", "type":"double"},
+ |"col3":{"cf":"cf3", "col":"col3", "type":"float"},
+ |"col4":{"cf":"cf3", "col":"col4", "type":"int"},
+ |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"},
+ |"col6":{"cf":"cf6", "col":"col6", "type":"smallint"},
+ |"col7":{"cf":"cf7", "col":"col7", "type":"string"},
+ |"col8":{"cf":"cf8", "col":"col8", "type":"tinyint"}
+ |}
+ |}""".stripMargin
+
+ def withCatalog(cat: String): DataFrame = {
+ sqlContext.read
+ .options(Map(HBaseTableCatalog.tableCatalog -> cat))
+ .format("org.apache.hadoop.hbase.spark")
+ .load()
+ }
+
+ test("populate table") {
+ val sql = sqlContext
+ import sql.implicits._
+ val data = (0 to 255).map { i => HBaseRecord(i, "extra") }
+ sc.parallelize(data)
+ .toDF
+ .write
+ .options(
+ Map(HBaseTableCatalog.tableCatalog -> writeCatalog, HBaseTableCatalog.newTable -> "5"))
+ .format("org.apache.hadoop.hbase.spark")
+ .save()
+ }
+
+ test("empty column") {
+ val df = withCatalog(writeCatalog)
+ df.registerTempTable("table0")
+ val c = sqlContext.sql("select count(1) from table0").rdd.collect()(0)(0).asInstanceOf[Long]
+ assert(c == 256)
+ }
+
+ test("full query") {
+ val df = withCatalog(writeCatalog)
+ df.show()
+ assert(df.count() == 256)
+ }
+
+ test("filtered query0") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(writeCatalog)
+ val s = df
+ .filter($"col0" <= "row005")
+ .select("col0", "col1")
+ s.show()
+ assert(s.count() == 6)
+ }
+
+ test("filtered query01") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(writeCatalog)
+ val s = df
+ .filter(col("col0").startsWith("row00"))
+ .select("col0", "col1")
+ s.show()
+ assert(s.count() == 10)
+ }
+
+ test("startsWith filtered query 1") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(writeCatalog)
+ val s = df
+ .filter(col("col0").startsWith("row005"))
+ .select("col0", "col1")
+ s.show()
+ assert(s.count() == 1)
+ }
+
+ test("startsWith filtered query 2") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(writeCatalog)
+ val s = df
+ .filter(col("col0").startsWith("row"))
+ .select("col0", "col1")
+ s.show()
+ assert(s.count() == 256)
+ }
+
+ test("startsWith filtered query 3") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(writeCatalog)
+ val s = df
+ .filter(col("col0").startsWith("row19"))
+ .select("col0", "col1")
+ s.show()
+ assert(s.count() == 10)
+ }
+
+ test("startsWith filtered query 4") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(writeCatalog)
+ val s = df
+ .filter(col("col0").startsWith(""))
+ .select("col0", "col1")
+ s.show()
+ assert(s.count() == 256)
+ }
+
+ test("Timestamp semantics") {
+ val sql = sqlContext
+ import sql.implicits._
+
+ // There's already some data in here from recently. Let's throw something in
+ // from 1993 which we can include/exclude and add some data with the implicit (now) timestamp.
+ // Then we should be able to cross-section it and only get points in between, get the most recent view
+ // and get an old view.
+ val oldMs = 754869600000L
+ val startMs = System.currentTimeMillis()
+ val oldData = (0 to 100).map { i => HBaseRecord(i, "old") }
+ val newData = (200 to 255).map { i => HBaseRecord(i, "new") }
+
+ sc.parallelize(oldData)
+ .toDF
+ .write
+ .options(
+ Map(
+ HBaseTableCatalog.tableCatalog -> writeCatalog,
+ HBaseTableCatalog.tableName -> "5",
+ HBaseSparkConf.TIMESTAMP -> oldMs.toString))
+ .format("org.apache.hadoop.hbase.spark")
+ .save()
+ sc.parallelize(newData)
+ .toDF
+ .write
+ .options(
+ Map(HBaseTableCatalog.tableCatalog -> writeCatalog, HBaseTableCatalog.tableName -> "5"))
+ .format("org.apache.hadoop.hbase.spark")
+ .save()
+
+ // Test specific timestamp -- Full scan, Timestamp
+ val individualTimestamp = sqlContext.read
+ .options(
+ Map(
+ HBaseTableCatalog.tableCatalog -> writeCatalog,
+ HBaseSparkConf.TIMESTAMP -> oldMs.toString))
+ .format("org.apache.hadoop.hbase.spark")
+ .load()
+ assert(individualTimestamp.count() == 101)
+
+ // Test getting everything -- Full Scan, No range
+ val everything = sqlContext.read
+ .options(Map(HBaseTableCatalog.tableCatalog -> writeCatalog))
+ .format("org.apache.hadoop.hbase.spark")
+ .load()
+ assert(everything.count() == 256)
+ // Test getting everything -- Pruned Scan, TimeRange
+ val element50 = everything.where(col("col0") === lit("row050")).select("col7").collect()(0)(0)
+ assert(element50 == "String50: extra")
+ val element200 = everything.where(col("col0") === lit("row200")).select("col7").collect()(0)(0)
+ assert(element200 == "String200: new")
+
+ // Test Getting old stuff -- Full Scan, TimeRange
+ val oldRange = sqlContext.read
+ .options(
+ Map(
+ HBaseTableCatalog.tableCatalog -> writeCatalog,
+ HBaseSparkConf.TIMERANGE_START -> "0",
+ HBaseSparkConf.TIMERANGE_END -> (oldMs + 100).toString))
+ .format("org.apache.hadoop.hbase.spark")
+ .load()
+ assert(oldRange.count() == 101)
+ // Test Getting old stuff -- Pruned Scan, TimeRange
+ val oldElement50 = oldRange.where(col("col0") === lit("row050")).select("col7").collect()(0)(0)
+ assert(oldElement50 == "String50: old")
+
+ // Test Getting middle stuff -- Full Scan, TimeRange
+ val middleRange = sqlContext.read
+ .options(
+ Map(
+ HBaseTableCatalog.tableCatalog -> writeCatalog,
+ HBaseSparkConf.TIMERANGE_START -> "0",
+ HBaseSparkConf.TIMERANGE_END -> (startMs + 100).toString))
+ .format("org.apache.hadoop.hbase.spark")
+ .load()
+ assert(middleRange.count() == 256)
+ // Test Getting middle stuff -- Pruned Scan, TimeRange
+ val middleElement200 =
+ middleRange.where(col("col0") === lit("row200")).select("col7").collect()(0)(0)
+ assert(middleElement200 == "String200: extra")
+ }
+
+ // catalog for insertion
+ def avroWriteCatalog = s"""{
+ |"table":{"namespace":"default", "name":"avrotable"},
+ |"rowkey":"key",
+ |"columns":{
+ |"col0":{"cf":"rowkey", "col":"key", "type":"binary"},
+ |"col1":{"cf":"cf1", "col":"col1", "type":"binary"}
+ |}
+ |}""".stripMargin
+
+ // catalog for read
+ def avroCatalog = s"""{
+ |"table":{"namespace":"default", "name":"avrotable"},
+ |"rowkey":"key",
+ |"columns":{
+ |"col0":{"cf":"rowkey", "col":"key", "avro":"avroSchema"},
+ |"col1":{"cf":"cf1", "col":"col1", "avro":"avroSchema"}
+ |}
+ |}""".stripMargin
+
+ // for insert to another table
+ def avroCatalogInsert = s"""{
+ |"table":{"namespace":"default", "name":"avrotableInsert"},
+ |"rowkey":"key",
+ |"columns":{
+ |"col0":{"cf":"rowkey", "col":"key", "avro":"avroSchema"},
+ |"col1":{"cf":"cf1", "col":"col1", "avro":"avroSchema"}
+ |}
+ |}""".stripMargin
+
+ def withAvroCatalog(cat: String): DataFrame = {
+ sqlContext.read
+ .options(
+ Map(
+ "avroSchema" -> AvroHBaseKeyRecord.schemaString,
+ HBaseTableCatalog.tableCatalog -> avroCatalog))
+ .format("org.apache.hadoop.hbase.spark")
+ .load()
+ }
+
+ test("populate avro table") {
+ val sql = sqlContext
+ import sql.implicits._
+
+ val data = (0 to 255).map { i => AvroHBaseKeyRecord(i) }
+ sc.parallelize(data)
+ .toDF
+ .write
+ .options(
+ Map(HBaseTableCatalog.tableCatalog -> avroWriteCatalog, HBaseTableCatalog.newTable -> "5"))
+ .format("org.apache.hadoop.hbase.spark")
+ .save()
+ }
+
+ test("avro empty column") {
+ val df = withAvroCatalog(avroCatalog)
+ df.registerTempTable("avrotable")
+ val c = sqlContext.sql("select count(1) from avrotable").rdd.collect()(0)(0).asInstanceOf[Long]
+ assert(c == 256)
+ }
+
+ test("avro full query") {
+ val df = withAvroCatalog(avroCatalog)
+ df.show()
+ df.printSchema()
+ assert(df.count() == 256)
+ }
+
+ test("avro serialization and deserialization query") {
+ val df = withAvroCatalog(avroCatalog)
+ df.write
+ .options(
+ Map(
+ "avroSchema" -> AvroHBaseKeyRecord.schemaString,
+ HBaseTableCatalog.tableCatalog -> avroCatalogInsert,
+ HBaseTableCatalog.newTable -> "5"))
+ .format("org.apache.hadoop.hbase.spark")
+ .save()
+ val newDF = withAvroCatalog(avroCatalogInsert)
+ newDF.show()
+ newDF.printSchema()
+ assert(newDF.count() == 256)
+ }
+
+ test("avro filtered query") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withAvroCatalog(avroCatalog)
+ val r = df
+ .filter($"col1.name" === "name005" || $"col1.name" <= "name005")
+ .select("col0", "col1.favorite_color", "col1.favorite_number")
+ r.show()
+ assert(r.count() == 6)
+ }
+
+ test("avro Or filter") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withAvroCatalog(avroCatalog)
+ val s = df
+ .filter($"col1.name" <= "name005" || $"col1.name".contains("name007"))
+ .select("col0", "col1.favorite_color", "col1.favorite_number")
+ s.show()
+ assert(s.count() == 7)
+ }
+
+ test("test create HBaseRelation with new context throws SAXParseException") {
+ val catalog = s"""{
+ |"table":{"namespace":"default", "name":"t1NotThere"},
+ |"rowkey":"key",
+ |"columns":{
+ |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
+ |"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
+ |"B_FIELD":{"cf":"c", "col":"c", "type":"string"}
+ |}
+ |}""".stripMargin
+ try {
+ HBaseRelation(
+ Map(HBaseTableCatalog.tableCatalog -> catalog, HBaseSparkConf.USE_HBASECONTEXT -> "false"),
+ None)(sqlContext)
+ } catch {
+ case e: Throwable =>
+ if (e.getCause.isInstanceOf[SAXParseException]) {
+ fail("SAXParseException due to configuration loading empty resource")
+ } else {
+ println("Failed due to some other exception, ignore " + e.getMessage)
+ }
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala
new file mode 100644
index 00000000..7e913cec
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.util
+import org.apache.hadoop.hbase.spark.datasources.{HBaseSparkConf, JavaBytesEncoder}
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.sql.types._
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+
+class DynamicLogicExpressionSuite
+ extends AnyFunSuite
+ with BeforeAndAfterEach
+ with BeforeAndAfterAll
+ with Logging {
+
+ val encoder = JavaBytesEncoder.create(HBaseSparkConf.DEFAULT_QUERY_ENCODER)
+
+ test("Basic And Test") {
+ val leftLogic = new LessThanLogicExpression("Col1", 0)
+ leftLogic.setEncoder(encoder)
+ val rightLogic = new GreaterThanLogicExpression("Col1", 1)
+ rightLogic.setEncoder(encoder)
+ val andLogic = new AndLogicExpression(leftLogic, rightLogic)
+
+ val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]()
+
+ columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(10)))
+ val valueFromQueryValueArray = new Array[Array[Byte]](2)
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5)
+ assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5)
+ assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10)
+ assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ val expressionString = andLogic.toExpressionString
+
+ assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )"))
+
+ val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder)
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5)
+ assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5)
+ assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10)
+ assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ }
+
+ test("Basic OR Test") {
+ val leftLogic = new LessThanLogicExpression("Col1", 0)
+ leftLogic.setEncoder(encoder)
+ val rightLogic = new GreaterThanLogicExpression("Col1", 1)
+ rightLogic.setEncoder(encoder)
+ val OrLogic = new OrLogicExpression(leftLogic, rightLogic)
+
+ val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]()
+
+ columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(10)))
+ val valueFromQueryValueArray = new Array[Array[Byte]](2)
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5)
+ assert(OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5)
+ assert(OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10)
+ assert(OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10)
+ assert(!OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ val expressionString = OrLogic.toExpressionString
+
+ assert(expressionString.equals("( Col1 < 0 OR Col1 > 1 )"))
+
+ val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder)
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5)
+ assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5)
+ assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10)
+ assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10)
+ valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10)
+ assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+ }
+
+ test("Basic Command Test") {
+ val greaterLogic = new GreaterThanLogicExpression("Col1", 0)
+ greaterLogic.setEncoder(encoder)
+ val greaterAndEqualLogic = new GreaterThanOrEqualLogicExpression("Col1", 0)
+ greaterAndEqualLogic.setEncoder(encoder)
+ val lessLogic = new LessThanLogicExpression("Col1", 0)
+ lessLogic.setEncoder(encoder)
+ val lessAndEqualLogic = new LessThanOrEqualLogicExpression("Col1", 0)
+ lessAndEqualLogic.setEncoder(encoder)
+ val equalLogic = new EqualLogicExpression("Col1", 0, false)
+ val notEqualLogic = new EqualLogicExpression("Col1", 0, true)
+ val passThrough = new PassThroughLogicExpression
+
+ val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]()
+ columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(10)))
+ val valueFromQueryValueArray = new Array[Array[Byte]](1)
+
+ // great than
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10)
+ assert(!greaterLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 20)
+ assert(!greaterLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ // great than and equal
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 5)
+ assert(greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10)
+ assert(greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 20)
+ assert(!greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ // less than
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10)
+ assert(!lessLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 5)
+ assert(!lessLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ // less than and equal
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 20)
+ assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 20)
+ assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10)
+ assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ // equal too
+ valueFromQueryValueArray(0) = Bytes.toBytes(10)
+ assert(equalLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = Bytes.toBytes(5)
+ assert(!equalLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ // not equal too
+ valueFromQueryValueArray(0) = Bytes.toBytes(10)
+ assert(!notEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = Bytes.toBytes(5)
+ assert(notEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ // pass through
+ valueFromQueryValueArray(0) = Bytes.toBytes(10)
+ assert(passThrough.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = Bytes.toBytes(5)
+ assert(passThrough.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+ }
+
+ test("Double Type") {
+ val leftLogic = new LessThanLogicExpression("Col1", 0)
+ leftLogic.setEncoder(encoder)
+ val rightLogic = new GreaterThanLogicExpression("Col1", 1)
+ rightLogic.setEncoder(encoder)
+ val andLogic = new AndLogicExpression(leftLogic, rightLogic)
+
+ val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]()
+
+ columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(-4.0d)))
+ val valueFromQueryValueArray = new Array[Array[Byte]](2)
+ valueFromQueryValueArray(0) = encoder.encode(DoubleType, 15.0d)
+ valueFromQueryValueArray(1) = encoder.encode(DoubleType, -5.0d)
+ assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(DoubleType, 10.0d)
+ valueFromQueryValueArray(1) = encoder.encode(DoubleType, -1.0d)
+ assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(DoubleType, -10.0d)
+ valueFromQueryValueArray(1) = encoder.encode(DoubleType, -20.0d)
+ assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ val expressionString = andLogic.toExpressionString
+ // Note that here 0 and 1 is index, instead of value.
+ assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )"))
+
+ val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder)
+ valueFromQueryValueArray(0) = encoder.encode(DoubleType, 15.0d)
+ valueFromQueryValueArray(1) = encoder.encode(DoubleType, -5.0d)
+ assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(DoubleType, 10.0d)
+ valueFromQueryValueArray(1) = encoder.encode(DoubleType, -1.0d)
+ assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(DoubleType, -10.0d)
+ valueFromQueryValueArray(1) = encoder.encode(DoubleType, -20.0d)
+ assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+ }
+
+ test("Float Type") {
+ val leftLogic = new LessThanLogicExpression("Col1", 0)
+ leftLogic.setEncoder(encoder)
+ val rightLogic = new GreaterThanLogicExpression("Col1", 1)
+ rightLogic.setEncoder(encoder)
+ val andLogic = new AndLogicExpression(leftLogic, rightLogic)
+
+ val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]()
+
+ columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(-4.0f)))
+ val valueFromQueryValueArray = new Array[Array[Byte]](2)
+ valueFromQueryValueArray(0) = encoder.encode(FloatType, 15.0f)
+ valueFromQueryValueArray(1) = encoder.encode(FloatType, -5.0f)
+ assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(FloatType, 10.0f)
+ valueFromQueryValueArray(1) = encoder.encode(FloatType, -1.0f)
+ assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(FloatType, -10.0f)
+ valueFromQueryValueArray(1) = encoder.encode(FloatType, -20.0f)
+ assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ val expressionString = andLogic.toExpressionString
+ // Note that here 0 and 1 is index, instead of value.
+ assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )"))
+
+ val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder)
+ valueFromQueryValueArray(0) = encoder.encode(FloatType, 15.0f)
+ valueFromQueryValueArray(1) = encoder.encode(FloatType, -5.0f)
+ assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(FloatType, 10.0f)
+ valueFromQueryValueArray(1) = encoder.encode(FloatType, -1.0f)
+ assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(FloatType, -10.0f)
+ valueFromQueryValueArray(1) = encoder.encode(FloatType, -20.0f)
+ assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+ }
+
+ test("Long Type") {
+ val greaterLogic = new GreaterThanLogicExpression("Col1", 0)
+ greaterLogic.setEncoder(encoder)
+ val greaterAndEqualLogic = new GreaterThanOrEqualLogicExpression("Col1", 0)
+ greaterAndEqualLogic.setEncoder(encoder)
+ val lessLogic = new LessThanLogicExpression("Col1", 0)
+ lessLogic.setEncoder(encoder)
+ val lessAndEqualLogic = new LessThanOrEqualLogicExpression("Col1", 0)
+ lessAndEqualLogic.setEncoder(encoder)
+ val equalLogic = new EqualLogicExpression("Col1", 0, false)
+ val notEqualLogic = new EqualLogicExpression("Col1", 0, true)
+
+ val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]()
+ columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(10L)))
+ val valueFromQueryValueArray = new Array[Array[Byte]](1)
+
+ // great than
+ valueFromQueryValueArray(0) = encoder.encode(LongType, 10L)
+ assert(!greaterLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(LongType, 20L)
+ assert(!greaterLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ // great than and equal
+ valueFromQueryValueArray(0) = encoder.encode(LongType, 5L)
+ assert(greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(LongType, 10L)
+ assert(greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(LongType, 20L)
+ assert(!greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ // less than
+ valueFromQueryValueArray(0) = encoder.encode(LongType, 10L)
+ assert(!lessLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(LongType, 5L)
+ assert(!lessLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ // less than and equal
+ valueFromQueryValueArray(0) = encoder.encode(LongType, 20L)
+ assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(LongType, 20L)
+ assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(LongType, 10L)
+ assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ // equal too
+ valueFromQueryValueArray(0) = Bytes.toBytes(10L)
+ assert(equalLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = Bytes.toBytes(5L)
+ assert(!equalLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ // not equal too
+ valueFromQueryValueArray(0) = Bytes.toBytes(10L)
+ assert(!notEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = Bytes.toBytes(5L)
+ assert(notEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+ }
+
+ test("String Type") {
+ val leftLogic = new LessThanLogicExpression("Col1", 0)
+ leftLogic.setEncoder(encoder)
+ val rightLogic = new GreaterThanLogicExpression("Col1", 1)
+ rightLogic.setEncoder(encoder)
+ val andLogic = new AndLogicExpression(leftLogic, rightLogic)
+
+ val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]()
+
+ columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes("row005")))
+ val valueFromQueryValueArray = new Array[Array[Byte]](2)
+ valueFromQueryValueArray(0) = encoder.encode(StringType, "row015")
+ valueFromQueryValueArray(1) = encoder.encode(StringType, "row000")
+ assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(StringType, "row004")
+ valueFromQueryValueArray(1) = encoder.encode(StringType, "row000")
+ assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(StringType, "row020")
+ valueFromQueryValueArray(1) = encoder.encode(StringType, "row010")
+ assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ val expressionString = andLogic.toExpressionString
+ // Note that here 0 and 1 is index, instead of value.
+ assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )"))
+
+ val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder)
+ valueFromQueryValueArray(0) = encoder.encode(StringType, "row015")
+ valueFromQueryValueArray(1) = encoder.encode(StringType, "row000")
+ assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(StringType, "row004")
+ valueFromQueryValueArray(1) = encoder.encode(StringType, "row000")
+ assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+
+ valueFromQueryValueArray(0) = encoder.encode(StringType, "row020")
+ valueFromQueryValueArray(1) = encoder.encode(StringType, "row010")
+ assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+ }
+
+ test("Boolean Type") {
+ val leftLogic = new LessThanLogicExpression("Col1", 0)
+ leftLogic.setEncoder(encoder)
+ val rightLogic = new GreaterThanLogicExpression("Col1", 1)
+ rightLogic.setEncoder(encoder)
+
+ val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]()
+
+ columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(false)))
+ val valueFromQueryValueArray = new Array[Array[Byte]](2)
+ valueFromQueryValueArray(0) = encoder.encode(BooleanType, true)
+ valueFromQueryValueArray(1) = encoder.encode(BooleanType, false)
+ assert(leftLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+ assert(!rightLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray))
+ }
+}
diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseCatalogSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseCatalogSuite.scala
new file mode 100644
index 00000000..f0a4d238
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseCatalogSuite.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.spark.datasources.{DataTypeParserWrapper, DoubleSerDes, HBaseTableCatalog}
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.sql.types._
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+
+class HBaseCatalogSuite
+ extends AnyFunSuite
+ with BeforeAndAfterEach
+ with BeforeAndAfterAll
+ with Logging {
+
+ val map = s"""MAP>"""
+ val array = s"""array>"""
+ val arrayMap = s"""MAp>"""
+ val catalog = s"""{
+ |"table":{"namespace":"default", "name":"htable"},
+ |"rowkey":"key1:key2",
+ |"columns":{
+ |"col1":{"cf":"rowkey", "col":"key1", "type":"string"},
+ |"col2":{"cf":"rowkey", "col":"key2", "type":"double"},
+ |"col3":{"cf":"cf1", "col":"col2", "type":"binary"},
+ |"col4":{"cf":"cf1", "col":"col3", "type":"timestamp"},
+ |"col5":{"cf":"cf1", "col":"col4", "type":"double", "serdes":"${classOf[
+ DoubleSerDes].getName}"},
+ |"col6":{"cf":"cf1", "col":"col5", "type":"$map"},
+ |"col7":{"cf":"cf1", "col":"col6", "type":"$array"},
+ |"col8":{"cf":"cf1", "col":"col7", "type":"$arrayMap"},
+ |"col9":{"cf":"cf1", "col":"col8", "type":"date"},
+ |"col10":{"cf":"cf1", "col":"col9", "type":"timestamp"}
+ |}
+ |}""".stripMargin
+ val parameters = Map(HBaseTableCatalog.tableCatalog -> catalog)
+ val t = HBaseTableCatalog(parameters)
+
+ def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
+ test(s"parse ${dataTypeString.replace("\n", "")}") {
+ assert(DataTypeParserWrapper.parse(dataTypeString) == expectedDataType)
+ }
+ }
+ test("basic") {
+ assert(t.getField("col1").isRowKey == true)
+ assert(t.getPrimaryKey == "key1")
+ assert(t.getField("col3").dt == BinaryType)
+ assert(t.getField("col4").dt == TimestampType)
+ assert(t.getField("col5").dt == DoubleType)
+ assert(t.getField("col5").serdes != None)
+ assert(t.getField("col4").serdes == None)
+ assert(t.getField("col1").isRowKey)
+ assert(t.getField("col2").isRowKey)
+ assert(!t.getField("col3").isRowKey)
+ assert(t.getField("col2").length == Bytes.SIZEOF_DOUBLE)
+ assert(t.getField("col1").length == -1)
+ assert(t.getField("col8").length == -1)
+ assert(t.getField("col9").dt == DateType)
+ assert(t.getField("col10").dt == TimestampType)
+ }
+
+ checkDataType(map, t.getField("col6").dt)
+
+ checkDataType(array, t.getField("col7").dt)
+
+ checkDataType(arrayMap, t.getField("col8").dt)
+
+ test("convert") {
+ val m = Map(
+ "hbase.columns.mapping" ->
+ "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD DOUBLE c:b, C_FIELD BINARY c:c,",
+ "hbase.table" -> "NAMESPACE:TABLE")
+ val map = HBaseTableCatalog.convert(m)
+ val json = map.get(HBaseTableCatalog.tableCatalog).get
+ val parameters = Map(HBaseTableCatalog.tableCatalog -> json)
+ val t = HBaseTableCatalog(parameters)
+ assert(t.namespace == "NAMESPACE")
+ assert(t.name == "TABLE")
+ assert(t.getField("KEY_FIELD").isRowKey)
+ assert(DataTypeParserWrapper.parse("STRING") == t.getField("A_FIELD").dt)
+ assert(!t.getField("A_FIELD").isRowKey)
+ assert(DataTypeParserWrapper.parse("DOUBLE") == t.getField("B_FIELD").dt)
+ assert(DataTypeParserWrapper.parse("BINARY") == t.getField("C_FIELD").dt)
+ }
+
+ test("compatibility") {
+ val m = Map(
+ "hbase.columns.mapping" ->
+ "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD DOUBLE c:b, C_FIELD BINARY c:c,",
+ "hbase.table" -> "t1")
+ val t = HBaseTableCatalog(m)
+ assert(t.namespace == "default")
+ assert(t.name == "t1")
+ assert(t.getField("KEY_FIELD").isRowKey)
+ assert(DataTypeParserWrapper.parse("STRING") == t.getField("A_FIELD").dt)
+ assert(!t.getField("A_FIELD").isRowKey)
+ assert(DataTypeParserWrapper.parse("DOUBLE") == t.getField("B_FIELD").dt)
+ assert(DataTypeParserWrapper.parse("BINARY") == t.getField("C_FIELD").dt)
+ }
+}
diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCacheSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCacheSuite.scala
new file mode 100644
index 00000000..309880f7
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCacheSuite.scala
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.util.concurrent.ExecutorService
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hbase.TableName
+import org.apache.hadoop.hbase.client.{Admin, BufferedMutator, BufferedMutatorParams, Connection, RegionLocator, Table, TableBuilder}
+
+import org.scalatest.funsuite.AnyFunSuite
+
+import scala.util.Random
+
+case class HBaseConnectionKeyMocker(confId: Int) extends HBaseConnectionKey(null) {
+ override def hashCode: Int = {
+ confId
+ }
+
+ override def equals(obj: Any): Boolean = {
+ if (!obj.isInstanceOf[HBaseConnectionKeyMocker])
+ false
+ else
+ confId == obj.asInstanceOf[HBaseConnectionKeyMocker].confId
+ }
+}
+
+class ConnectionMocker extends Connection {
+ var isClosed: Boolean = false
+
+ def getRegionLocator(tableName: TableName): RegionLocator = null
+ def getConfiguration: Configuration = null
+ override def getTable(tableName: TableName): Table = null
+ override def getTable(tableName: TableName, pool: ExecutorService): Table = null
+ def getBufferedMutator(params: BufferedMutatorParams): BufferedMutator = null
+ def getBufferedMutator(tableName: TableName): BufferedMutator = null
+ def getAdmin: Admin = null
+ def getTableBuilder(tableName: TableName, pool: ExecutorService): TableBuilder = null
+
+ def close(): Unit = {
+ if (isClosed)
+ throw new IllegalStateException()
+ isClosed = true
+ }
+
+ def isAborted: Boolean = true
+ def abort(why: String, e: Throwable) = {}
+
+ /* Without override, we can also compile it against HBase 2.1. */
+ /* override */
+ def clearRegionLocationCache(): Unit = {}
+}
+
+class HBaseConnectionCacheSuite extends AnyFunSuite with Logging {
+ /*
+ * These tests must be performed sequentially as they operate with an
+ * unique running thread and resource.
+ *
+ * It looks there's no way to tell FunSuite to do so, so making those
+ * test cases normal functions which are called sequentially in a single
+ * test case.
+ */
+ test("all test cases") {
+ testBasic()
+ testWithPressureWithoutClose()
+ testWithPressureWithClose()
+ }
+
+ def cleanEnv() {
+ HBaseConnectionCache.connectionMap.synchronized {
+ HBaseConnectionCache.connectionMap.clear()
+ HBaseConnectionCache.cacheStat.numActiveConnections = 0
+ HBaseConnectionCache.cacheStat.numActualConnectionsCreated = 0
+ HBaseConnectionCache.cacheStat.numTotalRequests = 0
+ }
+ }
+
+ def testBasic() {
+ cleanEnv()
+ HBaseConnectionCache.setTimeout(1 * 1000)
+
+ val connKeyMocker1 = new HBaseConnectionKeyMocker(1)
+ val connKeyMocker1a = new HBaseConnectionKeyMocker(1)
+ val connKeyMocker2 = new HBaseConnectionKeyMocker(2)
+
+ val c1 = HBaseConnectionCache
+ .getConnection(connKeyMocker1, new ConnectionMocker)
+
+ assert(HBaseConnectionCache.connectionMap.size == 1)
+ assert(HBaseConnectionCache.getStat.numTotalRequests == 1)
+ assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 1)
+ assert(HBaseConnectionCache.getStat.numActiveConnections == 1)
+
+ val c1a = HBaseConnectionCache
+ .getConnection(connKeyMocker1a, new ConnectionMocker)
+
+ HBaseConnectionCache.connectionMap.synchronized {
+ assert(HBaseConnectionCache.connectionMap.size == 1)
+ assert(HBaseConnectionCache.getStat.numTotalRequests == 2)
+ assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 1)
+ assert(HBaseConnectionCache.getStat.numActiveConnections == 1)
+ }
+
+ val c2 = HBaseConnectionCache
+ .getConnection(connKeyMocker2, new ConnectionMocker)
+
+ HBaseConnectionCache.connectionMap.synchronized {
+ assert(HBaseConnectionCache.connectionMap.size == 2)
+ assert(HBaseConnectionCache.getStat.numTotalRequests == 3)
+ assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 2)
+ assert(HBaseConnectionCache.getStat.numActiveConnections == 2)
+ }
+
+ c1.close()
+ HBaseConnectionCache.connectionMap.synchronized {
+ assert(HBaseConnectionCache.connectionMap.size == 2)
+ assert(HBaseConnectionCache.getStat.numActiveConnections == 2)
+ }
+
+ c1a.close()
+ HBaseConnectionCache.connectionMap.synchronized {
+ assert(HBaseConnectionCache.connectionMap.size == 2)
+ assert(HBaseConnectionCache.getStat.numActiveConnections == 2)
+ }
+
+ Thread.sleep(3 * 1000) // Leave housekeeping thread enough time
+ HBaseConnectionCache.connectionMap.synchronized {
+ assert(HBaseConnectionCache.connectionMap.size == 1)
+ assert(
+ HBaseConnectionCache.connectionMap.iterator
+ .next()
+ ._1
+ .asInstanceOf[HBaseConnectionKeyMocker]
+ .confId == 2)
+ assert(HBaseConnectionCache.getStat.numActiveConnections == 1)
+ }
+
+ c2.close()
+ }
+
+ def testWithPressureWithoutClose() {
+ cleanEnv()
+
+ class TestThread extends Runnable {
+ override def run() {
+ for (i <- 0 to 999) {
+ val c = HBaseConnectionCache.getConnection(
+ new HBaseConnectionKeyMocker(Random.nextInt(10)),
+ new ConnectionMocker)
+ }
+ }
+ }
+
+ HBaseConnectionCache.setTimeout(500)
+ val threads: Array[Thread] = new Array[Thread](100)
+ for (i <- 0 to 99) {
+ threads.update(i, new Thread(new TestThread()))
+ threads(i).run()
+ }
+ try {
+ threads.foreach { x => x.join() }
+ } catch {
+ case e: InterruptedException => println(e.getMessage)
+ }
+
+ Thread.sleep(1000)
+ HBaseConnectionCache.connectionMap.synchronized {
+ assert(HBaseConnectionCache.connectionMap.size == 10)
+ assert(HBaseConnectionCache.getStat.numTotalRequests == 100 * 1000)
+ assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 10)
+ assert(HBaseConnectionCache.getStat.numActiveConnections == 10)
+
+ var totalRc: Int = 0
+ HBaseConnectionCache.connectionMap.foreach { x => totalRc += x._2.refCount }
+ assert(totalRc == 100 * 1000)
+ HBaseConnectionCache.connectionMap.foreach {
+ x =>
+ {
+ x._2.refCount = 0
+ x._2.timestamp = System.currentTimeMillis() - 1000
+ }
+ }
+ }
+ Thread.sleep(1000)
+ assert(HBaseConnectionCache.connectionMap.size == 0)
+ assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 10)
+ assert(HBaseConnectionCache.getStat.numActiveConnections == 0)
+ }
+
+ def testWithPressureWithClose() {
+ cleanEnv()
+
+ class TestThread extends Runnable {
+ override def run() {
+ for (i <- 0 to 999) {
+ val c = HBaseConnectionCache.getConnection(
+ new HBaseConnectionKeyMocker(Random.nextInt(10)),
+ new ConnectionMocker)
+ Thread.`yield`()
+ c.close()
+ }
+ }
+ }
+
+ HBaseConnectionCache.setTimeout(3 * 1000)
+ val threads: Array[Thread] = new Array[Thread](100)
+ for (i <- threads.indices) {
+ threads.update(i, new Thread(new TestThread()))
+ threads(i).run()
+ }
+ try {
+ threads.foreach { x => x.join() }
+ } catch {
+ case e: InterruptedException => println(e.getMessage)
+ }
+
+ HBaseConnectionCache.connectionMap.synchronized {
+ assert(HBaseConnectionCache.connectionMap.size == 10)
+ assert(HBaseConnectionCache.getStat.numTotalRequests == 100 * 1000)
+ assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 10)
+ assert(HBaseConnectionCache.getStat.numActiveConnections == 10)
+ }
+
+ Thread.sleep(6 * 1000)
+ HBaseConnectionCache.connectionMap.synchronized {
+ assert(HBaseConnectionCache.connectionMap.size == 0)
+ assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 10)
+ assert(HBaseConnectionCache.getStat.numActiveConnections == 0)
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala
new file mode 100644
index 00000000..a45988be
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala
@@ -0,0 +1,400 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.{CellUtil, HBaseTestingUtility, TableName}
+import org.apache.hadoop.hbase.client._
+import org.apache.hadoop.hbase.filter.FirstKeyOnlyFilter
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.{SparkContext, SparkException}
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+
+class HBaseContextSuite
+ extends AnyFunSuite
+ with BeforeAndAfterEach
+ with BeforeAndAfterAll
+ with Logging {
+
+ @transient var sc: SparkContext = null
+ var hbaseContext: HBaseContext = null
+ var TEST_UTIL = new HBaseTestingUtility
+
+ val tableName = "t1"
+ val columnFamily = "c"
+
+ override def beforeAll() {
+ TEST_UTIL.startMiniCluster()
+ logInfo(" - minicluster started")
+
+ try {
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName))
+ } catch {
+ case e: Exception =>
+ logInfo(" - no table " + tableName + " found")
+ }
+ logInfo(" - creating table " + tableName)
+ TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily))
+ logInfo(" - created table")
+
+ val envMap = Map[String, String](("Xmx", "512m"))
+
+ sc = new SparkContext("local", "test", null, Nil, envMap)
+
+ val config = TEST_UTIL.getConfiguration
+ hbaseContext = new HBaseContext(sc, config)
+ }
+
+ override def afterAll() {
+ logInfo("shuting down minicluster")
+ TEST_UTIL.shutdownMiniCluster()
+ logInfo(" - minicluster shut down")
+ TEST_UTIL.cleanupTestDir()
+ sc.stop()
+ }
+
+ test("bulkput to test HBase client") {
+ val config = TEST_UTIL.getConfiguration
+ val rdd = sc.parallelize(
+ Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
+ (
+ Bytes.toBytes("1"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")))),
+ (
+ Bytes.toBytes("2"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("foo2")))),
+ (
+ Bytes.toBytes("3"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("c"), Bytes.toBytes("foo3")))),
+ (
+ Bytes.toBytes("4"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("d"), Bytes.toBytes("foo")))),
+ (
+ Bytes.toBytes("5"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("e"), Bytes.toBytes("bar"))))))
+
+ hbaseContext.bulkPut[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
+ rdd,
+ TableName.valueOf(tableName),
+ (putRecord) => {
+ val put = new Put(putRecord._1)
+ putRecord._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3))
+ put
+ })
+
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ val foo1 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("1")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a"))))
+ assert(foo1 == "foo1")
+
+ val foo2 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("2")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("b"))))
+ assert(foo2 == "foo2")
+
+ val foo3 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("3")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("c"))))
+ assert(foo3 == "foo3")
+
+ val foo4 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("4")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("d"))))
+ assert(foo4 == "foo")
+
+ val foo5 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("5")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("e"))))
+ assert(foo5 == "bar")
+
+ } finally {
+ table.close()
+ connection.close()
+ }
+ }
+
+ test("bulkDelete to test HBase client") {
+ val config = TEST_UTIL.getConfiguration
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ var put = new Put(Bytes.toBytes("delete1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("delete2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("delete3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ table.put(put)
+
+ val rdd = sc.parallelize(Array[Array[Byte]](Bytes.toBytes("delete1"), Bytes.toBytes("delete3")))
+
+ hbaseContext.bulkDelete[Array[Byte]](
+ rdd,
+ TableName.valueOf(tableName),
+ putRecord => new Delete(putRecord),
+ 4)
+
+ assert(
+ table
+ .get(new Get(Bytes.toBytes("delete1")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")) == null)
+ assert(
+ table
+ .get(new Get(Bytes.toBytes("delete3")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")) == null)
+ assert(
+ Bytes
+ .toString(
+ CellUtil.cloneValue(table
+ .get(new Get(Bytes.toBytes("delete2")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a"))))
+ .equals("foo2"))
+ } finally {
+ table.close()
+ connection.close()
+ }
+ }
+
+ test("bulkGet to test HBase client") {
+ val config = TEST_UTIL.getConfiguration
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ var put = new Put(Bytes.toBytes("get1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ table.put(put)
+ } finally {
+ table.close()
+ connection.close()
+ }
+ val rdd = sc.parallelize(
+ Array[Array[Byte]](
+ Bytes.toBytes("get1"),
+ Bytes.toBytes("get2"),
+ Bytes.toBytes("get3"),
+ Bytes.toBytes("get4")))
+
+ val getRdd = hbaseContext.bulkGet[Array[Byte], String](
+ TableName.valueOf(tableName),
+ 2,
+ rdd,
+ record => {
+ new Get(record)
+ },
+ (result: Result) => {
+ if (result.listCells() != null) {
+ val it = result.listCells().iterator()
+ val B = new StringBuilder
+
+ B.append(Bytes.toString(result.getRow) + ":")
+
+ while (it.hasNext) {
+ val cell = it.next()
+ val q = Bytes.toString(CellUtil.cloneQualifier(cell))
+ if (q.equals("counter")) {
+ B.append("(" + q + "," + Bytes.toLong(CellUtil.cloneValue(cell)) + ")")
+ } else {
+ B.append("(" + q + "," + Bytes.toString(CellUtil.cloneValue(cell)) + ")")
+ }
+ }
+ "" + B.toString
+ } else {
+ ""
+ }
+ })
+ val getArray = getRdd.collect()
+
+ assert(getArray.length == 4)
+ assert(getArray.contains("get1:(a,foo1)"))
+ assert(getArray.contains("get2:(a,foo2)"))
+ assert(getArray.contains("get3:(a,foo3)"))
+
+ }
+
+ test("BulkGet failure test: bad table") {
+ val config = TEST_UTIL.getConfiguration
+
+ val rdd = sc.parallelize(
+ Array[Array[Byte]](
+ Bytes.toBytes("get1"),
+ Bytes.toBytes("get2"),
+ Bytes.toBytes("get3"),
+ Bytes.toBytes("get4")))
+
+ intercept[SparkException] {
+ try {
+ val getRdd = hbaseContext.bulkGet[Array[Byte], String](
+ TableName.valueOf("badTableName"),
+ 2,
+ rdd,
+ record => {
+ new Get(record)
+ },
+ (result: Result) => "1")
+
+ getRdd.collect()
+
+ fail("We should have failed and not reached this line")
+ } catch {
+ case ex: SparkException => {
+ assert(
+ ex.getMessage.contains(
+ "org.apache.hadoop.hbase.client.RetriesExhaustedWithDetailsException"))
+ throw ex
+ }
+ }
+ }
+ }
+
+ test("BulkGet failure test: bad column") {
+
+ val config = TEST_UTIL.getConfiguration
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ var put = new Put(Bytes.toBytes("get1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ table.put(put)
+ } finally {
+ table.close()
+ connection.close()
+ }
+
+ val rdd = sc.parallelize(
+ Array[Array[Byte]](
+ Bytes.toBytes("get1"),
+ Bytes.toBytes("get2"),
+ Bytes.toBytes("get3"),
+ Bytes.toBytes("get4")))
+
+ val getRdd = hbaseContext.bulkGet[Array[Byte], String](
+ TableName.valueOf(tableName),
+ 2,
+ rdd,
+ record => {
+ new Get(record)
+ },
+ (result: Result) => {
+ if (result.listCells() != null) {
+ val cellValue =
+ result.getColumnLatestCell(Bytes.toBytes("c"), Bytes.toBytes("bad_column"))
+ if (cellValue == null) "null" else "bad"
+ } else "noValue"
+ })
+ var nullCounter = 0
+ var noValueCounter = 0
+ getRdd
+ .collect()
+ .foreach(
+ r => {
+ if ("null".equals(r)) nullCounter += 1
+ else if ("noValue".equals(r)) noValueCounter += 1
+ })
+ assert(nullCounter == 3)
+ assert(noValueCounter == 1)
+ }
+
+ test("distributedScan to test HBase client") {
+ val config = TEST_UTIL.getConfiguration
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ var put = new Put(Bytes.toBytes("scan1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("scan2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("scan2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("foo-2"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("scan3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("scan4"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("scan5"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ table.put(put)
+ } finally {
+ table.close()
+ connection.close()
+ }
+
+ val scan = new Scan()
+ val filter = new FirstKeyOnlyFilter()
+ scan.setCaching(100)
+ scan.setStartRow(Bytes.toBytes("scan2"))
+ scan.setStopRow(Bytes.toBytes("scan4_"))
+ scan.setFilter(filter)
+
+ val scanRdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan)
+
+ try {
+ val scanList = scanRdd
+ .map(r => r._1.copyBytes())
+ .collect()
+ assert(scanList.length == 3)
+ var cnt = 0
+ scanRdd
+ .map(r => r._2.listCells().size())
+ .collect()
+ .foreach(
+ l => {
+ cnt += l
+ })
+ // the number of cells returned would be 4 without the Filter
+ assert(cnt == 3);
+ } catch {
+ case ex: Exception => ex.printStackTrace()
+ }
+ }
+}
diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseDStreamFunctionsSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseDStreamFunctionsSuite.scala
new file mode 100644
index 00000000..d7e99267
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseDStreamFunctionsSuite.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.{CellUtil, HBaseTestingUtility, TableName}
+import org.apache.hadoop.hbase.client._
+import org.apache.hadoop.hbase.spark.HBaseDStreamFunctions._
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.{Milliseconds, StreamingContext}
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+
+import scala.collection.mutable
+
+class HBaseDStreamFunctionsSuite
+ extends AnyFunSuite
+ with BeforeAndAfterEach
+ with BeforeAndAfterAll
+ with Logging {
+ @transient var sc: SparkContext = null
+
+ var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility
+
+ val tableName = "t1"
+ val columnFamily = "c"
+
+ override def beforeAll() {
+
+ TEST_UTIL.startMiniCluster()
+
+ logInfo(" - minicluster started")
+ try
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName))
+ catch {
+ case e: Exception => logInfo(" - no table " + tableName + " found")
+
+ }
+ logInfo(" - creating table " + tableName)
+ TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily))
+ logInfo(" - created table")
+
+ sc = new SparkContext("local", "test")
+ }
+
+ override def afterAll() {
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName))
+ TEST_UTIL.shutdownMiniCluster()
+ sc.stop()
+ }
+
+ test("bulkput to test HBase client") {
+ val config = TEST_UTIL.getConfiguration
+ val rdd1 = sc.parallelize(
+ Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
+ (
+ Bytes.toBytes("1"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")))),
+ (
+ Bytes.toBytes("2"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("foo2")))),
+ (
+ Bytes.toBytes("3"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("c"), Bytes.toBytes("foo3"))))))
+
+ val rdd2 = sc.parallelize(
+ Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
+ (
+ Bytes.toBytes("4"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("d"), Bytes.toBytes("foo")))),
+ (
+ Bytes.toBytes("5"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("e"), Bytes.toBytes("bar"))))))
+
+ var isFinished = false
+
+ val hbaseContext = new HBaseContext(sc, config)
+ val ssc = new StreamingContext(sc, Milliseconds(200))
+
+ val queue = mutable.Queue[RDD[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]]()
+ queue += rdd1
+ queue += rdd2
+ val dStream = ssc.queueStream(queue)
+
+ dStream.hbaseBulkPut(
+ hbaseContext,
+ TableName.valueOf(tableName),
+ (putRecord) => {
+ val put = new Put(putRecord._1)
+ putRecord._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3))
+ put
+ })
+
+ dStream.foreachRDD(
+ rdd => {
+ if (rdd.count() == 0) {
+ isFinished = true
+ }
+ })
+
+ ssc.start()
+
+ while (!isFinished) {
+ Thread.sleep(100)
+ }
+
+ ssc.stop(true, true)
+
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ val foo1 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("1")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a"))))
+ assert(foo1 == "foo1")
+
+ val foo2 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("2")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("b"))))
+ assert(foo2 == "foo2")
+
+ val foo3 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("3")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("c"))))
+ assert(foo3 == "foo3")
+
+ val foo4 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("4")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("d"))))
+ assert(foo4 == "foo")
+
+ val foo5 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("5")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("e"))))
+ assert(foo5 == "bar")
+ } finally {
+ table.close()
+ connection.close()
+ }
+ }
+
+}
diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseRDDFunctionsSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseRDDFunctionsSuite.scala
new file mode 100644
index 00000000..900a003d
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseRDDFunctionsSuite.scala
@@ -0,0 +1,472 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.{CellUtil, HBaseTestingUtility, TableName}
+import org.apache.hadoop.hbase.client._
+import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.SparkContext
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+
+import scala.collection.mutable
+
+class HBaseRDDFunctionsSuite
+ extends AnyFunSuite
+ with BeforeAndAfterEach
+ with BeforeAndAfterAll
+ with Logging {
+ @transient var sc: SparkContext = null
+ var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility
+
+ val tableName = "t1"
+ val columnFamily = "c"
+
+ override def beforeAll() {
+
+ TEST_UTIL.startMiniCluster
+
+ logInfo(" - minicluster started")
+ try
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName))
+ catch {
+ case e: Exception => logInfo(" - no table " + tableName + " found")
+
+ }
+ logInfo(" - creating table " + tableName)
+ TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily))
+ logInfo(" - created table")
+
+ sc = new SparkContext("local", "test")
+ }
+
+ override def afterAll() {
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName))
+ logInfo("shuting down minicluster")
+ TEST_UTIL.shutdownMiniCluster()
+
+ sc.stop()
+ }
+
+ test("bulkput to test HBase client") {
+ val config = TEST_UTIL.getConfiguration
+ val rdd = sc.parallelize(
+ Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
+ (
+ Bytes.toBytes("1"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")))),
+ (
+ Bytes.toBytes("2"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("foo2")))),
+ (
+ Bytes.toBytes("3"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("c"), Bytes.toBytes("foo3")))),
+ (
+ Bytes.toBytes("4"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("d"), Bytes.toBytes("foo")))),
+ (
+ Bytes.toBytes("5"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("e"), Bytes.toBytes("bar"))))))
+
+ val hbaseContext = new HBaseContext(sc, config)
+
+ rdd.hbaseBulkPut(
+ hbaseContext,
+ TableName.valueOf(tableName),
+ (putRecord) => {
+ val put = new Put(putRecord._1)
+ putRecord._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3))
+ put
+ })
+
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ val foo1 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("1")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a"))))
+ assert(foo1 == "foo1")
+
+ val foo2 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("2")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("b"))))
+ assert(foo2 == "foo2")
+
+ val foo3 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("3")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("c"))))
+ assert(foo3 == "foo3")
+
+ val foo4 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("4")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("d"))))
+ assert(foo4 == "foo")
+
+ val foo5 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("5")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("e"))))
+ assert(foo5 == "bar")
+ } finally {
+ table.close()
+ connection.close()
+ }
+ }
+
+ test("bulkDelete to test HBase client") {
+ val config = TEST_UTIL.getConfiguration
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ var put = new Put(Bytes.toBytes("delete1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("delete2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("delete3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ table.put(put)
+
+ val rdd = sc.parallelize(Array[Array[Byte]](Bytes.toBytes("delete1"), Bytes.toBytes("delete3")))
+
+ val hbaseContext = new HBaseContext(sc, config)
+
+ rdd.hbaseBulkDelete(
+ hbaseContext,
+ TableName.valueOf(tableName),
+ putRecord => new Delete(putRecord),
+ 4)
+
+ assert(
+ table
+ .get(new Get(Bytes.toBytes("delete1")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")) == null)
+ assert(
+ table
+ .get(new Get(Bytes.toBytes("delete3")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")) == null)
+ assert(
+ Bytes
+ .toString(
+ CellUtil.cloneValue(table
+ .get(new Get(Bytes.toBytes("delete2")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a"))))
+ .equals("foo2"))
+ } finally {
+ table.close()
+ connection.close()
+ }
+
+ }
+
+ test("bulkGet to test HBase client") {
+ val config = TEST_UTIL.getConfiguration
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ var put = new Put(Bytes.toBytes("get1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ table.put(put)
+ } finally {
+ table.close()
+ connection.close()
+ }
+
+ val rdd = sc.parallelize(
+ Array[Array[Byte]](
+ Bytes.toBytes("get1"),
+ Bytes.toBytes("get2"),
+ Bytes.toBytes("get3"),
+ Bytes.toBytes("get4")))
+ val hbaseContext = new HBaseContext(sc, config)
+
+ // Get with custom convert logic
+ val getRdd = rdd.hbaseBulkGet[String](
+ hbaseContext,
+ TableName.valueOf(tableName),
+ 2,
+ record => {
+ new Get(record)
+ },
+ (result: Result) => {
+ if (result.listCells() != null) {
+ val it = result.listCells().iterator()
+ val B = new StringBuilder
+
+ B.append(Bytes.toString(result.getRow) + ":")
+
+ while (it.hasNext) {
+ val cell = it.next
+ val q = Bytes.toString(CellUtil.cloneQualifier(cell))
+ if (q.equals("counter")) {
+ B.append("(" + q + "," + Bytes.toLong(CellUtil.cloneValue(cell)) + ")")
+ } else {
+ B.append("(" + q + "," + Bytes.toString(CellUtil.cloneValue(cell)) + ")")
+ }
+ }
+ "" + B.toString
+ } else {
+ ""
+ }
+ })
+
+ val getArray = getRdd.collect()
+
+ assert(getArray.length == 4)
+ assert(getArray.contains("get1:(a,foo1)"))
+ assert(getArray.contains("get2:(a,foo2)"))
+ assert(getArray.contains("get3:(a,foo3)"))
+ }
+
+ test("bulkGet default converter to test HBase client") {
+ val config = TEST_UTIL.getConfiguration
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ var put = new Put(Bytes.toBytes("get1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ table.put(put)
+ } finally {
+ table.close()
+ connection.close()
+ }
+
+ val rdd = sc.parallelize(
+ Array[Array[Byte]](
+ Bytes.toBytes("get1"),
+ Bytes.toBytes("get2"),
+ Bytes.toBytes("get3"),
+ Bytes.toBytes("get4")))
+ val hbaseContext = new HBaseContext(sc, config)
+
+ val getRdd = rdd
+ .hbaseBulkGet(
+ hbaseContext,
+ TableName.valueOf("t1"),
+ 2,
+ record => {
+ new Get(record)
+ })
+ .map(
+ (row) => {
+ if (row != null && row._2.listCells() != null) {
+ val it = row._2.listCells().iterator()
+ val B = new StringBuilder
+
+ B.append(Bytes.toString(row._2.getRow) + ":")
+
+ while (it.hasNext) {
+ val cell = it.next
+ val q = Bytes.toString(CellUtil.cloneQualifier(cell))
+ if (q.equals("counter")) {
+ B.append("(" + q + "," + Bytes.toLong(CellUtil.cloneValue(cell)) + ")")
+ } else {
+ B.append("(" + q + "," + Bytes.toString(CellUtil.cloneValue(cell)) + ")")
+ }
+ }
+ "" + B.toString
+ } else {
+ ""
+ }
+ })
+
+ val getArray = getRdd.collect()
+
+ assert(getArray.length == 4)
+ assert(getArray.contains("get1:(a,foo1)"))
+ assert(getArray.contains("get2:(a,foo2)"))
+ assert(getArray.contains("get3:(a,foo3)"))
+ }
+
+ test("foreachPartition with puts to test HBase client") {
+ val config = TEST_UTIL.getConfiguration
+ val rdd = sc.parallelize(
+ Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
+ (
+ Bytes.toBytes("1foreach"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")))),
+ (
+ Bytes.toBytes("2foreach"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("foo2")))),
+ (
+ Bytes.toBytes("3foreach"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("c"), Bytes.toBytes("foo3")))),
+ (
+ Bytes.toBytes("4foreach"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("d"), Bytes.toBytes("foo")))),
+ (
+ Bytes.toBytes("5foreach"),
+ Array((Bytes.toBytes(columnFamily), Bytes.toBytes("e"), Bytes.toBytes("bar"))))))
+
+ val hbaseContext = new HBaseContext(sc, config)
+
+ rdd.hbaseForeachPartition(
+ hbaseContext,
+ (it, conn) => {
+ val bufferedMutator = conn.getBufferedMutator(TableName.valueOf("t1"))
+ it.foreach(
+ (putRecord) => {
+ val put = new Put(putRecord._1)
+ putRecord._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3))
+ bufferedMutator.mutate(put)
+ })
+ bufferedMutator.flush()
+ bufferedMutator.close()
+ })
+
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ val foo1 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("1foreach")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a"))))
+ assert(foo1 == "foo1")
+
+ val foo2 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("2foreach")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("b"))))
+ assert(foo2 == "foo2")
+
+ val foo3 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("3foreach")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("c"))))
+ assert(foo3 == "foo3")
+
+ val foo4 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("4foreach")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("d"))))
+ assert(foo4 == "foo")
+
+ val foo5 = Bytes.toString(
+ CellUtil.cloneValue(
+ table
+ .get(new Get(Bytes.toBytes("5")))
+ .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("e"))))
+ assert(foo5 == "bar")
+ } finally {
+ table.close()
+ connection.close()
+ }
+ }
+
+ test("mapPartitions with Get from test HBase client") {
+ val config = TEST_UTIL.getConfiguration
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ var put = new Put(Bytes.toBytes("get1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ table.put(put)
+ } finally {
+ table.close()
+ connection.close()
+ }
+
+ val rdd = sc.parallelize(
+ Array[Array[Byte]](
+ Bytes.toBytes("get1"),
+ Bytes.toBytes("get2"),
+ Bytes.toBytes("get3"),
+ Bytes.toBytes("get4")))
+ val hbaseContext = new HBaseContext(sc, config)
+
+ // Get with custom convert logic
+ val getRdd = rdd.hbaseMapPartitions(
+ hbaseContext,
+ (it, conn) => {
+ val table = conn.getTable(TableName.valueOf("t1"))
+ var res = mutable.ListBuffer[String]()
+
+ it.foreach(
+ r => {
+ val get = new Get(r)
+ val result = table.get(get)
+ if (result.listCells != null) {
+ val it = result.listCells().iterator()
+ val B = new StringBuilder
+
+ B.append(Bytes.toString(result.getRow) + ":")
+
+ while (it.hasNext) {
+ val cell = it.next()
+ val q = Bytes.toString(CellUtil.cloneQualifier(cell))
+ if (q.equals("counter")) {
+ B.append("(" + q + "," + Bytes.toLong(CellUtil.cloneValue(cell)) + ")")
+ } else {
+ B.append("(" + q + "," + Bytes.toString(CellUtil.cloneValue(cell)) + ")")
+ }
+ }
+ res += "" + B.toString
+ } else {
+ res += ""
+ }
+ })
+ res.iterator
+ })
+
+ val getArray = getRdd.collect()
+
+ assert(getArray.length == 4)
+ assert(getArray.contains("get1:(a,foo1)"))
+ assert(getArray.contains("get2:(a,foo2)"))
+ assert(getArray.contains("get3:(a,foo3)"))
+ }
+}
diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseTestSource.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseTestSource.scala
new file mode 100644
index 00000000..a27876f1
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseTestSource.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf
+import org.apache.spark.SparkEnv
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+
+class HBaseTestSource extends RelationProvider {
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String]): BaseRelation = {
+ DummyScan(
+ parameters("cacheSize").toInt,
+ parameters("batchNum").toInt,
+ parameters("blockCacheingEnable").toBoolean,
+ parameters("rowNum").toInt)(sqlContext)
+ }
+}
+
+case class DummyScan(cacheSize: Int, batchNum: Int, blockCachingEnable: Boolean, rowNum: Int)(
+ @transient val sqlContext: SQLContext)
+ extends BaseRelation
+ with TableScan {
+ private def sparkConf = SparkEnv.get.conf
+ override def schema: StructType =
+ StructType(StructField("i", IntegerType, nullable = false) :: Nil)
+
+ override def buildScan(): RDD[Row] = sqlContext.sparkContext
+ .parallelize(0 until rowNum)
+ .map(Row(_))
+ .map {
+ x =>
+ if (sparkConf.getInt(HBaseSparkConf.QUERY_BATCHSIZE, -1) != batchNum ||
+ sparkConf.getInt(HBaseSparkConf.QUERY_CACHEDROWS, -1) != cacheSize ||
+ sparkConf.getBoolean(HBaseSparkConf.QUERY_CACHEBLOCKS, false) != blockCachingEnable) {
+ throw new Exception("HBase Spark configuration cannot be set properly")
+ }
+ x
+ }
+}
diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/PartitionFilterSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/PartitionFilterSuite.scala
new file mode 100644
index 00000000..7718fa40
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/PartitionFilterSuite.scala
@@ -0,0 +1,533 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.{HBaseTestingUtility, TableName}
+import org.apache.hadoop.hbase.spark.datasources.{HBaseSparkConf, HBaseTableCatalog}
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+
+import scala.reflect.ClassTag
+
+case class FilterRangeRecord(
+ intCol0: Int,
+ boolCol1: Boolean,
+ doubleCol2: Double,
+ floatCol3: Float,
+ intCol4: Int,
+ longCol5: Long,
+ shortCol6: Short,
+ stringCol7: String,
+ byteCol8: Byte)
+
+object FilterRangeRecord {
+ def apply(i: Int): FilterRangeRecord = {
+ FilterRangeRecord(
+ if (i % 2 == 0) i else -i,
+ i % 2 == 0,
+ if (i % 2 == 0) i.toDouble else -i.toDouble,
+ i.toFloat,
+ if (i % 2 == 0) i else -i,
+ i.toLong,
+ i.toShort,
+ s"String$i extra",
+ i.toByte)
+ }
+}
+
+class PartitionFilterSuite
+ extends AnyFunSuite
+ with BeforeAndAfterEach
+ with BeforeAndAfterAll
+ with Logging {
+ @transient var sc: SparkContext = null
+ var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility
+
+ var sqlContext: SQLContext = null
+ var df: DataFrame = null
+
+ def withCatalog(cat: String): DataFrame = {
+ sqlContext.read
+ .options(Map(HBaseTableCatalog.tableCatalog -> cat))
+ .format("org.apache.hadoop.hbase.spark")
+ .load()
+ }
+
+ override def beforeAll() {
+
+ TEST_UTIL.startMiniCluster
+ val sparkConf = new SparkConf
+ sparkConf.set(HBaseSparkConf.QUERY_CACHEBLOCKS, "true")
+ sparkConf.set(HBaseSparkConf.QUERY_BATCHSIZE, "100")
+ sparkConf.set(HBaseSparkConf.QUERY_CACHEDROWS, "100")
+
+ sc = new SparkContext("local", "test", sparkConf)
+ new HBaseContext(sc, TEST_UTIL.getConfiguration)
+ sqlContext = new SQLContext(sc)
+ }
+
+ override def afterAll() {
+ logInfo("shutting down minicluster")
+ TEST_UTIL.shutdownMiniCluster()
+
+ sc.stop()
+ }
+
+ override def beforeEach(): Unit = {
+ DefaultSourceStaticUtils.lastFiveExecutionRules.clear()
+ }
+
+ // The original raw data used for construct result set without going through
+ // data frame logic. It is used to verify the result set retrieved from data frame logic.
+ val rawResult = (0 until 32).map { i => FilterRangeRecord(i) }
+
+ def collectToSet[T: ClassTag](df: DataFrame): Set[T] = {
+ df.collect().map(_.getAs[T](0)).toSet
+ }
+ val catalog = s"""{
+ |"table":{"namespace":"default", "name":"rangeTable"},
+ |"rowkey":"key",
+ |"columns":{
+ |"intCol0":{"cf":"rowkey", "col":"key", "type":"int"},
+ |"boolCol1":{"cf":"cf1", "col":"boolCol1", "type":"boolean"},
+ |"doubleCol2":{"cf":"cf2", "col":"doubleCol2", "type":"double"},
+ |"floatCol3":{"cf":"cf3", "col":"floatCol3", "type":"float"},
+ |"intCol4":{"cf":"cf4", "col":"intCol4", "type":"int"},
+ |"longCol5":{"cf":"cf5", "col":"longCol5", "type":"bigint"},
+ |"shortCol6":{"cf":"cf6", "col":"shortCol6", "type":"smallint"},
+ |"stringCol7":{"cf":"cf7", "col":"stringCol7", "type":"string"},
+ |"byteCol8":{"cf":"cf8", "col":"byteCol8", "type":"tinyint"}
+ |}
+ |}""".stripMargin
+
+ test("populate rangeTable") {
+ val sql = sqlContext
+ import sql.implicits._
+
+ sc.parallelize(rawResult)
+ .toDF
+ .write
+ .options(Map(HBaseTableCatalog.tableCatalog -> catalog, HBaseTableCatalog.newTable -> "5"))
+ .format("org.apache.hadoop.hbase.spark")
+ .save()
+ }
+ test("rangeTable full query") {
+ val df = withCatalog(catalog)
+ df.show
+ assert(df.count() === 32)
+ }
+
+ /**
+ * expected result: only showing top 20 rows
+ * +-------+
+ * |intCol0|
+ * +-------+
+ * | -31 |
+ * | -29 |
+ * | -27 |
+ * | -25 |
+ * | -23 |
+ * | -21 |
+ * | -19 |
+ * | -17 |
+ * | -15 |
+ * | -13 |
+ * | -11 |
+ * | -9 |
+ * | -7 |
+ * | -5 |
+ * | -3 |
+ * | -1 |
+ * +---- +
+ */
+ test("rangeTable rowkey less than 0") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(catalog)
+ val s = df.filter($"intCol0" < 0).select($"intCol0")
+ s.show
+ // filter results without going through dataframe
+ val expected = rawResult.filter(_.intCol0 < 0).map(_.intCol0).toSet
+ // filter results going through dataframe
+ val result = collectToSet[Int](s)
+ assert(expected === result)
+ }
+
+ /**
+ * expected result: only showing top 20 rows
+ * +-------+
+ * |intCol4|
+ * +-------+
+ * | -31 |
+ * | -29 |
+ * | -27 |
+ * | -25 |
+ * | -23 |
+ * | -21 |
+ * | -19 |
+ * | -17 |
+ * | -15 |
+ * | -13 |
+ * | -11 |
+ * | -9 |
+ * | -7 |
+ * | -5 |
+ * | -3 |
+ * | -1 |
+ * +-------+
+ */
+ test("rangeTable int col less than 0") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(catalog)
+ val s = df.filter($"intCol4" < 0).select($"intCol4")
+ s.show
+ // filter results without going through dataframe
+ val expected = rawResult.filter(_.intCol4 < 0).map(_.intCol4).toSet
+ // filter results going through dataframe
+ val result = collectToSet[Int](s)
+ assert(expected === result)
+ }
+
+ /**
+ * expected result: only showing top 20 rows
+ * +-----------+
+ * | doubleCol2|
+ * +-----------+
+ * | 0.0 |
+ * | 2.0 |
+ * |-31.0 |
+ * |-29.0 |
+ * |-27.0 |
+ * |-25.0 |
+ * |-23.0 |
+ * |-21.0 |
+ * |-19.0 |
+ * |-17.0 |
+ * |-15.0 |
+ * |-13.0 |
+ * |-11.0 |
+ * | -9.0 |
+ * | -7.0 |
+ * | -5.0 |
+ * | -3.0 |
+ * | -1.0 |
+ * +-----------+
+ */
+ test("rangeTable double col less than 0") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(catalog)
+ val s = df.filter($"doubleCol2" < 3.0).select($"doubleCol2")
+ s.show
+ // filter results without going through dataframe
+ val expected = rawResult.filter(_.doubleCol2 < 3.0).map(_.doubleCol2).toSet
+ // filter results going through dataframe
+ val result = collectToSet[Double](s)
+ assert(expected === result)
+ }
+
+ /**
+ * expected result: only showing top 20 rows
+ * +-------+
+ * |intCol0|
+ * +-------+
+ * | -31 |
+ * | -29 |
+ * | -27 |
+ * | -25 |
+ * | -23 |
+ * | -21 |
+ * | -19 |
+ * | -17 |
+ * | -15 |
+ * | -13 |
+ * | -11 |
+ * +-------+
+ */
+ test("rangeTable lessequal than -10") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(catalog)
+ val s = df.filter($"intCol0" <= -10).select($"intCol0")
+ s.show
+ // filter results without going through dataframe
+ val expected = rawResult.filter(_.intCol0 <= -10).map(_.intCol0).toSet
+ // filter results going through dataframe
+ val result = collectToSet[Int](s)
+ assert(expected === result)
+ }
+
+ /**
+ * expected result: only showing top 20 rows
+ * +-------+
+ * |intCol0|
+ * +----+
+ * | -31 |
+ * | -29 |
+ * | -27 |
+ * | -25 |
+ * | -23 |
+ * | -21 |
+ * | -19 |
+ * | -17 |
+ * | -15 |
+ * | -13 |
+ * | -11 |
+ * | -9 |
+ * +-------+
+ */
+ test("rangeTable lessequal than -9") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(catalog)
+ val s = df.filter($"intCol0" <= -9).select($"intCol0")
+ s.show
+ // filter results without going through dataframe
+ val expected = rawResult.filter(_.intCol0 <= -9).map(_.intCol0).toSet
+ // filter results going through dataframe
+ val result = collectToSet[Int](s)
+ assert(expected === result)
+ }
+
+ /**
+ * expected result: only showing top 20 rows
+ * +-------+
+ * |intCol0|
+ * +-------+
+ * | 0 |
+ * | 2 |
+ * | 4 |
+ * | 6 |
+ * | 8 |
+ * | 10 |
+ * | 12 |
+ * | 14 |
+ * | 16 |
+ * | 18 |
+ * | 20 |
+ * | 22 |
+ * | 24 |
+ * | 26 |
+ * | 28 |
+ * | 30 |
+ * | -9 |
+ * | -7 |
+ * | -5 |
+ * | -3 |
+ * +-------+
+ */
+ test("rangeTable greaterequal than -9") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(catalog)
+ val s = df.filter($"intCol0" >= -9).select($"intCol0")
+ s.show
+ // filter results without going through dataframe
+ val expected = rawResult.filter(_.intCol0 >= -9).map(_.intCol0).toSet
+ // filter results going through dataframe
+ val result = collectToSet[Int](s)
+ assert(expected === result)
+ }
+
+ /**
+ * expected result: only showing top 20 rows
+ * +-------+
+ * |intCol0|
+ * +-------+
+ * | 0 |
+ * | 2 |
+ * | 4 |
+ * | 6 |
+ * | 8 |
+ * | 10 |
+ * | 12 |
+ * | 14 |
+ * | 16 |
+ * | 18 |
+ * | 20 |
+ * | 22 |
+ * | 24 |
+ * | 26 |
+ * | 28 |
+ * | 30 |
+ * +-------+
+ */
+ test("rangeTable greaterequal than 0") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(catalog)
+ val s = df.filter($"intCol0" >= 0).select($"intCol0")
+ s.show
+ // filter results without going through dataframe
+ val expected = rawResult.filter(_.intCol0 >= 0).map(_.intCol0).toSet
+ // filter results going through dataframe
+ val result = collectToSet[Int](s)
+ assert(expected === result)
+ }
+
+ /**
+ * expected result: only showing top 20 rows
+ * +-------+
+ * |intCol0|
+ * +-------+
+ * | 12 |
+ * | 14 |
+ * | 16 |
+ * | 18 |
+ * | 20 |
+ * | 22 |
+ * | 24 |
+ * | 26 |
+ * | 28 |
+ * | 30 |
+ * +-------+
+ */
+ test("rangeTable greater than 10") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(catalog)
+ val s = df.filter($"intCol0" > 10).select($"intCol0")
+ s.show
+ // filter results without going through dataframe
+ val expected = rawResult.filter(_.intCol0 > 10).map(_.intCol0).toSet
+ // filter results going through dataframe
+ val result = collectToSet[Int](s)
+ assert(expected === result)
+ }
+
+ /**
+ * expected result: only showing top 20 rows
+ * +-------+
+ * |intCol0|
+ * +-------+
+ * | 0 |
+ * | 2 |
+ * | 4 |
+ * | 6 |
+ * | 8 |
+ * | 10 |
+ * | -9 |
+ * | -7 |
+ * | -5 |
+ * | -3 |
+ * | -1 |
+ * +-------+
+ */
+ test("rangeTable and") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(catalog)
+ val s = df.filter($"intCol0" > -10 && $"intCol0" <= 10).select($"intCol0")
+ s.show
+ // filter results without going through dataframe
+ val expected = rawResult
+ .filter(x => x.intCol0 > -10 && x.intCol0 <= 10)
+ .map(_.intCol0)
+ .toSet
+ // filter results going through dataframe
+ val result = collectToSet[Int](s)
+ assert(expected === result)
+ }
+
+ /**
+ * expected result: only showing top 20 rows
+ * +-------+
+ * |intCol0|
+ * +-------+
+ * | 12 |
+ * | 14 |
+ * | 16 |
+ * | 18 |
+ * | 20 |
+ * | 22 |
+ * | 24 |
+ * | 26 |
+ * | 28 |
+ * | 30 |
+ * | -31 |
+ * | -29 |
+ * | -27 |
+ * | -25 |
+ * | -23 |
+ * | -21 |
+ * | -19 |
+ * | -17 |
+ * | -15 |
+ * | -13 |
+ * +-------+
+ */
+
+ test("or") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(catalog)
+ val s = df.filter($"intCol0" <= -10 || $"intCol0" > 10).select($"intCol0")
+ s.show
+ // filter results without going through dataframe
+ val expected = rawResult
+ .filter(x => x.intCol0 <= -10 || x.intCol0 > 10)
+ .map(_.intCol0)
+ .toSet
+ // filter results going through dataframe
+ val result = collectToSet[Int](s)
+ assert(expected === result)
+ }
+
+ /**
+ * expected result: only showing top 20 rows
+ * +-------+
+ * |intCol0|
+ * +-------+
+ * | 0 |
+ * | 2 |
+ * | 4 |
+ * | 6 |
+ * | 8 |
+ * | 10 |
+ * | 12 |
+ * | 14 |
+ * | 16 |
+ * | 18 |
+ * | 20 |
+ * | 22 |
+ * | 24 |
+ * | 26 |
+ * | 28 |
+ * | 30 |
+ * | -31 |
+ * | -29 |
+ * | -27 |
+ * | -25 |
+ * +-------+
+ */
+ test("rangeTable all") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withCatalog(catalog)
+ val s = df.filter($"intCol0" >= -100).select($"intCol0")
+ s.show
+ // filter results without going through dataframe
+ val expected = rawResult.filter(_.intCol0 >= -100).map(_.intCol0).toSet
+ // filter results going through dataframe
+ val result = collectToSet[Int](s)
+ assert(expected === result)
+ }
+}
diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/StartsWithSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/StartsWithSuite.scala
new file mode 100644
index 00000000..9a0d5675
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/StartsWithSuite.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.spark.datasources.Utils
+import org.apache.hadoop.hbase.util.Bytes
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+
+class StartsWithSuite extends AnyFunSuite with BeforeAndAfterEach with BeforeAndAfterAll with Logging {
+
+ test("simple1") {
+ val t = new Array[Byte](2)
+ t(0) = 1.toByte
+ t(1) = 2.toByte
+
+ val expected = new Array[Byte](2)
+ expected(0) = 1.toByte
+ expected(1) = 3.toByte
+
+ val res = Utils.incrementByteArray(t)
+ assert(res.sameElements(expected))
+ }
+
+ test("simple2") {
+ val t = new Array[Byte](1)
+ t(0) = 87.toByte
+
+ val expected = new Array[Byte](1)
+ expected(0) = 88.toByte
+
+ val res = Utils.incrementByteArray(t)
+ assert(res.sameElements(expected))
+ }
+
+ test("overflow1") {
+ val t = new Array[Byte](2)
+ t(0) = 1.toByte
+ t(1) = (-1).toByte
+
+ val expected = new Array[Byte](2)
+ expected(0) = 2.toByte
+ expected(1) = 0.toByte
+
+ val res = Utils.incrementByteArray(t)
+
+ assert(res.sameElements(expected))
+ }
+
+ test("overflow2") {
+ val t = new Array[Byte](2)
+ t(0) = (-1).toByte
+ t(1) = (-1).toByte
+
+ val expected = null
+
+ val res = Utils.incrementByteArray(t)
+
+ assert(res == expected)
+ }
+
+ test("max-min-value") {
+ val t = new Array[Byte](2)
+ t(0) = 1.toByte
+ t(1) = (127).toByte
+
+ val expected = new Array[Byte](2)
+ expected(0) = 1.toByte
+ expected(1) = (-128).toByte
+
+ val res = Utils.incrementByteArray(t)
+ assert(res.sameElements(expected))
+ }
+
+ test("complicated") {
+ val imput = "row005"
+ val expectedOutput = "row006"
+
+ val t = Bytes.toBytes(imput)
+ val expected = Bytes.toBytes(expectedOutput)
+
+ val res = Utils.incrementByteArray(t)
+ assert(res.sameElements(expected))
+ }
+
+}
diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/TableOutputFormatSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/TableOutputFormatSuite.scala
new file mode 100644
index 00000000..a180dc97
--- /dev/null
+++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/TableOutputFormatSuite.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.spark
+
+import java.text.SimpleDateFormat
+import java.util.{Date, Locale}
+import org.apache.hadoop.hbase.{HBaseTestingUtility, TableName, TableNotFoundException}
+import org.apache.hadoop.hbase.mapreduce.TableOutputFormat
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptID, TaskType}
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+import org.apache.spark.{SparkConf, SparkContext}
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+
+import scala.util.{Failure, Success, Try}
+
+// Unit tests for HBASE-20521: change get configuration(TableOutputFormat.conf) object first sequence from jobContext to getConf
+// this suite contains two tests, one for normal case(getConf return null, use jobContext), create new TableOutputformat object without init TableOutputFormat.conf object,
+// configuration object inside checkOutputSpecs came from jobContext.
+// The other one(getConf return conf object) we manually call "setConf" to init TableOutputFormat.conf, for making it more straight forward, we specify a nonexistent table
+// name in conf object, checkOutputSpecs will then throw TableNotFoundException exception
+class TableOutputFormatSuite
+ extends AnyFunSuite
+ with BeforeAndAfterEach
+ with BeforeAndAfterAll
+ with Logging {
+ @transient var sc: SparkContext = null
+ var TEST_UTIL = new HBaseTestingUtility
+
+ val tableName = "TableOutputFormatTest"
+ val tableNameTest = "NonExistentTable"
+ val columnFamily = "cf"
+
+ override protected def beforeAll(): Unit = {
+ TEST_UTIL.startMiniCluster
+
+ logInfo(" - minicluster started")
+ try {
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName))
+ } catch {
+ case e: Exception => logInfo(" - no table " + tableName + " found")
+ }
+
+ TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily))
+ logInfo(" - created table")
+
+ // set "validateOutputSpecs" true anyway, force to validate output spec
+ val sparkConf = new SparkConf()
+ .setMaster("local")
+ .setAppName("test")
+
+ sc = new SparkContext(sparkConf)
+ }
+
+ override protected def afterAll(): Unit = {
+ logInfo(" - delete table: " + tableName)
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName))
+ logInfo(" - shutting down minicluster")
+ TEST_UTIL.shutdownMiniCluster()
+
+ TEST_UTIL.cleanupTestDir()
+ sc.stop()
+ }
+
+ def getJobContext() = {
+ val hConf = TEST_UTIL.getConfiguration
+ hConf.set(TableOutputFormat.OUTPUT_TABLE, tableName)
+ val job = Job.getInstance(hConf)
+ job.setOutputFormatClass(classOf[TableOutputFormat[String]])
+
+ val jobTrackerId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(new Date())
+ val jobAttemptId = new TaskAttemptID(jobTrackerId, 1, TaskType.MAP, 0, 0)
+ new TaskAttemptContextImpl(job.getConfiguration, jobAttemptId)
+ }
+
+ // Mock up jobContext object and execute actions in "write" function
+ // from "org.apache.spark.internal.io.SparkHadoopMapReduceWriter"
+ // this case should run normally without any exceptions
+ test(
+ "TableOutputFormat.checkOutputSpecs test without setConf called, should return true and without exceptions") {
+ val jobContext = getJobContext()
+ val format = jobContext.getOutputFormatClass
+ val jobFormat = format.newInstance
+ Try {
+ jobFormat.checkOutputSpecs(jobContext)
+ } match {
+ case Success(_) => assert(true)
+ case Failure(_) => assert(false)
+ }
+ }
+
+ // Set configuration externally, checkOutputSpec should use configuration object set by "SetConf" method
+ // rather than jobContext, this case should throw "TableNotFoundException" exception
+ test(
+ "TableOutputFormat.checkOutputSpecs test without setConf called, should throw TableNotFoundException") {
+ val jobContext = getJobContext()
+ val format = jobContext.getOutputFormatClass
+ val jobFormat = format.newInstance
+
+ val hConf = TEST_UTIL.getConfiguration
+ hConf.set(TableOutputFormat.OUTPUT_TABLE, tableNameTest)
+ jobFormat.asInstanceOf[TableOutputFormat[String]].setConf(hConf)
+ Try {
+ jobFormat.checkOutputSpecs(jobContext)
+ } match {
+ case Success(_) => assert(false)
+ case Failure(e: Exception) => {
+ if (e.isInstanceOf[TableNotFoundException])
+ assert(true)
+ else
+ assert(false)
+ }
+ case _ => None
+ }
+ }
+
+}
diff --git a/spark4/pom.xml b/spark4/pom.xml
new file mode 100644
index 00000000..f6bfe2b5
--- /dev/null
+++ b/spark4/pom.xml
@@ -0,0 +1,103 @@
+
+
+
+ 4.0.0
+
+
+ org.apache.hbase.connectors
+ hbase-connectors
+ ${revision}
+ ../pom.xml
+
+
+ spark4
+ ${revision}
+ pom
+ Apache HBase - Spark4
+ Spark4 Connectors for Apache HBase
+
+
+ hbase-spark4-protocol
+ hbase-spark4-protocol-shaded
+ hbase-spark4
+ hbase-spark4-it
+
+
+
+ 0.6.1
+ 2.12.5
+ 4.0.0-preview1
+
+ 2.13.14
+ 2.13
+ 3.2.3
+
+
+
+
+
+ org.apache.hbase.connectors.spark
+ hbase-spark4
+ ${revision}
+
+
+ org.apache.hbase.connectors.spark
+ hbase-spark4-protocol
+ ${revision}
+
+
+ org.apache.hbase.connectors.spark
+ hbase-spark4-protocol-shaded
+ ${revision}
+
+
+ org.apache.hbase.thirdparty
+ hbase-shaded-miscellaneous
+ ${hbase-thirdparty.version}
+
+
+
+
+
+
+
+
+ org.xolstice.maven.plugins
+ protobuf-maven-plugin
+ ${protobuf.plugin.version}
+
+ com.google.protobuf:protoc:${external.protobuf.version}:exe:${os.detected.classifier}
+ ${basedir}/src/main/protobuf/
+ false
+ true
+
+
+
+ net.revelc.code
+ warbucks-maven-plugin
+ 1.1.0
+
+
+
+
+
+