diff --git a/pom.xml b/pom.xml index bdd41e9f..4f4b5860 100644 --- a/pom.xml +++ b/pom.xml @@ -106,6 +106,7 @@ kafka spark + spark4 hbase-connectors-assembly diff --git a/spark4/README.md b/spark4/README.md new file mode 100644 index 00000000..e4d596f1 --- /dev/null +++ b/spark4/README.md @@ -0,0 +1,40 @@ + + +# Apache HBase™ Spark Connector + +## Spark, Scala and Configurable Options + +To generate an artifact for a different [Spark version](https://mvnrepository.com/artifact/org.apache.spark/spark-core) and/or [Scala version](https://www.scala-lang.org/download/all.html), +[Hadoop version](https://mvnrepository.com/artifact/org.apache.hadoop/hadoop-core), or [HBase version](https://mvnrepository.com/artifact/org.apache.hbase/hbase), pass command-line options as follows (changing version numbers appropriately): + +``` +$ mvn -Dcheckstyle.skip=true --add-opens java.base/java.util=ALL-UNNAMED --add-opens java.base/java.lang=ALL-UNNAMED --add-opens java.base/sun.net.util=ALL-UNNAMED --add-opens java.base/java.lang=ALL-UNNAMED -Dspark.version=4.0.0-preview1 -Dscala.version=2.13.14 -Dhadoop-three.version=3.2.3 -Dscala.binary.version=2.13 -Dhbase.version=2.4.8 clean install +``` + + +## Configuration and Installation +**Client-side** (Spark) configuration: +- The HBase configuration file `hbase-site.xml` should be made available to Spark, it can be copied to `$SPARK_CONF_DIR` (default is $SPARK_HOME/conf`) + +**Server-side** (HBase region servers) configuration: +- The following jars need to be in the CLASSPATH of the HBase region servers: + - scala-library, hbase-spark, and hbase-spark-protocol-shaded. +- The server-side configuration is needed for column filter pushdown + - if you cannot perform the server-side configuration, consider using `.option("hbase.spark.pushdown.columnfilter", false)` +- The Scala library version must match the Scala version (2.13) used for compiling the connector. diff --git a/spark4/hbase-spark4-it/pom.xml b/spark4/hbase-spark4-it/pom.xml new file mode 100644 index 00000000..db850594 --- /dev/null +++ b/spark4/hbase-spark4-it/pom.xml @@ -0,0 +1,360 @@ + + + + 4.0.0 + + org.apache.hbase.connectors + spark4 + ${revision} + ../pom.xml + + org.apache.hbase.connectors.spark + hbase-spark4-it + Apache HBase - Spark4 Integration Tests + Integration and System tests for HBase + + + **/Test*.java + **/IntegrationTest*.java + + 4g + + + + + + org.slf4j + slf4j-api + + + + org.apache.hbase + hbase-shaded-testing-util + + + org.apache.hbase + hbase-it + test-jar + + + org.apache.hbase.connectors.spark + hbase-spark4 + ${revision} + + + org.apache.hbase + ${compat.module} + + + org.apache.hbase.thirdparty + hbase-shaded-miscellaneous + + + org.apache.hadoop + hadoop-common + ${hadoop-two.version} + + + com.google.code.findbugs + jsr305 + + + log4j + log4j + + + org.slf4j + slf4j-log4j12 + + + + + org.apache.hadoop + hadoop-common + ${hadoop-two.version} + test-jar + test + + + com.google.code.findbugs + jsr305 + + + log4j + log4j + + + org.slf4j + slf4j-log4j12 + + + + + + com.fasterxml.jackson.module + jackson-module-scala_${scala.binary.version} + ${jackson.version} + test + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + provided + + + + org.scala-lang + scala-library + + + + org.scala-lang + scalap + + + com.google.code.findbugs + jsr305 + + + org.apache.hadoop + hadoop-client-api + + + org.apache.hadoop + hadoop-client-runtime + + + + + org.scala-lang + scala-library + ${scala.version} + provided + + + org.scala-lang.modules + scala-xml_2.11 + 1.0.4 + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${spark.version} + tests + test-jar + test + + + + junit + junit + test + + + org.mockito + mockito-all + test + + + + + + + + org.apache.maven.plugins + maven-source-plugin + + + + maven-assembly-plugin + + true + + + + org.apache.maven.plugins + maven-failsafe-plugin + ${surefire.version} + + + ${integrationtest.include} + + + ${unittest.include} + **/*$* + + ${test.output.tofile} + false + false + + + + org.apache.maven.surefire + surefire-junit4 + ${surefire.version} + + + + + integration-test + + integration-test + + integration-test + + + verify + + verify + + verify + + + + + + + + + + org.apache.maven.plugins + maven-failsafe-plugin + + false + 1 + true + + 1800 + -enableassertions -Xmx${failsafe.Xmx} + -Djava.security.egd=file:/dev/./urandom -XX:+CMSClassUnloadingEnabled + -verbose:gc -XX:+PrintCommandLineFlags -XX:+PrintFlagsFinal + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + + banned-hbase-spark + + enforce + + + true + + + + banned-scala + + enforce + + + true + + + + + + maven-dependency-plugin + + + create-mrapp-generated-classpath + + build-classpath + + generate-test-resources + + + ${project.build.directory}/test-classes/spark-generated-classpath + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + + + net.revelc.code + warbucks-maven-plugin + + + + + + + + org.apache.maven.plugins + maven-surefire-report-plugin + 3.0.0-M4 + + + spark-integration-tests + + report-only + + + failsafe-report + + ${project.build.directory}/failsafe-reports + + + + + + + + + + + + skipIntegrationTests + + + skipIntegrationTests + + + + true + + + + + diff --git a/spark4/hbase-spark4-it/src/test/java/org/apache/hadoop/hbase/spark/IntegrationTestSparkBulkLoad.java b/spark4/hbase-spark4-it/src/test/java/org/apache/hadoop/hbase/spark/IntegrationTestSparkBulkLoad.java new file mode 100644 index 00000000..07669157 --- /dev/null +++ b/spark4/hbase-spark4-it/src/test/java/org/apache/hadoop/hbase/spark/IntegrationTestSparkBulkLoad.java @@ -0,0 +1,654 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hbase.Cell; +import org.apache.hadoop.hbase.CellUtil; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.HBaseTestingUtility; +import org.apache.hadoop.hbase.HConstants; +import org.apache.hadoop.hbase.HTableDescriptor; +import org.apache.hadoop.hbase.IntegrationTestBase; +import org.apache.hadoop.hbase.IntegrationTestingUtility; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Admin; +import org.apache.hadoop.hbase.client.Connection; +import org.apache.hadoop.hbase.client.ConnectionFactory; +import org.apache.hadoop.hbase.client.Consistency; +import org.apache.hadoop.hbase.client.RegionLocator; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.client.Scan; +import org.apache.hadoop.hbase.client.Table; +import org.apache.hadoop.hbase.io.ImmutableBytesWritable; +import org.apache.hadoop.hbase.mapreduce.IntegrationTestBulkLoad; +import org.apache.hadoop.hbase.tool.LoadIncrementalHFiles; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.hadoop.hbase.util.EnvironmentEdgeManager; +import org.apache.hadoop.hbase.util.Pair; +import org.apache.hadoop.hbase.util.RegionSplitter; +import org.apache.hadoop.util.StringUtils; +import org.apache.hadoop.util.ToolRunner; +import org.apache.spark.Partitioner; +import org.apache.spark.SerializableWritable; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.apache.spark.api.java.function.VoidFunction; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Tuple2; + +import org.apache.hbase.thirdparty.com.google.common.collect.Sets; +import org.apache.hbase.thirdparty.org.apache.commons.cli.CommandLine; + +/** + * Test Bulk Load and Spark on a distributed cluster. It starts an Spark job that creates linked + * chains. This test mimic {@link IntegrationTestBulkLoad} in mapreduce. Usage on cluster: First add + * hbase related jars and hbase-spark.jar into spark classpath. spark-submit --class + * org.apache.hadoop.hbase.spark.IntegrationTestSparkBulkLoad + * HBASE_HOME/lib/hbase-spark-it-XXX-tests.jar -m slowDeterministic + * -Dhbase.spark.bulkload.chainlength=300 + */ +public class IntegrationTestSparkBulkLoad extends IntegrationTestBase { + + private static final Logger LOG = LoggerFactory.getLogger(IntegrationTestSparkBulkLoad.class); + + // The number of partitions for random generated data + private static String BULKLOAD_PARTITIONS_NUM = "hbase.spark.bulkload.partitionsnum"; + private static int DEFAULT_BULKLOAD_PARTITIONS_NUM = 3; + + private static String BULKLOAD_CHAIN_LENGTH = "hbase.spark.bulkload.chainlength"; + private static int DEFAULT_BULKLOAD_CHAIN_LENGTH = 200000; + + private static String BULKLOAD_IMPORT_ROUNDS = "hbase.spark.bulkload.importround"; + private static int DEFAULT_BULKLOAD_IMPORT_ROUNDS = 1; + + private static String CURRENT_ROUND_NUM = "hbase.spark.bulkload.current.roundnum"; + + private static String NUM_REPLICA_COUNT_KEY = "hbase.spark.bulkload.replica.countkey"; + private static int DEFAULT_NUM_REPLICA_COUNT = 1; + + private static String BULKLOAD_TABLE_NAME = "hbase.spark.bulkload.tableName"; + private static String DEFAULT_BULKLOAD_TABLE_NAME = "IntegrationTestSparkBulkLoad"; + + private static String BULKLOAD_OUTPUT_PATH = "hbase.spark.bulkload.output.path"; + + private static final String OPT_LOAD = "load"; + private static final String OPT_CHECK = "check"; + + private boolean load = false; + private boolean check = false; + + private static final byte[] CHAIN_FAM = Bytes.toBytes("L"); + private static final byte[] SORT_FAM = Bytes.toBytes("S"); + private static final byte[] DATA_FAM = Bytes.toBytes("D"); + + /** + * Running spark job to load data into hbase table + */ + public void runLoad() throws Exception { + setupTable(); + int numImportRounds = getConf().getInt(BULKLOAD_IMPORT_ROUNDS, DEFAULT_BULKLOAD_IMPORT_ROUNDS); + LOG.info("Running load with numIterations:" + numImportRounds); + for (int i = 0; i < numImportRounds; i++) { + runLinkedListSparkJob(i); + } + } + + /** + * Running spark job to create LinkedList for testing + * @param iteration iteration th of this job + * @throws Exception if an HBase operation or getting the test directory fails + */ + public void runLinkedListSparkJob(int iteration) throws Exception { + String jobName = IntegrationTestSparkBulkLoad.class.getSimpleName() + " _load " + + EnvironmentEdgeManager.currentTime(); + + LOG.info("Running iteration " + iteration + "in Spark Job"); + + Path output = null; + if (conf.get(BULKLOAD_OUTPUT_PATH) == null) { + output = util.getDataTestDirOnTestFS(getTablename() + "-" + iteration); + } else { + output = new Path(conf.get(BULKLOAD_OUTPUT_PATH)); + } + + SparkConf sparkConf = new SparkConf().setAppName(jobName).setMaster("local"); + Configuration hbaseConf = new Configuration(getConf()); + hbaseConf.setInt(CURRENT_ROUND_NUM, iteration); + int partitionNum = hbaseConf.getInt(BULKLOAD_PARTITIONS_NUM, DEFAULT_BULKLOAD_PARTITIONS_NUM); + + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, hbaseConf); + + LOG.info("Partition RDD into " + partitionNum + " parts"); + List temp = new ArrayList<>(); + JavaRDD> rdd = jsc.parallelize(temp, partitionNum).mapPartitionsWithIndex( + new LinkedListCreationMapper(new SerializableWritable<>(hbaseConf)), false); + + hbaseContext.bulkLoad(rdd, getTablename(), new ListToKeyValueFunc(), output.toUri().getPath(), + new HashMap<>(), false, HConstants.DEFAULT_MAX_FILE_SIZE); + + try (Connection conn = ConnectionFactory.createConnection(conf); Admin admin = conn.getAdmin(); + Table table = conn.getTable(getTablename()); + RegionLocator regionLocator = conn.getRegionLocator(getTablename())) { + // Create a new loader. + LoadIncrementalHFiles loader = new LoadIncrementalHFiles(conf); + + // Load the HFiles into table. + loader.doBulkLoad(output, admin, table, regionLocator); + } + + // Delete the files. + util.getTestFileSystem().delete(output, true); + jsc.close(); + } + + // See mapreduce.IntegrationTestBulkLoad#LinkedListCreationMapper + // Used to generate test data + public static class LinkedListCreationMapper + implements Function2, Iterator>> { + + SerializableWritable swConfig = null; + private Random rand = new Random(); + + public LinkedListCreationMapper(SerializableWritable conf) { + this.swConfig = conf; + } + + @Override + public Iterator> call(Integer v1, Iterator v2) throws Exception { + Configuration config = (Configuration) swConfig.value(); + int partitionId = v1.intValue(); + LOG.info("Starting create List in Partition " + partitionId); + + int partitionNum = config.getInt(BULKLOAD_PARTITIONS_NUM, DEFAULT_BULKLOAD_PARTITIONS_NUM); + int chainLength = config.getInt(BULKLOAD_CHAIN_LENGTH, DEFAULT_BULKLOAD_CHAIN_LENGTH); + int iterationsNum = config.getInt(BULKLOAD_IMPORT_ROUNDS, DEFAULT_BULKLOAD_IMPORT_ROUNDS); + int iterationsCur = config.getInt(CURRENT_ROUND_NUM, 0); + List> res = new LinkedList<>(); + + long tempId = partitionId + iterationsCur * partitionNum; + long totalPartitionNum = partitionNum * iterationsNum; + long chainId = Math.abs(rand.nextLong()); + chainId = chainId - (chainId % totalPartitionNum) + tempId; + + byte[] chainIdArray = Bytes.toBytes(chainId); + long currentRow = 0; + long nextRow = getNextRow(0, chainLength); + for (long i = 0; i < chainLength; i++) { + byte[] rk = Bytes.toBytes(currentRow); + // Insert record into a list + List tmp1 = Arrays.asList(rk, CHAIN_FAM, chainIdArray, Bytes.toBytes(nextRow)); + List tmp2 = Arrays.asList(rk, SORT_FAM, chainIdArray, Bytes.toBytes(i)); + List tmp3 = Arrays.asList(rk, DATA_FAM, chainIdArray, Bytes.toBytes("random" + i)); + res.add(tmp1); + res.add(tmp2); + res.add(tmp3); + + currentRow = nextRow; + nextRow = getNextRow(i + 1, chainLength); + } + return res.iterator(); + } + + /** Returns a unique row id within this chain for this index */ + private long getNextRow(long index, long chainLength) { + long nextRow = Math.abs(new Random().nextLong()); + // use significant bits from the random number, but pad with index to ensure it is unique + // this also ensures that we do not reuse row = 0 + // row collisions from multiple mappers are fine, since we guarantee unique chainIds + nextRow = nextRow - (nextRow % chainLength) + index; + return nextRow; + } + } + + public static class ListToKeyValueFunc + implements Function, Pair> { + @Override + public Pair call(List v1) throws Exception { + if (v1 == null || v1.size() != 4) { + return null; + } + KeyFamilyQualifier kfq = new KeyFamilyQualifier(v1.get(0), v1.get(1), v1.get(2)); + + return new Pair<>(kfq, v1.get(3)); + } + } + + /** + * After adding data to the table start a mr job to check the bulk load. + */ + public void runCheck() throws Exception { + LOG.info("Running check"); + String jobName = IntegrationTestSparkBulkLoad.class.getSimpleName() + "_check" + + EnvironmentEdgeManager.currentTime(); + + SparkConf sparkConf = new SparkConf().setAppName(jobName).setMaster("local"); + Configuration hbaseConf = new Configuration(getConf()); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, hbaseConf); + + Scan scan = new Scan(); + scan.addFamily(CHAIN_FAM); + scan.addFamily(SORT_FAM); + scan.setMaxVersions(1); + scan.setCacheBlocks(false); + scan.setBatch(1000); + int replicaCount = conf.getInt(NUM_REPLICA_COUNT_KEY, DEFAULT_NUM_REPLICA_COUNT); + if (replicaCount != DEFAULT_NUM_REPLICA_COUNT) { + scan.setConsistency(Consistency.TIMELINE); + } + + // 1. Using TableInputFormat to get data from HBase table + // 2. Mimic LinkedListCheckingMapper in mapreduce.IntegrationTestBulkLoad + // 3. Sort LinkKey by its order ID + // 4. Group LinkKey if they have same chainId, and repartition RDD by NaturalKeyPartitioner + // 5. Check LinkList in each Partition using LinkedListCheckingFlatMapFunc + hbaseContext.hbaseRDD(getTablename(), scan).flatMapToPair(new LinkedListCheckingFlatMapFunc()) + .sortByKey() + .combineByKey(new createCombinerFunc(), new mergeValueFunc(), new mergeCombinersFunc(), + new NaturalKeyPartitioner(new SerializableWritable<>(hbaseConf))) + .foreach(new LinkedListCheckingForeachFunc(new SerializableWritable<>(hbaseConf))); + jsc.close(); + } + + private void runCheckWithRetry() throws Exception { + try { + runCheck(); + } catch (Throwable t) { + LOG.warn("Received " + StringUtils.stringifyException(t)); + LOG.warn("Running the check MR Job again to see whether an ephemeral problem or not"); + runCheck(); + throw t; // we should still fail the test even if second retry succeeds + } + // everything green + } + + /** + * PairFlatMapFunction used to transfer {@code } to + * {@code Tuple}. + */ + public static class LinkedListCheckingFlatMapFunc implements + PairFlatMapFunction, SparkLinkKey, SparkLinkChain> { + + @Override + public Iterator> + call(Tuple2 v) throws Exception { + Result value = v._2(); + long longRk = Bytes.toLong(value.getRow()); + List> list = new LinkedList<>(); + + for (Map.Entry entry : value.getFamilyMap(CHAIN_FAM).entrySet()) { + long chainId = Bytes.toLong(entry.getKey()); + long next = Bytes.toLong(entry.getValue()); + Cell c = value.getColumnCells(SORT_FAM, entry.getKey()).get(0); + long order = Bytes.toLong(CellUtil.cloneValue(c)); + Tuple2 tuple2 = + new Tuple2<>(new SparkLinkKey(chainId, order), new SparkLinkChain(longRk, next)); + list.add(tuple2); + } + return list.iterator(); + } + } + + public static class createCombinerFunc implements Function> { + @Override + public List call(SparkLinkChain v1) throws Exception { + List list = new LinkedList<>(); + list.add(v1); + return list; + } + } + + public static class mergeValueFunc + implements Function2, SparkLinkChain, List> { + @Override + public List call(List v1, SparkLinkChain v2) throws Exception { + if (v1 == null) { + v1 = new LinkedList<>(); + } + + v1.add(v2); + return v1; + } + } + + public static class mergeCombinersFunc + implements Function2, List, List> { + @Override + public List call(List v1, List v2) + throws Exception { + v1.addAll(v2); + return v1; + } + } + + /** + * Class to figure out what partition to send a link in the chain to. This is based upon the + * linkKey's ChainId. + */ + public static class NaturalKeyPartitioner extends Partitioner { + + private int numPartions = 0; + + public NaturalKeyPartitioner(SerializableWritable swConf) { + Configuration hbaseConf = (Configuration) swConf.value(); + numPartions = hbaseConf.getInt(BULKLOAD_PARTITIONS_NUM, DEFAULT_BULKLOAD_PARTITIONS_NUM); + + } + + @Override + public int numPartitions() { + return numPartions; + } + + @Override + public int getPartition(Object key) { + if (!(key instanceof SparkLinkKey)) { + return -1; + } + + int hash = ((SparkLinkKey) key).getChainId().hashCode(); + return Math.abs(hash % numPartions); + + } + } + + /** + * Sort all LinkChain for one LinkKey, and test {@code List}. + */ + public static class LinkedListCheckingForeachFunc + implements VoidFunction>> { + + private SerializableWritable swConf = null; + + public LinkedListCheckingForeachFunc(SerializableWritable conf) { + swConf = conf; + } + + @Override + public void call(Tuple2> v1) throws Exception { + long next = -1L; + long prev = -1L; + long count = 0L; + + SparkLinkKey key = v1._1(); + List values = v1._2(); + + for (SparkLinkChain lc : values) { + + if (next == -1) { + if (lc.getRk() != 0L) { + String msg = "Chains should all start at rk 0, but read rk " + lc.getRk() + ". Chain:" + + key.getChainId() + ", order:" + key.getOrder(); + throw new RuntimeException(msg); + } + next = lc.getNext(); + } else { + if (next != lc.getRk()) { + String msg = "Missing a link in the chain. Prev rk " + prev + " was, expecting " + next + + " but got " + lc.getRk() + ". Chain:" + key.getChainId() + ", order:" + + key.getOrder(); + throw new RuntimeException(msg); + } + prev = lc.getRk(); + next = lc.getNext(); + } + count++; + } + Configuration hbaseConf = (Configuration) swConf.value(); + int expectedChainLen = hbaseConf.getInt(BULKLOAD_CHAIN_LENGTH, DEFAULT_BULKLOAD_CHAIN_LENGTH); + if (count != expectedChainLen) { + String msg = "Chain wasn't the correct length. Expected " + expectedChainLen + " got " + + count + ". Chain:" + key.getChainId() + ", order:" + key.getOrder(); + throw new RuntimeException(msg); + } + } + } + + /** + * Writable class used as the key to group links in the linked list. Used as the key emited from a + * pass over the table. + */ + public static class SparkLinkKey implements java.io.Serializable, Comparable { + + private Long chainId; + private Long order; + + public Long getOrder() { + return order; + } + + public Long getChainId() { + return chainId; + } + + public SparkLinkKey(long chainId, long order) { + this.chainId = chainId; + this.order = order; + } + + @Override + public int hashCode() { + return this.getChainId().hashCode(); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof SparkLinkKey)) { + return false; + } + + SparkLinkKey otherKey = (SparkLinkKey) other; + return this.getChainId().equals(otherKey.getChainId()); + } + + @Override + public int compareTo(SparkLinkKey other) { + int res = getChainId().compareTo(other.getChainId()); + + if (res == 0) { + res = getOrder().compareTo(other.getOrder()); + } + + return res; + } + } + + /** + * Writable used as the value emitted from a pass over the hbase table. + */ + public static class SparkLinkChain implements java.io.Serializable, Comparable { + + public Long getNext() { + return next; + } + + public Long getRk() { + return rk; + } + + public SparkLinkChain(Long rk, Long next) { + this.rk = rk; + this.next = next; + } + + private Long rk; + private Long next; + + @Override + public int compareTo(SparkLinkChain linkChain) { + int res = getRk().compareTo(linkChain.getRk()); + if (res == 0) { + res = getNext().compareTo(linkChain.getNext()); + } + return res; + } + + @Override + public int hashCode() { + return getRk().hashCode() ^ getNext().hashCode(); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof SparkLinkChain)) { + return false; + } + + SparkLinkChain otherKey = (SparkLinkChain) other; + return this.getRk().equals(otherKey.getRk()) && this.getNext().equals(otherKey.getNext()); + } + } + + /** + * Allow the scan to go to replica, this would not affect the runCheck() Since data are BulkLoaded + * from HFile into table + * @throws IOException if an HBase operation fails + * @throws InterruptedException if modifying the table fails + */ + private void installSlowingCoproc() throws IOException, InterruptedException { + int replicaCount = conf.getInt(NUM_REPLICA_COUNT_KEY, DEFAULT_NUM_REPLICA_COUNT); + + if (replicaCount == DEFAULT_NUM_REPLICA_COUNT) { + return; + } + + TableName t = getTablename(); + Admin admin = util.getAdmin(); + HTableDescriptor desc = admin.getTableDescriptor(t); + desc.addCoprocessor(IntegrationTestBulkLoad.SlowMeCoproScanOperations.class.getName()); + HBaseTestingUtility.modifyTableSync(admin, desc); + } + + @Test + public void testBulkLoad() throws Exception { + runLoad(); + installSlowingCoproc(); + runCheckWithRetry(); + } + + private byte[][] getSplits(int numRegions) { + RegionSplitter.UniformSplit split = new RegionSplitter.UniformSplit(); + split.setFirstRow(Bytes.toBytes(0L)); + split.setLastRow(Bytes.toBytes(Long.MAX_VALUE)); + return split.split(numRegions); + } + + private void setupTable() throws IOException, InterruptedException { + if (util.getAdmin().tableExists(getTablename())) { + util.deleteTable(getTablename()); + } + + util.createTable(getTablename(), new byte[][] { CHAIN_FAM, SORT_FAM, DATA_FAM }, getSplits(16)); + + int replicaCount = conf.getInt(NUM_REPLICA_COUNT_KEY, DEFAULT_NUM_REPLICA_COUNT); + + if (replicaCount == DEFAULT_NUM_REPLICA_COUNT) { + return; + } + + TableName t = getTablename(); + HBaseTestingUtility.setReplicas(util.getAdmin(), t, replicaCount); + } + + @Override + public void setUpCluster() throws Exception { + util = getTestingUtil(getConf()); + util.initializeCluster(1); + int replicaCount = getConf().getInt(NUM_REPLICA_COUNT_KEY, DEFAULT_NUM_REPLICA_COUNT); + if (LOG.isDebugEnabled() && replicaCount != DEFAULT_NUM_REPLICA_COUNT) { + LOG.debug("Region Replicas enabled: " + replicaCount); + } + + // Scale this up on a real cluster + if (util.isDistributedCluster()) { + util.getConfiguration().setIfUnset(BULKLOAD_PARTITIONS_NUM, + String.valueOf(DEFAULT_BULKLOAD_PARTITIONS_NUM)); + util.getConfiguration().setIfUnset(BULKLOAD_IMPORT_ROUNDS, "1"); + } else { + util.startMiniMapReduceCluster(); + } + } + + @Override + protected void addOptions() { + super.addOptions(); + super.addOptNoArg(OPT_CHECK, "Run check only"); + super.addOptNoArg(OPT_LOAD, "Run load only"); + } + + @Override + protected void processOptions(CommandLine cmd) { + super.processOptions(cmd); + check = cmd.hasOption(OPT_CHECK); + load = cmd.hasOption(OPT_LOAD); + } + + @Override + public int runTestFromCommandLine() throws Exception { + if (load) { + runLoad(); + } else if (check) { + installSlowingCoproc(); + runCheckWithRetry(); + } else { + testBulkLoad(); + } + return 0; + } + + @Override + public TableName getTablename() { + return getTableName(getConf()); + } + + public static TableName getTableName(Configuration conf) { + return TableName.valueOf(conf.get(BULKLOAD_TABLE_NAME, DEFAULT_BULKLOAD_TABLE_NAME)); + } + + @Override + protected Set getColumnFamilies() { + return Sets.newHashSet(Bytes.toString(CHAIN_FAM), Bytes.toString(DATA_FAM), + Bytes.toString(SORT_FAM)); + } + + public static void main(String[] args) throws Exception { + Configuration conf = HBaseConfiguration.create(); + IntegrationTestingUtility.setUseDistributedCluster(conf); + int status = ToolRunner.run(conf, new IntegrationTestSparkBulkLoad(), args); + System.exit(status); + } +} diff --git a/spark4/hbase-spark4-it/src/test/resources/hbase-site.xml b/spark4/hbase-spark4-it/src/test/resources/hbase-site.xml new file mode 100644 index 00000000..99d2ab8d --- /dev/null +++ b/spark4/hbase-spark4-it/src/test/resources/hbase-site.xml @@ -0,0 +1,32 @@ + + + + + + hbase.defaults.for.version.skip + true + + + hbase.hconnection.threads.keepalivetime + 3 + + diff --git a/spark4/hbase-spark4-protocol-shaded/pom.xml b/spark4/hbase-spark4-protocol-shaded/pom.xml new file mode 100644 index 00000000..de113ecd --- /dev/null +++ b/spark4/hbase-spark4-protocol-shaded/pom.xml @@ -0,0 +1,89 @@ + + + + 4.0.0 + + + org.apache.hbase.connectors + spark4 + ${revision} + ../pom.xml + + + org.apache.hbase.connectors.spark + hbase-spark4-protocol-shaded + Apache HBase - Spark4 Protocol (Shaded) + + + + + org.apache.hbase.connectors.spark + hbase-spark4-protocol + true + + + org.apache.hbase.thirdparty + hbase-shaded-protobuf + ${hbase-thirdparty.version} + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + + shade + + package + + true + true + + false + + + com.google.protobuf + org.apache.hbase.thirdparty.com.google.protobuf + + + + + com.google.protobuf:protobuf-java + org.apache.hbase.thirdparty:* + + + + + + + + + + diff --git a/spark4/hbase-spark4-protocol/pom.xml b/spark4/hbase-spark4-protocol/pom.xml new file mode 100644 index 00000000..77a8d0d6 --- /dev/null +++ b/spark4/hbase-spark4-protocol/pom.xml @@ -0,0 +1,76 @@ + + + + 4.0.0 + + + org.apache.hbase.connectors + spark4 + ${revision} + ../ + + + org.apache.hbase.connectors.spark + hbase-spark4-protocol + Apache HBase - Spark4 Protocol + + + + com.google.protobuf + protobuf-java + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + org.xolstice.maven.plugins + protobuf-maven-plugin + + + compile-protoc + + compile + + generate-sources + + + + + org.apache.maven.plugins + maven-source-plugin + + + attach-sources + + jar + + + + + + + + diff --git a/spark4/hbase-spark4-protocol/src/main/protobuf/SparkFilter.proto b/spark4/hbase-spark4-protocol/src/main/protobuf/SparkFilter.proto new file mode 100644 index 00000000..e16c5517 --- /dev/null +++ b/spark4/hbase-spark4-protocol/src/main/protobuf/SparkFilter.proto @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file contains protocol buffers that are used for Spark filters +// over in the hbase-spark module +package hbase.pb; + +option java_package = "org.apache.hadoop.hbase.spark.protobuf.generated"; +option java_outer_classname = "SparkFilterProtos"; +option java_generic_services = true; +option java_generate_equals_and_hash = true; +option optimize_for = SPEED; + +message SQLPredicatePushDownCellToColumnMapping { + required bytes column_family = 1; + required bytes qualifier = 2; + required string column_name = 3; +} + +message SQLPredicatePushDownFilter { + required string dynamic_logic_expression = 1; + repeated bytes value_from_query_array = 2; + repeated SQLPredicatePushDownCellToColumnMapping cell_to_column_mapping = 3; + optional string encoderClassName = 4; +} diff --git a/spark4/hbase-spark4/README.md b/spark4/hbase-spark4/README.md new file mode 100644 index 00000000..315fa640 --- /dev/null +++ b/spark4/hbase-spark4/README.md @@ -0,0 +1,24 @@ + + +##ON PROTOBUFS +This maven module has core protobuf definition files ('.protos') used by hbase +Spark that ship with hbase core including tests. + +Generation of java files from protobuf .proto files included here is done as +part of the build. diff --git a/spark4/hbase-spark4/pom.xml b/spark4/hbase-spark4/pom.xml new file mode 100644 index 00000000..7edafe56 --- /dev/null +++ b/spark4/hbase-spark4/pom.xml @@ -0,0 +1,621 @@ + + + + 4.0.0 + + + org.apache.hbase.connectors + spark4 + ${revision} + ../pom.xml + + + org.apache.hbase.connectors.spark + hbase-spark4 + Apache HBase - Spark4 Connector + + + + org.slf4j + slf4j-api + + + org.apache.hbase.thirdparty + hbase-shaded-miscellaneous + + + + javax.servlet + javax.servlet-api + 3.1.0 + test + + + + org.scala-lang + scala-library + ${scala.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + provided + + + + org.scala-lang + scala-library + + + + org.scala-lang + scalap + + + com.google.code.findbugs + jsr305 + + + + org.xerial.snappy + snappy-java + + + xerces + xercesImpl + + + org.apache.hadoop + hadoop-client-api + + + org.apache.hadoop + hadoop-client-runtime + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${spark.version} + tests + test-jar + test + + + junit + junit + test + + + org.mockito + mockito-all + test + + + org.scalatest + scalatest_${scala.binary.version} + 3.2.19 + test + + + org.scala-lang + scala-library + + + + + org.apache.hbase + hbase-shaded-client + + + org.apache.hbase.connectors.spark + hbase-spark4-protocol-shaded + + + org.apache.yetus + audience-annotations + + + org.apache.hbase + hbase-shaded-testing-util + + + org.apache.hbase + hbase-annotations + test-jar + test + + + org.apache.hbase + hbase-shaded-mapreduce + ${hbase.version} + + + org.apache.avro + avro + + + org.scala-lang + scala-reflect + ${scala.version} + + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${spark.version} + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${spark.version} + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-source + + add-source + + validate + + + src/main/scala + + + + + add-test-source + + add-test-source + + validate + + + src/test/scala + + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + + banned-scala + + enforce + + + true + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + + + net.revelc.code + warbucks-maven-plugin + + + + true + + + + (?!.*(.generated.|.tmpl.|\$|org.apache.hadoop.hbase.spark.hbase.package)).* + false + true + false + false + false + org[.]apache[.]yetus[.]audience[.]InterfaceAudience.* + + + + + + + + + + + skipSparkTests + + + skipSparkTests + + + + true + true + true + + + + + hadoop-2.0 + + + hadoop.profile + 2.0 + + + + + org.apache.hadoop + hadoop-client + ${hadoop-two.version} + + + org.apache.hadoop + hadoop-common + ${hadoop-two.version} + + + com.google.code.findbugs + jsr305 + + + + + org.apache.hadoop + hadoop-common + ${hadoop-two.version} + test-jar + test + + + com.google.code.findbugs + jsr305 + + + + + org.apache.hadoop + hadoop-hdfs + ${hadoop-two.version} + test-jar + test + + + com.google.code.findbugs + jsr305 + + + xerces + xercesImpl + + + + + org.apache.hadoop + hadoop-minikdc + ${hadoop-two.version} + test + + + org.apache.directory.jdbm + apacheds-jdbm1 + + + + + + + + hadoop-3.0 + + + !hadoop.profile + + + + ${hadoop-three.version} + + + + org.apache.hadoop + hadoop-client + ${hadoop-three.version} + + + org.apache.hadoop + hadoop-common + ${hadoop-three.version} + + + com.google.code.findbugs + jsr305 + + + + + org.apache.hadoop + hadoop-common + ${hadoop-three.version} + test-jar + test + + + com.google.code.findbugs + jsr305 + + + log4j + log4j + + + org.slf4j + slf4j-log4j12 + + + + + org.apache.hadoop + hadoop-hdfs + ${hadoop-three.version} + test-jar + test + + + com.google.code.findbugs + jsr305 + + + + + org.apache.hadoop + hadoop-minikdc + ${hadoop-three.version} + test + + + + + + + build-scala-sources + + + scala.skip + !true + + + + + + + org.codehaus.gmaven + gmaven-plugin + 1.5 + + + + execute + + validate + + + + + + + + net.alchim31.maven + scala-maven-plugin + 4.9.2 + + ${project.build.sourceEncoding} + ${scala.version} + + -feature + + ${target.jvm} + + ${compileSource} + ${compileSource} + + + + scala-compile-first + + add-source + compile + + process-resources + + + scala-test-compile + + testCompile + + process-test-resources + + + + + org.scalatest + scalatest-maven-plugin + 2.0.2 + + ${project.build.directory}/surefire-reports + . + WDF TestSuite.txt + false + + + + test + + test + + test + + -Xmx1536m -XX:ReservedCodeCacheSize=512m + false + + + + + + + + + + coverage + + false + + + src/main/ + ${project.parent.parent.basedir} + + + + + net.alchim31.maven + scala-maven-plugin + + src/test/scala + src/main/scala + + + + default-sbt-compile + + compile + testCompile + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + true + true + + + + org.scoverage + scoverage-maven-plugin + + 1.4.11 + true + true + false + ${project.build.sourceEncoding} + ${scoverageReportDir} + ${scoverageReportDir} + + + + instrument + + pre-compile + post-compile + + + + package + + package + report + + + + scoverage-report + + + report-only + + prepare-package + + + + + + + + + org.scoverage + scoverage-maven-plugin + + + + report-only + + + + + + + + + + diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java new file mode 100644 index 00000000..f965e602 --- /dev/null +++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.apache.hadoop.hbase.Cell; +import org.apache.hadoop.hbase.exceptions.DeserializationException; +import org.apache.hadoop.hbase.filter.Filter.ReturnCode; +import org.apache.hadoop.hbase.filter.FilterBase; +import org.apache.hadoop.hbase.spark.datasources.BytesEncoder; +import org.apache.hadoop.hbase.spark.datasources.Field; +import org.apache.hadoop.hbase.spark.datasources.JavaBytesEncoder; +import org.apache.hadoop.hbase.spark.protobuf.generated.SparkFilterProtos; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.yetus.audience.InterfaceAudience; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.collection.mutable.ListBuffer ; + +import org.apache.hbase.thirdparty.com.google.protobuf.ByteString; +import org.apache.hbase.thirdparty.com.google.protobuf.InvalidProtocolBufferException; + +/** + * This filter will push down all qualifier logic given to us by SparkSQL so that we have make the + * filters at the region server level and avoid sending the data back to the client to be filtered. + */ +@InterfaceAudience.Private +public class SparkSQLPushDownFilter extends FilterBase { + protected static final Logger log = LoggerFactory.getLogger(SparkSQLPushDownFilter.class); + + // The following values are populated with protobuffer + DynamicLogicExpression dynamicLogicExpression; + byte[][] valueFromQueryArray; + HashMap> currentCellToColumnIndexMap; + + // The following values are transient + HashMap columnToCurrentRowValueMap = null; + + static final byte[] rowKeyFamily = new byte[0]; + static final byte[] rowKeyQualifier = Bytes.toBytes("key"); + + String encoderClassName; + + public SparkSQLPushDownFilter(DynamicLogicExpression dynamicLogicExpression, + byte[][] valueFromQueryArray, + HashMap> currentCellToColumnIndexMap, + String encoderClassName) { + this.dynamicLogicExpression = dynamicLogicExpression; + this.valueFromQueryArray = valueFromQueryArray; + this.currentCellToColumnIndexMap = currentCellToColumnIndexMap; + this.encoderClassName = encoderClassName; + } + + public SparkSQLPushDownFilter(DynamicLogicExpression dynamicLogicExpression, + byte[][] valueFromQueryArray, ListBuffer fields, String encoderClassName) { + this.dynamicLogicExpression = dynamicLogicExpression; + this.valueFromQueryArray = valueFromQueryArray; + this.encoderClassName = encoderClassName; + + // generate family qualifier to index mapping + this.currentCellToColumnIndexMap = new HashMap<>(); + + for (int i = 0; i < fields.size(); i++) { + Field field = fields.apply(i); + + byte[] cfBytes = field.cfBytes(); + ByteArrayComparable familyByteComparable = + new ByteArrayComparable(cfBytes, 0, cfBytes.length); + + HashMap qualifierIndexMap = + currentCellToColumnIndexMap.get(familyByteComparable); + + if (qualifierIndexMap == null) { + qualifierIndexMap = new HashMap<>(); + currentCellToColumnIndexMap.put(familyByteComparable, qualifierIndexMap); + } + byte[] qBytes = field.colBytes(); + ByteArrayComparable qualifierByteComparable = + new ByteArrayComparable(qBytes, 0, qBytes.length); + + qualifierIndexMap.put(qualifierByteComparable, field.colName()); + } + } + + @Override + public ReturnCode filterCell(final Cell c) throws IOException { + + // If the map RowValueMap is empty then we need to populate + // the row key + if (columnToCurrentRowValueMap == null) { + columnToCurrentRowValueMap = new HashMap<>(); + HashMap qualifierColumnMap = currentCellToColumnIndexMap + .get(new ByteArrayComparable(rowKeyFamily, 0, rowKeyFamily.length)); + + if (qualifierColumnMap != null) { + String rowKeyColumnName = qualifierColumnMap + .get(new ByteArrayComparable(rowKeyQualifier, 0, rowKeyQualifier.length)); + // Make sure that the rowKey is part of the where clause + if (rowKeyColumnName != null) { + columnToCurrentRowValueMap.put(rowKeyColumnName, + new ByteArrayComparable(c.getRowArray(), c.getRowOffset(), c.getRowLength())); + } + } + } + + // Always populate the column value into the RowValueMap + ByteArrayComparable currentFamilyByteComparable = + new ByteArrayComparable(c.getFamilyArray(), c.getFamilyOffset(), c.getFamilyLength()); + + HashMap qualifierColumnMap = + currentCellToColumnIndexMap.get(currentFamilyByteComparable); + + if (qualifierColumnMap != null) { + + String columnName = qualifierColumnMap.get(new ByteArrayComparable(c.getQualifierArray(), + c.getQualifierOffset(), c.getQualifierLength())); + + if (columnName != null) { + columnToCurrentRowValueMap.put(columnName, + new ByteArrayComparable(c.getValueArray(), c.getValueOffset(), c.getValueLength())); + } + } + + return ReturnCode.INCLUDE; + } + + @Override + public boolean filterRow() throws IOException { + + try { + boolean result = + dynamicLogicExpression.execute(columnToCurrentRowValueMap, valueFromQueryArray); + columnToCurrentRowValueMap = null; + return !result; + } catch (Throwable e) { + log.error("Error running dynamic logic on row", e); + } + return false; + } + + /** + * @param pbBytes A pb serialized instance + * @return An instance of SparkSQLPushDownFilter + * @throws DeserializationException if the filter cannot be parsed from the given bytes + */ + @SuppressWarnings("unused") + public static SparkSQLPushDownFilter parseFrom(final byte[] pbBytes) + throws DeserializationException { + + SparkFilterProtos.SQLPredicatePushDownFilter proto; + try { + proto = SparkFilterProtos.SQLPredicatePushDownFilter.parseFrom(pbBytes); + } catch (InvalidProtocolBufferException e) { + throw new DeserializationException(e); + } + + String encoder = proto.getEncoderClassName(); + BytesEncoder enc = JavaBytesEncoder.create(encoder); + + // Load DynamicLogicExpression + DynamicLogicExpression dynamicLogicExpression = + DynamicLogicExpressionBuilder.build(proto.getDynamicLogicExpression(), enc); + + // Load valuesFromQuery + final List valueFromQueryArrayList = proto.getValueFromQueryArrayList(); + byte[][] valueFromQueryArray = new byte[valueFromQueryArrayList.size()][]; + for (int i = 0; i < valueFromQueryArrayList.size(); i++) { + valueFromQueryArray[i] = valueFromQueryArrayList.get(i).toByteArray(); + } + + // Load mapping from HBase family/qualifier to Spark SQL columnName + HashMap> currentCellToColumnIndexMap = + new HashMap<>(); + + for (SparkFilterProtos.SQLPredicatePushDownCellToColumnMapping sqlPredicatePushDownCellToColumnMapping : proto + .getCellToColumnMappingList()) { + + byte[] familyArray = sqlPredicatePushDownCellToColumnMapping.getColumnFamily().toByteArray(); + ByteArrayComparable familyByteComparable = + new ByteArrayComparable(familyArray, 0, familyArray.length); + HashMap qualifierMap = + currentCellToColumnIndexMap.get(familyByteComparable); + + if (qualifierMap == null) { + qualifierMap = new HashMap<>(); + currentCellToColumnIndexMap.put(familyByteComparable, qualifierMap); + } + byte[] qualifierArray = sqlPredicatePushDownCellToColumnMapping.getQualifier().toByteArray(); + + ByteArrayComparable qualifierByteComparable = + new ByteArrayComparable(qualifierArray, 0, qualifierArray.length); + + qualifierMap.put(qualifierByteComparable, + sqlPredicatePushDownCellToColumnMapping.getColumnName()); + } + + return new SparkSQLPushDownFilter(dynamicLogicExpression, valueFromQueryArray, + currentCellToColumnIndexMap, encoder); + } + + /** Returns The filter serialized using pb */ + public byte[] toByteArray() { + + SparkFilterProtos.SQLPredicatePushDownFilter.Builder builder = + SparkFilterProtos.SQLPredicatePushDownFilter.newBuilder(); + + SparkFilterProtos.SQLPredicatePushDownCellToColumnMapping.Builder columnMappingBuilder = + SparkFilterProtos.SQLPredicatePushDownCellToColumnMapping.newBuilder(); + + builder.setDynamicLogicExpression(dynamicLogicExpression.toExpressionString()); + for (byte[] valueFromQuery : valueFromQueryArray) { + builder.addValueFromQueryArray(ByteString.copyFrom(valueFromQuery)); + } + + for (Map.Entry> familyEntry : currentCellToColumnIndexMap.entrySet()) { + for (Map.Entry qualifierEntry : familyEntry.getValue() + .entrySet()) { + columnMappingBuilder.setColumnFamily(ByteString.copyFrom(familyEntry.getKey().bytes())); + columnMappingBuilder.setQualifier(ByteString.copyFrom(qualifierEntry.getKey().bytes())); + columnMappingBuilder.setColumnName(qualifierEntry.getValue()); + builder.addCellToColumnMapping(columnMappingBuilder.build()); + } + } + builder.setEncoderClassName(encoderClassName); + + return builder.build().toByteArray(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof SparkSQLPushDownFilter)) { + return false; + } + if (this == obj) { + return true; + } + SparkSQLPushDownFilter f = (SparkSQLPushDownFilter) obj; + if (this.valueFromQueryArray.length != f.valueFromQueryArray.length) { + return false; + } + int i = 0; + for (byte[] val : this.valueFromQueryArray) { + if (!Bytes.equals(val, f.valueFromQueryArray[i])) { + return false; + } + i++; + } + return this.dynamicLogicExpression.equals(f.dynamicLogicExpression) + && this.currentCellToColumnIndexMap.equals(f.currentCellToColumnIndexMap) + && this.encoderClassName.equals(f.encoderClassName); + } + + @Override + public int hashCode() { + return Objects.hash(this.dynamicLogicExpression, Arrays.hashCode(this.valueFromQueryArray), + this.currentCellToColumnIndexMap, this.encoderClassName); + } +} diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkDeleteExample.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkDeleteExample.java new file mode 100644 index 00000000..8be0acc9 --- /dev/null +++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkDeleteExample.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext; + +import java.util.ArrayList; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Delete; +import org.apache.hadoop.hbase.spark.JavaHBaseContext; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.yetus.audience.InterfaceAudience; + +/** + * This is a simple example of deleting records in HBase with the bulkDelete function. + */ +@InterfaceAudience.Private +final public class JavaHBaseBulkDeleteExample { + + private JavaHBaseBulkDeleteExample() { + } + + public static void main(String[] args) { + if (args.length < 1) { + System.out.println("JavaHBaseBulkDeleteExample {tableName}"); + return; + } + + String tableName = args[0]; + + SparkConf sparkConf = new SparkConf().setAppName("JavaHBaseBulkDeleteExample " + tableName); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + try { + List list = new ArrayList<>(5); + list.add(Bytes.toBytes("1")); + list.add(Bytes.toBytes("2")); + list.add(Bytes.toBytes("3")); + list.add(Bytes.toBytes("4")); + list.add(Bytes.toBytes("5")); + + JavaRDD rdd = jsc.parallelize(list); + + Configuration conf = HBaseConfiguration.create(); + + JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); + + hbaseContext.bulkDelete(rdd, TableName.valueOf(tableName), new DeleteFunction(), 4); + } finally { + jsc.stop(); + } + + } + + public static class DeleteFunction implements Function { + private static final long serialVersionUID = 1L; + + public Delete call(byte[] v) throws Exception { + return new Delete(v); + } + } +} diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkGetExample.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkGetExample.java new file mode 100644 index 00000000..8ff21ea4 --- /dev/null +++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkGetExample.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.Cell; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Get; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.spark.JavaHBaseContext; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.yetus.audience.InterfaceAudience; + +/** + * This is a simple example of getting records in HBase with the bulkGet function. + */ +@InterfaceAudience.Private +final public class JavaHBaseBulkGetExample { + + private JavaHBaseBulkGetExample() { + } + + public static void main(String[] args) { + if (args.length < 1) { + System.out.println("JavaHBaseBulkGetExample {tableName}"); + return; + } + + String tableName = args[0]; + + SparkConf sparkConf = new SparkConf().setAppName("JavaHBaseBulkGetExample " + tableName); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + try { + List list = new ArrayList<>(5); + list.add(Bytes.toBytes("1")); + list.add(Bytes.toBytes("2")); + list.add(Bytes.toBytes("3")); + list.add(Bytes.toBytes("4")); + list.add(Bytes.toBytes("5")); + + JavaRDD rdd = jsc.parallelize(list); + + Configuration conf = HBaseConfiguration.create(); + + JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); + + hbaseContext.bulkGet(TableName.valueOf(tableName), 2, rdd, new GetFunction(), + new ResultFunction()); + } finally { + jsc.stop(); + } + } + + public static class GetFunction implements Function { + + private static final long serialVersionUID = 1L; + + public Get call(byte[] v) throws Exception { + return new Get(v); + } + } + + public static class ResultFunction implements Function { + + private static final long serialVersionUID = 1L; + + public String call(Result result) throws Exception { + Iterator it = result.listCells().iterator(); + StringBuilder b = new StringBuilder(); + + b.append(Bytes.toString(result.getRow())).append(":"); + + while (it.hasNext()) { + Cell cell = it.next(); + String q = Bytes.toString(cell.getQualifierArray()); + if (q.equals("counter")) { + b.append("(").append(Bytes.toString(cell.getQualifierArray())).append(",") + .append(Bytes.toLong(cell.getValueArray())).append(")"); + } else { + b.append("(").append(Bytes.toString(cell.getQualifierArray())).append(",") + .append(Bytes.toString(cell.getValueArray())).append(")"); + } + } + return b.toString(); + } + } +} diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkLoadExample.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkLoadExample.java new file mode 100644 index 00000000..44f1b339 --- /dev/null +++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkLoadExample.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.HConstants; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.spark.FamilyHFileWriteOptions; +import org.apache.hadoop.hbase.spark.JavaHBaseContext; +import org.apache.hadoop.hbase.spark.KeyFamilyQualifier; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.hadoop.hbase.util.Pair; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.yetus.audience.InterfaceAudience; + +/** + * Run this example using command below: SPARK_HOME/bin/spark-submit --master local[2] --class + * org.apache.hadoop.hbase.spark.example.hbasecontext.JavaHBaseBulkLoadExample + * path/to/hbase-spark.jar {path/to/output/HFiles} This example will output put hfiles in + * {path/to/output/HFiles}, and user can run 'hbase + * org.apache.hadoop.hbase.tool.LoadIncrementalHFiles' to load the HFiles into table to verify this + * example. + */ +@InterfaceAudience.Private +final public class JavaHBaseBulkLoadExample { + private JavaHBaseBulkLoadExample() { + } + + public static void main(String[] args) { + if (args.length < 1) { + System.out.println("JavaHBaseBulkLoadExample " + "{outputPath}"); + return; + } + + String tableName = "bulkload-table-test"; + String columnFamily1 = "f1"; + String columnFamily2 = "f2"; + + SparkConf sparkConf = new SparkConf().setAppName("JavaHBaseBulkLoadExample " + tableName); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + try { + List list = new ArrayList(); + // row1 + list.add("1," + columnFamily1 + ",b,1"); + // row3 + list.add("3," + columnFamily1 + ",a,2"); + list.add("3," + columnFamily1 + ",b,1"); + list.add("3," + columnFamily2 + ",a,1"); + /* row2 */ + list.add("2," + columnFamily2 + ",a,3"); + list.add("2," + columnFamily2 + ",b,3"); + + JavaRDD rdd = jsc.parallelize(list); + + Configuration conf = HBaseConfiguration.create(); + JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); + + hbaseContext.bulkLoad(rdd, TableName.valueOf(tableName), new BulkLoadFunction(), args[0], + new HashMap(), false, HConstants.DEFAULT_MAX_FILE_SIZE); + } finally { + jsc.stop(); + } + } + + public static class BulkLoadFunction + implements Function> { + @Override + public Pair call(String v1) throws Exception { + if (v1 == null) { + return null; + } + + String[] strs = v1.split(","); + if (strs.length != 4) { + return null; + } + + KeyFamilyQualifier kfq = new KeyFamilyQualifier(Bytes.toBytes(strs[0]), + Bytes.toBytes(strs[1]), Bytes.toBytes(strs[2])); + return new Pair(kfq, Bytes.toBytes(strs[3])); + } + } +} diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkPutExample.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkPutExample.java new file mode 100644 index 00000000..5f685a14 --- /dev/null +++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseBulkPutExample.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext; + +import java.util.ArrayList; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Put; +import org.apache.hadoop.hbase.spark.JavaHBaseContext; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.yetus.audience.InterfaceAudience; + +/** + * This is a simple example of putting records in HBase with the bulkPut function. + */ +@InterfaceAudience.Private +final public class JavaHBaseBulkPutExample { + + private JavaHBaseBulkPutExample() { + } + + public static void main(String[] args) { + if (args.length < 2) { + System.out.println("JavaHBaseBulkPutExample " + "{tableName} {columnFamily}"); + return; + } + + String tableName = args[0]; + String columnFamily = args[1]; + + SparkConf sparkConf = new SparkConf().setAppName("JavaHBaseBulkPutExample " + tableName); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + try { + List list = new ArrayList<>(5); + list.add("1," + columnFamily + ",a,1"); + list.add("2," + columnFamily + ",a,2"); + list.add("3," + columnFamily + ",a,3"); + list.add("4," + columnFamily + ",a,4"); + list.add("5," + columnFamily + ",a,5"); + + JavaRDD rdd = jsc.parallelize(list); + + Configuration conf = HBaseConfiguration.create(); + + JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); + + hbaseContext.bulkPut(rdd, TableName.valueOf(tableName), new PutFunction()); + } finally { + jsc.stop(); + } + } + + public static class PutFunction implements Function { + + private static final long serialVersionUID = 1L; + + public Put call(String v) throws Exception { + String[] cells = v.split(","); + Put put = new Put(Bytes.toBytes(cells[0])); + + put.addColumn(Bytes.toBytes(cells[1]), Bytes.toBytes(cells[2]), Bytes.toBytes(cells[3])); + return put; + } + + } +} diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseDistributedScan.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseDistributedScan.java new file mode 100644 index 00000000..76c7f6d8 --- /dev/null +++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseDistributedScan.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext; + +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.client.Scan; +import org.apache.hadoop.hbase.io.ImmutableBytesWritable; +import org.apache.hadoop.hbase.spark.JavaHBaseContext; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.yetus.audience.InterfaceAudience; +import scala.Tuple2; + +/** + * This is a simple example of scanning records from HBase with the hbaseRDD function. + */ +@InterfaceAudience.Private +final public class JavaHBaseDistributedScan { + + private JavaHBaseDistributedScan() { + } + + public static void main(String[] args) { + if (args.length < 1) { + System.out.println("JavaHBaseDistributedScan {tableName}"); + return; + } + + String tableName = args[0]; + + SparkConf sparkConf = new SparkConf().setAppName("JavaHBaseDistributedScan " + tableName); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + try { + Configuration conf = HBaseConfiguration.create(); + + JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); + + Scan scan = new Scan(); + scan.setCaching(100); + + JavaRDD> javaRdd = + hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan); + + List results = javaRdd.map(new ScanConvertFunction()).collect(); + + System.out.println("Result Size: " + results.size()); + } finally { + jsc.stop(); + } + } + + private static class ScanConvertFunction + implements Function, String> { + @Override + public String call(Tuple2 v1) throws Exception { + return Bytes.toString(v1._1().copyBytes()); + } + } +} diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseMapGetPutExample.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseMapGetPutExample.java new file mode 100644 index 00000000..c516ab35 --- /dev/null +++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseMapGetPutExample.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.BufferedMutator; +import org.apache.hadoop.hbase.client.Connection; +import org.apache.hadoop.hbase.client.Get; +import org.apache.hadoop.hbase.client.Put; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.client.Table; +import org.apache.hadoop.hbase.spark.JavaHBaseContext; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; +import org.apache.yetus.audience.InterfaceAudience; +import scala.Tuple2; + +/** + * This is a simple example of using the foreachPartition method with a HBase connection + */ +@InterfaceAudience.Private +final public class JavaHBaseMapGetPutExample { + + private JavaHBaseMapGetPutExample() { + } + + public static void main(String[] args) { + if (args.length < 1) { + System.out.println("JavaHBaseBulkGetExample {tableName}"); + return; + } + + final String tableName = args[0]; + + SparkConf sparkConf = new SparkConf().setAppName("JavaHBaseBulkGetExample " + tableName); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + try { + List list = new ArrayList<>(5); + list.add(Bytes.toBytes("1")); + list.add(Bytes.toBytes("2")); + list.add(Bytes.toBytes("3")); + list.add(Bytes.toBytes("4")); + list.add(Bytes.toBytes("5")); + + JavaRDD rdd = jsc.parallelize(list); + Configuration conf = HBaseConfiguration.create(); + + JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); + + hbaseContext.foreachPartition(rdd, new VoidFunction, Connection>>() { + public void call(Tuple2, Connection> t) throws Exception { + Table table = t._2().getTable(TableName.valueOf(tableName)); + BufferedMutator mutator = t._2().getBufferedMutator(TableName.valueOf(tableName)); + + while (t._1().hasNext()) { + byte[] b = t._1().next(); + Result r = table.get(new Get(b)); + if (r.getExists()) { + mutator.mutate(new Put(b)); + } + } + + mutator.flush(); + mutator.close(); + table.close(); + } + }); + } finally { + jsc.stop(); + } + } + + public static class GetFunction implements Function { + private static final long serialVersionUID = 1L; + + public Get call(byte[] v) throws Exception { + return new Get(v); + } + } +} diff --git a/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseStreamingBulkPutExample.java b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseStreamingBulkPutExample.java new file mode 100644 index 00000000..dc07a562 --- /dev/null +++ b/spark4/hbase-spark4/src/main/java/org/apache/hadoop/hbase/spark/example/hbasecontext/JavaHBaseStreamingBulkPutExample.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Put; +import org.apache.hadoop.hbase.spark.JavaHBaseContext; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.yetus.audience.InterfaceAudience; + +/** + * This is a simple example of BulkPut with Spark Streaming + */ +@InterfaceAudience.Private +final public class JavaHBaseStreamingBulkPutExample { + + private JavaHBaseStreamingBulkPutExample() { + } + + public static void main(String[] args) { + if (args.length < 4) { + System.out.println("JavaHBaseBulkPutExample " + "{host} {port} {tableName}"); + return; + } + + String host = args[0]; + String port = args[1]; + String tableName = args[2]; + + SparkConf sparkConf = new SparkConf() + .setAppName("JavaHBaseStreamingBulkPutExample " + tableName + ":" + port + ":" + tableName); + + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + try { + JavaStreamingContext jssc = new JavaStreamingContext(jsc, new Duration(1000)); + + JavaReceiverInputDStream javaDstream = + jssc.socketTextStream(host, Integer.parseInt(port)); + + Configuration conf = HBaseConfiguration.create(); + + JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); + + hbaseContext.streamBulkPut(javaDstream, TableName.valueOf(tableName), new PutFunction()); + } finally { + jsc.stop(); + } + } + + public static class PutFunction implements Function { + + private static final long serialVersionUID = 1L; + + public Put call(String v) throws Exception { + String[] part = v.split(","); + Put put = new Put(Bytes.toBytes(part[0])); + + put.addColumn(Bytes.toBytes(part[1]), Bytes.toBytes(part[2]), Bytes.toBytes(part[3])); + return put; + } + + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/BulkLoadPartitioner.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/BulkLoadPartitioner.scala new file mode 100644 index 00000000..17316638 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/BulkLoadPartitioner.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.util +import java.util.Comparator +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.Partitioner +import org.apache.yetus.audience.InterfaceAudience + +/** + * A Partitioner implementation that will separate records to different + * HBase Regions based on region splits + * + * @param startKeys The start keys for the given table + */ +@InterfaceAudience.Public +class BulkLoadPartitioner(startKeys: Array[Array[Byte]]) extends Partitioner { + // when table not exist, startKeys = Byte[0][] + override def numPartitions: Int = if (startKeys.length == 0) 1 else startKeys.length + + override def getPartition(key: Any): Int = { + + val comparator: Comparator[Array[Byte]] = new Comparator[Array[Byte]] { + override def compare(o1: Array[Byte], o2: Array[Byte]): Int = { + Bytes.compareTo(o1, o2) + } + } + + val rowKey: Array[Byte] = + key match { + case qualifier: KeyFamilyQualifier => + qualifier.rowKey + case wrapper: ByteArrayWrapper => + wrapper.value + case _ => + key.asInstanceOf[Array[Byte]] + } + var partition = util.Arrays.binarySearch(startKeys, rowKey, comparator) + if (partition < 0) + partition = partition * -1 + -2 + if (partition < 0) + partition = 0 + partition + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ByteArrayComparable.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ByteArrayComparable.scala new file mode 100644 index 00000000..78cd3abc --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ByteArrayComparable.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.util.Bytes +import org.apache.yetus.audience.InterfaceAudience + +@InterfaceAudience.Public +class ByteArrayComparable(val bytes: Array[Byte], val offset: Int = 0, var length: Int = -1) + extends Comparable[ByteArrayComparable] { + + if (length == -1) { + length = bytes.length + } + + override def compareTo(o: ByteArrayComparable): Int = { + Bytes.compareTo(bytes, offset, length, o.bytes, o.offset, o.length) + } + + override def hashCode(): Int = { + Bytes.hashCode(bytes, offset, length) + } + + override def equals(obj: Any): Boolean = { + obj match { + case b: ByteArrayComparable => + Bytes.equals(bytes, offset, length, b.bytes, b.offset, b.length) + case _ => + false + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ByteArrayWrapper.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ByteArrayWrapper.scala new file mode 100644 index 00000000..a774838e --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ByteArrayWrapper.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.io.Serializable +import org.apache.hadoop.hbase.util.Bytes +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a wrapper over a byte array so it can work as + * a key in a hashMap + * + * @param value The Byte Array value + */ +@InterfaceAudience.Public +class ByteArrayWrapper(var value: Array[Byte]) + extends Comparable[ByteArrayWrapper] + with Serializable { + override def compareTo(valueOther: ByteArrayWrapper): Int = { + Bytes.compareTo(value, valueOther.value) + } + override def equals(o2: Any): Boolean = { + o2 match { + case wrapper: ByteArrayWrapper => + Bytes.equals(value, wrapper.value) + case _ => + false + } + } + override def hashCode(): Int = { + Bytes.hashCode(value) + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ColumnFamilyQualifierMapKeyWrapper.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ColumnFamilyQualifierMapKeyWrapper.scala new file mode 100644 index 00000000..84883992 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/ColumnFamilyQualifierMapKeyWrapper.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.util.Bytes +import org.apache.yetus.audience.InterfaceAudience + +/** + * A wrapper class that will allow both columnFamily and qualifier to + * be the key of a hashMap. Also allow for finding the value in a hashmap + * with out cloning the HBase value from the HBase Cell object + * @param columnFamily ColumnFamily byte array + * @param columnFamilyOffSet Offset of columnFamily value in the array + * @param columnFamilyLength Length of the columnFamily value in the columnFamily array + * @param qualifier Qualifier byte array + * @param qualifierOffSet Offset of qualifier value in the array + * @param qualifierLength Length of the qualifier value with in the array + */ +@InterfaceAudience.Public +class ColumnFamilyQualifierMapKeyWrapper( + val columnFamily: Array[Byte], + val columnFamilyOffSet: Int, + val columnFamilyLength: Int, + val qualifier: Array[Byte], + val qualifierOffSet: Int, + val qualifierLength: Int) + extends Serializable { + + override def equals(other: Any): Boolean = { + val otherWrapper = other.asInstanceOf[ColumnFamilyQualifierMapKeyWrapper] + + Bytes.compareTo( + columnFamily, + columnFamilyOffSet, + columnFamilyLength, + otherWrapper.columnFamily, + otherWrapper.columnFamilyOffSet, + otherWrapper.columnFamilyLength) == 0 && Bytes.compareTo( + qualifier, + qualifierOffSet, + qualifierLength, + otherWrapper.qualifier, + otherWrapper.qualifierOffSet, + otherWrapper.qualifierLength) == 0 + } + + override def hashCode(): Int = { + Bytes.hashCode(columnFamily, columnFamilyOffSet, columnFamilyLength) + + Bytes.hashCode(qualifier, qualifierOffSet, qualifierLength) + } + + def cloneColumnFamily(): Array[Byte] = { + val resultArray = new Array[Byte](columnFamilyLength) + System.arraycopy(columnFamily, columnFamilyOffSet, resultArray, 0, columnFamilyLength) + resultArray + } + + def cloneQualifier(): Array[Byte] = { + val resultArray = new Array[Byte](qualifierLength) + System.arraycopy(qualifier, qualifierOffSet, resultArray, 0, qualifierLength) + resultArray + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala new file mode 100644 index 00000000..06dea823 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala @@ -0,0 +1,1232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.util +import java.util.concurrent.ConcurrentLinkedQueue +import org.apache.hadoop.hbase.CellUtil +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.HColumnDescriptor +import org.apache.hadoop.hbase.HTableDescriptor +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client._ +import org.apache.hadoop.hbase.io.ImmutableBytesWritable +import org.apache.hadoop.hbase.mapred.TableOutputFormat +import org.apache.hadoop.hbase.spark.datasources._ +import org.apache.hadoop.hbase.types._ +import org.apache.hadoop.hbase.util.{Bytes, PositionedByteRange, SimplePositionedMutableByteRange} +import org.apache.hadoop.mapred.JobConf +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.yetus.audience.InterfaceAudience +import scala.collection.mutable + +/** + * DefaultSource for integration with Spark's dataframe datasources. + * This class will produce a relationProvider based on input given to it from spark + * + * This class needs to stay in the current package 'org.apache.hadoop.hbase.spark' + * for Spark to match the hbase data source name. + * + * In all this DefaultSource support the following datasource functionality + * - Scan range pruning through filter push down logic based on rowKeys + * - Filter push down logic on HBase Cells + * - Qualifier filtering based on columns used in the SparkSQL statement + * - Type conversions of basic SQL types. All conversions will be + * Through the HBase Bytes object commands. + */ +@InterfaceAudience.Private +class DefaultSource extends RelationProvider with CreatableRelationProvider with Logging { + + /** + * Is given input from SparkSQL to construct a BaseRelation + * + * @param sqlContext SparkSQL context + * @param parameters Parameters given to us from SparkSQL + * @return A BaseRelation Object + */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + new HBaseRelation(parameters, None)(sqlContext) + } + + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val relation = HBaseRelation(parameters, Some(data.schema))(sqlContext) + relation.createTable() + relation.insert(data, false) + relation + } +} + +/** + * Implementation of Spark BaseRelation that will build up our scan logic + * , do the scan pruning, filter push down, and value conversions + * + * @param sqlContext SparkSQL context + */ +@InterfaceAudience.Private +case class HBaseRelation( + @transient parameters: Map[String, String], + userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext) + extends BaseRelation + with PrunedFilteredScan + with InsertableRelation + with Logging { + val timestamp = parameters.get(HBaseSparkConf.TIMESTAMP).map(_.toLong) + val minTimestamp = parameters.get(HBaseSparkConf.TIMERANGE_START).map(_.toLong) + val maxTimestamp = parameters.get(HBaseSparkConf.TIMERANGE_END).map(_.toLong) + val maxVersions = parameters.get(HBaseSparkConf.MAX_VERSIONS).map(_.toInt) + val encoderClsName = + parameters.get(HBaseSparkConf.QUERY_ENCODER).getOrElse(HBaseSparkConf.DEFAULT_QUERY_ENCODER) + + @transient val encoder = JavaBytesEncoder.create(encoderClsName) + + val catalog = HBaseTableCatalog(parameters) + def tableName = s"${catalog.namespace}:${catalog.name}" + val configResources = parameters.get(HBaseSparkConf.HBASE_CONFIG_LOCATION) + val useHBaseContext = parameters + .get(HBaseSparkConf.USE_HBASECONTEXT) + .map(_.toBoolean) + .getOrElse(HBaseSparkConf.DEFAULT_USE_HBASECONTEXT) + val usePushDownColumnFilter = parameters + .get(HBaseSparkConf.PUSHDOWN_COLUMN_FILTER) + .map(_.toBoolean) + .getOrElse(HBaseSparkConf.DEFAULT_PUSHDOWN_COLUMN_FILTER) + + // The user supplied per table parameter will overwrite global ones in SparkConf + val blockCacheEnable = parameters + .get(HBaseSparkConf.QUERY_CACHEBLOCKS) + .map(_.toBoolean) + .getOrElse(sqlContext.sparkContext.getConf + .getBoolean(HBaseSparkConf.QUERY_CACHEBLOCKS, HBaseSparkConf.DEFAULT_QUERY_CACHEBLOCKS)) + val cacheSize = parameters + .get(HBaseSparkConf.QUERY_CACHEDROWS) + .map(_.toInt) + .getOrElse(sqlContext.sparkContext.getConf.getInt(HBaseSparkConf.QUERY_CACHEDROWS, -1)) + val batchNum = parameters + .get(HBaseSparkConf.QUERY_BATCHSIZE) + .map(_.toInt) + .getOrElse(sqlContext.sparkContext.getConf.getInt(HBaseSparkConf.QUERY_BATCHSIZE, -1)) + + val bulkGetSize = parameters + .get(HBaseSparkConf.BULKGET_SIZE) + .map(_.toInt) + .getOrElse(sqlContext.sparkContext.getConf + .getInt(HBaseSparkConf.BULKGET_SIZE, HBaseSparkConf.DEFAULT_BULKGET_SIZE)) + + // create or get latest HBaseContext + val hbaseContext: HBaseContext = if (useHBaseContext) { + LatestHBaseContextCache.latest + } else { + val hadoopConfig = sqlContext.sparkContext.hadoopConfiguration + val config = HBaseConfiguration.create(hadoopConfig) + configResources.map( + resource => + resource + .split(",") + .foreach(r => config.addResource(r))) + new HBaseContext(sqlContext.sparkContext, config) + } + + val wrappedConf = new SerializableConfiguration(hbaseContext.config) + def hbaseConf = wrappedConf.value + + /** + * Generates a Spark SQL schema objeparametersct so Spark SQL knows what is being + * provided by this BaseRelation + * + * @return schema generated from the SCHEMA_COLUMNS_MAPPING_KEY value + */ + override val schema: StructType = userSpecifiedSchema.getOrElse(catalog.toDataType) + + def createTable() { + val numReg = parameters + .get(HBaseTableCatalog.newTable) + .map(x => x.toInt) + .getOrElse(0) + val startKey = Bytes.toBytes( + parameters + .get(HBaseTableCatalog.regionStart) + .getOrElse(HBaseTableCatalog.defaultRegionStart)) + val endKey = Bytes.toBytes( + parameters + .get(HBaseTableCatalog.regionEnd) + .getOrElse(HBaseTableCatalog.defaultRegionEnd)) + if (numReg > 3) { + val tName = TableName.valueOf(tableName) + val cfs = catalog.getColumnFamilies + + val connection = HBaseConnectionCache.getConnection(hbaseConf) + // Initialize hBase table if necessary + val admin = connection.getAdmin + try { + if (!admin.tableExists(tName)) { + val tableDesc = new HTableDescriptor(tName) + cfs.foreach { + x => + val cf = new HColumnDescriptor(x.getBytes()) + logDebug(s"add family $x to ${tableName}") + tableDesc.addFamily(cf) + } + val splitKeys = Bytes.split(startKey, endKey, numReg); + admin.createTable(tableDesc, splitKeys) + + } + } finally { + admin.close() + connection.close() + } + } else { + logInfo(s"""${HBaseTableCatalog.newTable} + |is not defined or no larger than 3, skip the create table""".stripMargin) + } + } + + /** + * @param data + * @param overwrite + */ + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + val jobConfig: JobConf = new JobConf(hbaseConf, this.getClass) + jobConfig.setOutputFormat(classOf[TableOutputFormat]) + jobConfig.set(TableOutputFormat.OUTPUT_TABLE, tableName) + var count = 0 + val rkFields = catalog.getRowKey + val rkIdxedFields = rkFields.map { + case x => + (schema.fieldIndex(x.colName), x) + } + val colsIdxedFields = schema.fieldNames + .partition(x => rkFields.map(_.colName).contains(x)) + ._2 + .map(x => (schema.fieldIndex(x), catalog.getField(x))) + val rdd = data.rdd + def convertToPut(row: Row) = { + // construct bytes for row key + val rowBytes = rkIdxedFields.map { + case (x, y) => + Utils.toBytes(row(x), y) + } + val rLen = rowBytes.foldLeft(0) { + case (x, y) => + x + y.length + } + val rBytes = new Array[Byte](rLen) + var offset = 0 + rowBytes.foreach { + x => + System.arraycopy(x, 0, rBytes, offset, x.length) + offset += x.length + } + val put = timestamp.fold(new Put(rBytes))(new Put(rBytes, _)) + + colsIdxedFields.foreach { + case (x, y) => + val r = row(x) + if (r != null) { + val b = Utils.toBytes(r, y) + put.addColumn(Bytes.toBytes(y.cf), Bytes.toBytes(y.col), b) + } + } + count += 1 + (new ImmutableBytesWritable, put) + } + rdd.map(convertToPut(_)).saveAsHadoopDataset(jobConfig) + } + + def getIndexedProjections(requiredColumns: Array[String]): Seq[(Field, Int)] = { + requiredColumns.map(catalog.sMap.getField(_)).zipWithIndex + } + + /** + * Takes a HBase Row object and parses all of the fields from it. + * This is independent of which fields were requested from the key + * Because we have all the data it's less complex to parse everything. + * + * @param row the retrieved row from hbase. + * @param keyFields all of the fields in the row key, ORDERED by their order in the row key. + */ + def parseRowKey(row: Array[Byte], keyFields: Seq[Field]): Map[Field, Any] = { + keyFields + .foldLeft((0, Seq[(Field, Any)]()))( + (state, field) => { + val idx = state._1 + val parsed = state._2 + if (field.length != -1) { + val value = Utils.hbaseFieldToScalaType(field, row, idx, field.length) + // Return the new index and appended value + (idx + field.length, parsed ++ Seq((field, value))) + } else { + field.dt match { + case StringType => + val pos = row.indexOf(HBaseTableCatalog.delimiter, idx) + if (pos == -1 || pos > row.length) { + // this is at the last dimension + val value = Utils.hbaseFieldToScalaType(field, row, idx, row.length) + (row.length + 1, parsed ++ Seq((field, value))) + } else { + val value = Utils.hbaseFieldToScalaType(field, row, idx, pos - idx) + (pos, parsed ++ Seq((field, value))) + } + // We don't know the length, assume it extends to the end of the rowkey. + case _ => + ( + row.length + 1, + parsed ++ Seq((field, Utils.hbaseFieldToScalaType(field, row, idx, row.length)))) + } + } + }) + ._2 + .toMap + } + + def buildRow(fields: Seq[Field], result: Result): Row = { + val r = result.getRow + val keySeq = parseRowKey(r, catalog.getRowKey) + val valueSeq = fields + .filter(!_.isRowKey) + .map { + x => + val kv = result.getColumnLatestCell(Bytes.toBytes(x.cf), Bytes.toBytes(x.col)) + if (kv == null || kv.getValueLength == 0) { + (x, null) + } else { + val v = CellUtil.cloneValue(kv) + ( + x, + x.dt match { + // Here, to avoid arraycopy, return v directly instead of calling hbaseFieldToScalaType + case BinaryType => v + case _ => Utils.hbaseFieldToScalaType(x, v, 0, v.length) + }) + } + } + .toMap + val unionedRow = keySeq ++ valueSeq + // Return the row ordered by the requested order + Row.fromSeq(fields.map(unionedRow.get(_).getOrElse(null))) + } + + /** + * Here we are building the functionality to populate the resulting RDD[Row] + * Here is where we will do the following: + * - Filter push down + * - Scan or GetList pruning + * - Executing our scan(s) or/and GetList to generate result + * + * @param requiredColumns The columns that are being requested by the requesting query + * @param filters The filters that are being applied by the requesting query + * @return RDD will all the results from HBase needed for SparkSQL to + * execute the query on + */ + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { + + val pushDownTuple = buildPushDownPredicatesResource(filters) + val pushDownRowKeyFilter = pushDownTuple._1 + var pushDownDynamicLogicExpression = pushDownTuple._2 + val valueArray = pushDownTuple._3 + + if (!usePushDownColumnFilter) { + pushDownDynamicLogicExpression = null + } + + logDebug("pushDownRowKeyFilter: " + pushDownRowKeyFilter.ranges) + if (pushDownDynamicLogicExpression != null) { + logDebug( + "pushDownDynamicLogicExpression: " + + pushDownDynamicLogicExpression.toExpressionString) + } + logDebug("valueArray: " + valueArray.length) + + val requiredQualifierDefinitionList = + new mutable.ListBuffer[Field] + + requiredColumns.foreach( + c => { + val field = catalog.getField(c) + requiredQualifierDefinitionList += field + }) + + // retain the information for unit testing checks + DefaultSourceStaticUtils.populateLatestExecutionRules( + pushDownRowKeyFilter, + pushDownDynamicLogicExpression) + + val getList = new util.ArrayList[Get]() + val rddList = new util.ArrayList[RDD[Row]]() + + // add points to getList + pushDownRowKeyFilter.points.foreach( + p => { + val get = new Get(p) + requiredQualifierDefinitionList.foreach( + d => { + if (d.isRowKey) + get.addColumn(d.cfBytes, d.colBytes) + }) + getList.add(get) + }) + + val pushDownFilterJava = + if (usePushDownColumnFilter && pushDownDynamicLogicExpression != null) { + Some( + new SparkSQLPushDownFilter( + pushDownDynamicLogicExpression, + valueArray, + requiredQualifierDefinitionList, + encoderClsName)) + } else { + None + } + val hRdd = new HBaseTableScanRDD( + this, + hbaseContext, + pushDownFilterJava, + requiredQualifierDefinitionList.seq.toSeq) + pushDownRowKeyFilter.points.foreach(hRdd.addPoint(_)) + pushDownRowKeyFilter.ranges.foreach(hRdd.addRange(_)) + + var resultRDD: RDD[Row] = { + val tmp = hRdd.map { + r => + val indexedFields = getIndexedProjections(requiredColumns).map(_._1) + buildRow(indexedFields, r) + + } + if (tmp.partitions.size > 0) { + tmp + } else { + null + } + } + + if (resultRDD == null) { + val scan = new Scan() + scan.setCacheBlocks(blockCacheEnable) + scan.setBatch(batchNum) + scan.setCaching(cacheSize) + requiredQualifierDefinitionList.foreach(d => scan.addColumn(d.cfBytes, d.colBytes)) + + val rdd = hbaseContext + .hbaseRDD(TableName.valueOf(tableName), scan) + .map( + r => { + val indexedFields = getIndexedProjections(requiredColumns).map(_._1) + buildRow(indexedFields, r._2) + }) + resultRDD = rdd + } + resultRDD + } + + def buildPushDownPredicatesResource( + filters: Array[Filter]): (RowKeyFilter, DynamicLogicExpression, Array[Array[Byte]]) = { + var superRowKeyFilter: RowKeyFilter = null + val queryValueList = new mutable.ListBuffer [Array[Byte]] + var superDynamicLogicExpression: DynamicLogicExpression = null + + filters.foreach( + f => { + val rowKeyFilter = new RowKeyFilter() + val logicExpression = transverseFilterTree(rowKeyFilter, queryValueList, f) + if (superDynamicLogicExpression == null) { + superDynamicLogicExpression = logicExpression + superRowKeyFilter = rowKeyFilter + } else { + superDynamicLogicExpression = + new AndLogicExpression(superDynamicLogicExpression, logicExpression) + superRowKeyFilter.mergeIntersect(rowKeyFilter) + } + + }) + + val queryValueArray = queryValueList.toArray + + if (superRowKeyFilter == null) { + superRowKeyFilter = new RowKeyFilter + } + + (superRowKeyFilter, superDynamicLogicExpression, queryValueArray) + } + + /** + * For some codec, the order may be inconsistent between java primitive + * type and its byte array. We may have to split the predicates on some + * of the java primitive type into multiple predicates. The encoder will take + * care of it and returning the concrete ranges. + * + * For example in naive codec, some of the java primitive types have to be split into multiple + * predicates, and union these predicates together to make the predicates be performed correctly. + * For example, if we have "COLUMN < 2", we will transform it into + * "0 <= COLUMN < 2 OR Integer.MIN_VALUE <= COLUMN <= -1" + */ + + def transverseFilterTree( + parentRowKeyFilter: RowKeyFilter, + valueArray: mutable.ListBuffer [Array[Byte]], + filter: Filter): DynamicLogicExpression = { + filter match { + case EqualTo(attr, value) => + val field = catalog.getField(attr) + if (field != null) { + if (field.isRowKey) { + parentRowKeyFilter.mergeIntersect(new RowKeyFilter(Utils.toBytes(value, field), null)) + } + val byteValue = Utils.toBytes(value, field) + valueArray += byteValue + } + new EqualLogicExpression(attr, valueArray.length - 1, false) + + /** + * encoder may split the predicates into multiple byte array boundaries. + * Each boundaries is mapped into the RowKeyFilter and then is unioned by the reduce + * operation. If the data type is not supported, b will be None, and there is + * no operation happens on the parentRowKeyFilter. + * + * Note that because LessThan is not inclusive, thus the first bound should be exclusive, + * which is controlled by inc. + * + * The other predicates, i.e., GreaterThan/LessThanOrEqual/GreaterThanOrEqual follows + * the similar logic. + */ + case LessThan(attr, value) => + val field = catalog.getField(attr) + if (field != null) { + if (field.isRowKey) { + val b = encoder.ranges(value) + var inc = false + b.map(_.less.map { + x => + val r = new RowKeyFilter(null, new ScanRange(x.upper, inc, x.low, true)) + inc = true + r + }).map { x => x.reduce { (i, j) => i.mergeUnion(j) } } + .map(parentRowKeyFilter.mergeIntersect(_)) + } + val byteValue = encoder.encode(field.dt, value) + valueArray += byteValue + } + new LessThanLogicExpression(attr, valueArray.length - 1) + case GreaterThan(attr, value) => + val field = catalog.getField(attr) + if (field != null) { + if (field.isRowKey) { + val b = encoder.ranges(value) + var inc = false + b.map(_.greater.map { + x => + val r = new RowKeyFilter(null, new ScanRange(x.upper, true, x.low, inc)) + inc = true + r + }).map { x => x.reduce { (i, j) => i.mergeUnion(j) } } + .map(parentRowKeyFilter.mergeIntersect(_)) + } + val byteValue = encoder.encode(field.dt, value) + valueArray += byteValue + } + new GreaterThanLogicExpression(attr, valueArray.length - 1) + case LessThanOrEqual(attr, value) => + val field = catalog.getField(attr) + if (field != null) { + if (field.isRowKey) { + val b = encoder.ranges(value) + b.map( + _.less.map(x => new RowKeyFilter(null, new ScanRange(x.upper, true, x.low, true)))) + .map { x => x.reduce { (i, j) => i.mergeUnion(j) } } + .map(parentRowKeyFilter.mergeIntersect(_)) + } + val byteValue = encoder.encode(field.dt, value) + valueArray += byteValue + } + new LessThanOrEqualLogicExpression(attr, valueArray.length - 1) + case GreaterThanOrEqual(attr, value) => + val field = catalog.getField(attr) + if (field != null) { + if (field.isRowKey) { + val b = encoder.ranges(value) + b.map( + _.greater.map(x => new RowKeyFilter(null, new ScanRange(x.upper, true, x.low, true)))) + .map { x => x.reduce { (i, j) => i.mergeUnion(j) } } + .map(parentRowKeyFilter.mergeIntersect(_)) + } + val byteValue = encoder.encode(field.dt, value) + valueArray += byteValue + } + new GreaterThanOrEqualLogicExpression(attr, valueArray.length - 1) + case StringStartsWith(attr, value) => + val field = catalog.getField(attr) + if (field != null) { + if (field.isRowKey) { + val p = Utils.toBytes(value, field) + val endRange = Utils.incrementByteArray(p) + parentRowKeyFilter.mergeIntersect( + new RowKeyFilter(null, new ScanRange(endRange, false, p, true))) + } + val byteValue = Utils.toBytes(value, field) + valueArray += byteValue + } + new StartsWithLogicExpression(attr, valueArray.length - 1) + case Or(left, right) => + val leftExpression = transverseFilterTree(parentRowKeyFilter, valueArray, left) + val rightSideRowKeyFilter = new RowKeyFilter + val rightExpression = transverseFilterTree(rightSideRowKeyFilter, valueArray, right) + + parentRowKeyFilter.mergeUnion(rightSideRowKeyFilter) + + new OrLogicExpression(leftExpression, rightExpression) + case And(left, right) => + val leftExpression = transverseFilterTree(parentRowKeyFilter, valueArray, left) + val rightSideRowKeyFilter = new RowKeyFilter + val rightExpression = transverseFilterTree(rightSideRowKeyFilter, valueArray, right) + parentRowKeyFilter.mergeIntersect(rightSideRowKeyFilter) + + new AndLogicExpression(leftExpression, rightExpression) + case IsNull(attr) => + new IsNullLogicExpression(attr, false) + case IsNotNull(attr) => + new IsNullLogicExpression(attr, true) + case _ => + new PassThroughLogicExpression + } + } +} + +/** + * Construct to contain a single scan ranges information. Also + * provide functions to merge with other scan ranges through AND + * or OR operators + * + * @param upperBound Upper bound of scan + * @param isUpperBoundEqualTo Include upper bound value in the results + * @param lowerBound Lower bound of scan + * @param isLowerBoundEqualTo Include lower bound value in the results + */ +@InterfaceAudience.Private +class ScanRange( + var upperBound: Array[Byte], + var isUpperBoundEqualTo: Boolean, + var lowerBound: Array[Byte], + var isLowerBoundEqualTo: Boolean) + extends Serializable { + + /** + * Function to merge another scan object through a AND operation + * + * @param other Other scan object + */ + def mergeIntersect(other: ScanRange): Unit = { + val upperBoundCompare = compareRange(upperBound, other.upperBound) + val lowerBoundCompare = compareRange(lowerBound, other.lowerBound) + + upperBound = if (upperBoundCompare < 0) upperBound else other.upperBound + lowerBound = if (lowerBoundCompare > 0) lowerBound else other.lowerBound + + isLowerBoundEqualTo = + if (lowerBoundCompare == 0) + isLowerBoundEqualTo && other.isLowerBoundEqualTo + else if (lowerBoundCompare < 0) other.isLowerBoundEqualTo + else isLowerBoundEqualTo + + isUpperBoundEqualTo = + if (upperBoundCompare == 0) + isUpperBoundEqualTo && other.isUpperBoundEqualTo + else if (upperBoundCompare < 0) isUpperBoundEqualTo + else other.isUpperBoundEqualTo + } + + /** + * Function to merge another scan object through a OR operation + * + * @param other Other scan object + */ + def mergeUnion(other: ScanRange): Unit = { + + val upperBoundCompare = compareRange(upperBound, other.upperBound) + val lowerBoundCompare = compareRange(lowerBound, other.lowerBound) + + upperBound = if (upperBoundCompare > 0) upperBound else other.upperBound + lowerBound = if (lowerBoundCompare < 0) lowerBound else other.lowerBound + + isLowerBoundEqualTo = + if (lowerBoundCompare == 0) + isLowerBoundEqualTo || other.isLowerBoundEqualTo + else if (lowerBoundCompare < 0) isLowerBoundEqualTo + else other.isLowerBoundEqualTo + + isUpperBoundEqualTo = + if (upperBoundCompare == 0) + isUpperBoundEqualTo || other.isUpperBoundEqualTo + else if (upperBoundCompare < 0) other.isUpperBoundEqualTo + else isUpperBoundEqualTo + } + + /** + * Common function to see if this scan over laps with another + * + * Reference Visual + * + * A B + * |---------------------------| + * LL--------------LU + * RL--------------RU + * + * A = lowest value is byte[0] + * B = highest value is null + * LL = Left Lower Bound + * LU = Left Upper Bound + * RL = Right Lower Bound + * RU = Right Upper Bound + * + * @param other Other scan object + * @return True is overlap false is not overlap + */ + def getOverLapScanRange(other: ScanRange): ScanRange = { + var leftRange: ScanRange = null + var rightRange: ScanRange = null + + // First identify the Left range + // Also lower bound can't be null + if (compareRange(lowerBound, other.lowerBound) < 0 || + compareRange(upperBound, other.upperBound) < 0) { + leftRange = this + rightRange = other + } else { + leftRange = other + rightRange = this + } + + if (hasOverlap(leftRange, rightRange)) { + val result = new ScanRange(upperBound, isUpperBoundEqualTo, lowerBound, isLowerBoundEqualTo) + result.mergeIntersect(other) + result + } else { + null + } + } + + /** + * The leftRange.upperBound has to be larger than the rightRange's lowerBound. + * Otherwise, there is no overlap. + * + * @param left: The range with the smaller lowBound + * @param right: The range with the larger lowBound + * @return Whether two ranges have overlap. + */ + + def hasOverlap(left: ScanRange, right: ScanRange): Boolean = { + compareRange(left.upperBound, right.lowerBound) >= 0 + } + + /** + * Special compare logic because we can have null values + * for left or right bound + * + * @param left Left byte array + * @param right Right byte array + * @return 0 for equals 1 is left is greater and -1 is right is greater + */ + def compareRange(left: Array[Byte], right: Array[Byte]): Int = { + if (left == null && right == null) 0 + else if (left == null && right != null) 1 + else if (left != null && right == null) -1 + else Bytes.compareTo(left, right) + } + + /** + * @return + */ + def containsPoint(point: Array[Byte]): Boolean = { + val lowerCompare = compareRange(point, lowerBound) + val upperCompare = compareRange(point, upperBound) + + ((isLowerBoundEqualTo && lowerCompare >= 0) || + (!isLowerBoundEqualTo && lowerCompare > 0)) && + ((isUpperBoundEqualTo && upperCompare <= 0) || + (!isUpperBoundEqualTo && upperCompare < 0)) + + } + override def toString: String = { + "ScanRange:(upperBound:" + Bytes.toString(upperBound) + + ",isUpperBoundEqualTo:" + isUpperBoundEqualTo + ",lowerBound:" + + Bytes.toString(lowerBound) + ",isLowerBoundEqualTo:" + isLowerBoundEqualTo + ")" + } +} + +/** + * Contains information related to a filters for a given column. + * This can contain many ranges or points. + * + * @param currentPoint the initial point when the filter is created + * @param currentRange the initial scanRange when the filter is created + */ +@InterfaceAudience.Private +class ColumnFilter( + currentPoint: Array[Byte] = null, + currentRange: ScanRange = null, + var points: mutable.ListBuffer [Array[Byte]] = new mutable.ListBuffer [Array[Byte]](), + var ranges: mutable.ListBuffer [ScanRange] = new mutable.ListBuffer [ScanRange]()) + extends Serializable { + // Collection of ranges + if (currentRange != null) ranges.+=(currentRange) + + // Collection of points + if (currentPoint != null) points.+=(currentPoint) + + /** + * This will validate a give value through the filter's points and/or ranges + * the result will be if the value passed the filter + * + * @param value Value to be validated + * @param valueOffSet The offset of the value + * @param valueLength The length of the value + * @return True is the value passes the filter false if not + */ + def validate(value: Array[Byte], valueOffSet: Int, valueLength: Int): Boolean = { + var result = false + + points.foreach( + p => { + if (Bytes.equals(p, 0, p.length, value, valueOffSet, valueLength)) { + result = true + } + }) + + ranges.foreach( + r => { + val upperBoundPass = r.upperBound == null || + (r.isUpperBoundEqualTo && + Bytes.compareTo( + r.upperBound, + 0, + r.upperBound.length, + value, + valueOffSet, + valueLength) >= 0) || + (!r.isUpperBoundEqualTo && + Bytes + .compareTo(r.upperBound, 0, r.upperBound.length, value, valueOffSet, valueLength) > 0) + + val lowerBoundPass = r.lowerBound == null || r.lowerBound.length == 0 + (r.isLowerBoundEqualTo && + Bytes + .compareTo( + r.lowerBound, + 0, + r.lowerBound.length, + value, + valueOffSet, + valueLength) <= 0) || + (!r.isLowerBoundEqualTo && + Bytes + .compareTo(r.lowerBound, 0, r.lowerBound.length, value, valueOffSet, valueLength) < 0) + + result = result || (upperBoundPass && lowerBoundPass) + }) + result + } + + /** + * This will allow us to merge filter logic that is joined to the existing filter + * through a OR operator + * + * @param other Filter to merge + */ + def mergeUnion(other: ColumnFilter): Unit = { + other.points.foreach(p => points += p) + + other.ranges.foreach( + otherR => { + var doesOverLap = false + ranges.foreach { + r => + if (r.getOverLapScanRange(otherR) != null) { + r.mergeUnion(otherR) + doesOverLap = true + } + } + if (!doesOverLap) ranges.+=(otherR) + }) + } + + /** + * This will allow us to merge filter logic that is joined to the existing filter + * through a AND operator + * + * @param other Filter to merge + */ + def mergeIntersect(other: ColumnFilter): Unit = { + val survivingPoints = new mutable.ListBuffer [Array[Byte]]() + points.foreach( + p => { + other.points.foreach( + otherP => { + if (Bytes.equals(p, otherP)) { + survivingPoints.+=(p) + } + }) + }) + points = survivingPoints + + val survivingRanges = new mutable.ListBuffer [ScanRange]() + + other.ranges.foreach( + otherR => { + ranges.foreach( + r => { + if (r.getOverLapScanRange(otherR) != null) { + r.mergeIntersect(otherR) + survivingRanges += r + } + }) + }) + ranges = survivingRanges + } + + override def toString: String = { + val strBuilder = new StringBuilder + strBuilder.append("(points:(") + var isFirst = true + points.foreach( + p => { + if (isFirst) isFirst = false + else strBuilder.append(",") + strBuilder.append(Bytes.toString(p)) + }) + strBuilder.append("),ranges:") + isFirst = true + ranges.foreach( + r => { + if (isFirst) isFirst = false + else strBuilder.append(",") + strBuilder.append(r) + }) + strBuilder.append("))") + strBuilder.toString() + } +} + +/** + * A collection of ColumnFilters indexed by column names. + * + * Also contains merge commends that will consolidate the filters + * per column name + */ +@InterfaceAudience.Private +class ColumnFilterCollection { + val columnFilterMap = new mutable.HashMap[String, ColumnFilter] + + def clear(): Unit = { + columnFilterMap.clear() + } + + /** + * This will allow us to merge filter logic that is joined to the existing filter + * through a OR operator. This will merge a single columns filter + * + * @param column The column to be merged + * @param other The other ColumnFilter object to merge + */ + def mergeUnion(column: String, other: ColumnFilter): Unit = { + val existingFilter = columnFilterMap.get(column) + if (existingFilter.isEmpty) { + columnFilterMap.+=((column, other)) + } else { + existingFilter.get.mergeUnion(other) + } + } + + /** + * This will allow us to merge all filters in the existing collection + * to the filters in the other collection. All merges are done as a result + * of a OR operator + * + * @param other The other Column Filter Collection to be merged + */ + def mergeUnion(other: ColumnFilterCollection): Unit = { + other.columnFilterMap.foreach( + e => { + mergeUnion(e._1, e._2) + }) + } + + /** + * This will allow us to merge all filters in the existing collection + * to the filters in the other collection. All merges are done as a result + * of a AND operator + * + * @param other The column filter from the other collection + */ + def mergeIntersect(other: ColumnFilterCollection): Unit = { + other.columnFilterMap.foreach( + e => { + val existingColumnFilter = columnFilterMap.get(e._1) + if (existingColumnFilter.isEmpty) { + columnFilterMap += e + } else { + existingColumnFilter.get.mergeIntersect(e._2) + } + }) + } + + override def toString: String = { + val strBuilder = new StringBuilder + columnFilterMap.foreach(e => strBuilder.append(e)) + strBuilder.toString() + } +} + +/** + * Status object to store static functions but also to hold last executed + * information that can be used for unit testing. + */ +@InterfaceAudience.Private +object DefaultSourceStaticUtils { + + val byteRange = new ThreadLocal[PositionedByteRange] { + override def initialValue(): PositionedByteRange = { + val range = new SimplePositionedMutableByteRange() + range.setOffset(0) + range.setPosition(0) + } + } + + def getFreshByteRange(bytes: Array[Byte]): PositionedByteRange = { + getFreshByteRange(bytes, 0, bytes.length) + } + + def getFreshByteRange(bytes: Array[Byte], offset: Int = 0, length: Int): PositionedByteRange = { + byteRange.get().set(bytes).setLength(length).setOffset(offset) + } + + // This will contain the last 5 filters and required fields used in buildScan + // These values can be used in unit testing to make sure we are converting + // The Spark SQL input correctly + val lastFiveExecutionRules = + new ConcurrentLinkedQueue[ExecutionRuleForUnitTesting]() + + /** + * This method is to populate the lastFiveExecutionRules for unit test perposes + * This method is not thread safe. + * + * @param rowKeyFilter The rowKey Filter logic used in the last query + * @param dynamicLogicExpression The dynamicLogicExpression used in the last query + */ + def populateLatestExecutionRules( + rowKeyFilter: RowKeyFilter, + dynamicLogicExpression: DynamicLogicExpression): Unit = { + lastFiveExecutionRules.add( + new ExecutionRuleForUnitTesting(rowKeyFilter, dynamicLogicExpression)) + while (lastFiveExecutionRules.size() > 5) { + lastFiveExecutionRules.poll() + } + } +} + +/** + * Contains information related to a filters for a given column. + * This can contain many ranges or points. + * + * @param currentPoint the initial point when the filter is created + * @param currentRange the initial scanRange when the filter is created + */ +@InterfaceAudience.Private +class RowKeyFilter( + currentPoint: Array[Byte] = null, + currentRange: ScanRange = new ScanRange(null, true, new Array[Byte](0), true), + var points: mutable.ListBuffer [Array[Byte]] = new mutable.ListBuffer [Array[Byte]](), + var ranges: mutable.ListBuffer [ScanRange] = new mutable.ListBuffer [ScanRange]()) + extends Serializable { + // Collection of ranges + if (currentRange != null) ranges.+=(currentRange) + + // Collection of points + if (currentPoint != null) points.+=(currentPoint) + + /** + * This will validate a give value through the filter's points and/or ranges + * the result will be if the value passed the filter + * + * @param value Value to be validated + * @param valueOffSet The offset of the value + * @param valueLength The length of the value + * @return True is the value passes the filter false if not + */ + def validate(value: Array[Byte], valueOffSet: Int, valueLength: Int): Boolean = { + var result = false + + points.foreach( + p => { + if (Bytes.equals(p, 0, p.length, value, valueOffSet, valueLength)) { + result = true + } + }) + + ranges.foreach( + r => { + val upperBoundPass = r.upperBound == null || + (r.isUpperBoundEqualTo && + Bytes.compareTo( + r.upperBound, + 0, + r.upperBound.length, + value, + valueOffSet, + valueLength) >= 0) || + (!r.isUpperBoundEqualTo && + Bytes + .compareTo(r.upperBound, 0, r.upperBound.length, value, valueOffSet, valueLength) > 0) + + val lowerBoundPass = r.lowerBound == null || r.lowerBound.length == 0 + (r.isLowerBoundEqualTo && + Bytes + .compareTo( + r.lowerBound, + 0, + r.lowerBound.length, + value, + valueOffSet, + valueLength) <= 0) || + (!r.isLowerBoundEqualTo && + Bytes + .compareTo(r.lowerBound, 0, r.lowerBound.length, value, valueOffSet, valueLength) < 0) + + result = result || (upperBoundPass && lowerBoundPass) + }) + result + } + + /** + * This will allow us to merge filter logic that is joined to the existing filter + * through a OR operator + * + * @param other Filter to merge + */ + def mergeUnion(other: RowKeyFilter): RowKeyFilter = { + other.points.foreach(p => points += p) + + other.ranges.foreach( + otherR => { + var doesOverLap = false + ranges.foreach { + r => + if (r.getOverLapScanRange(otherR) != null) { + r.mergeUnion(otherR) + doesOverLap = true + } + } + if (!doesOverLap) ranges.+=(otherR) + }) + this + } + + /** + * This will allow us to merge filter logic that is joined to the existing filter + * through a AND operator + * + * @param other Filter to merge + */ + def mergeIntersect(other: RowKeyFilter): RowKeyFilter = { + val survivingPoints = new mutable.ListBuffer [Array[Byte]]() + val didntSurviveFirstPassPoints = new mutable.ListBuffer [Array[Byte]]() + if (points == null || points.length == 0) { + other.points.foreach( + otherP => { + didntSurviveFirstPassPoints += otherP + }) + } else { + points.foreach( + p => { + if (other.points.length == 0) { + didntSurviveFirstPassPoints += p + } else { + other.points.foreach( + otherP => { + if (Bytes.equals(p, otherP)) { + survivingPoints += p + } else { + didntSurviveFirstPassPoints += p + } + }) + } + }) + } + + val survivingRanges = new mutable.ListBuffer [ScanRange]() + + if (ranges.length == 0) { + didntSurviveFirstPassPoints.foreach( + p => { + survivingPoints += p + }) + } else { + ranges.foreach( + r => { + other.ranges.foreach( + otherR => { + val overLapScanRange = r.getOverLapScanRange(otherR) + if (overLapScanRange != null) { + survivingRanges += overLapScanRange + } + }) + didntSurviveFirstPassPoints.foreach( + p => { + if (r.containsPoint(p)) { + survivingPoints += p + } + }) + }) + } + points = survivingPoints + ranges = survivingRanges + this + } + + override def toString: String = { + val strBuilder = new StringBuilder + strBuilder.append("(points:(") + var isFirst = true + points.foreach( + p => { + if (isFirst) isFirst = false + else strBuilder.append(",") + strBuilder.append(Bytes.toString(p)) + }) + strBuilder.append("),ranges:") + isFirst = true + ranges.foreach( + r => { + if (isFirst) isFirst = false + else strBuilder.append(",") + strBuilder.append(r) + }) + strBuilder.append("))") + strBuilder.toString() + } +} + +@InterfaceAudience.Private +class ExecutionRuleForUnitTesting( + val rowKeyFilter: RowKeyFilter, + val dynamicLogicExpression: DynamicLogicExpression) diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala new file mode 100644 index 00000000..398093a5 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.util +import org.apache.hadoop.hbase.spark.datasources.{BytesEncoder, JavaBytesEncoder} +import org.apache.hadoop.hbase.spark.datasources.JavaBytesEncoder.JavaBytesEncoder +import org.apache.hadoop.hbase.util.Bytes +import org.apache.yetus.audience.InterfaceAudience + +/** + * Dynamic logic for SQL push down logic there is an instance for most + * common operations and a pass through for other operations not covered here + * + * Logic can be nested with And or Or operators. + * + * A logic tree can be written out as a string and reconstructed from that string + */ +@InterfaceAudience.Private +trait DynamicLogicExpression { + def execute( + columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable], + valueFromQueryValueArray: Array[Array[Byte]]): Boolean + def toExpressionString: String = { + val strBuilder = new StringBuilder + appendToExpression(strBuilder) + strBuilder.toString() + } + def filterOps: JavaBytesEncoder = JavaBytesEncoder.Unknown + + def appendToExpression(strBuilder: StringBuilder) + + var encoder: BytesEncoder = _ + + def setEncoder(enc: BytesEncoder): DynamicLogicExpression = { + encoder = enc + this + } +} + +@InterfaceAudience.Private +trait CompareTrait { + self: DynamicLogicExpression => + def columnName: String + def valueFromQueryIndex: Int + def execute( + columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable], + valueFromQueryValueArray: Array[Array[Byte]]): Boolean = { + val currentRowValue = columnToCurrentRowValueMap.get(columnName) + val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) + currentRowValue != null && + encoder.filter( + currentRowValue.bytes, + currentRowValue.offset, + currentRowValue.length, + valueFromQuery, + 0, + valueFromQuery.length, + filterOps) + } +} + +@InterfaceAudience.Private +class AndLogicExpression( + val leftExpression: DynamicLogicExpression, + val rightExpression: DynamicLogicExpression) + extends DynamicLogicExpression { + override def execute( + columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable], + valueFromQueryValueArray: Array[Array[Byte]]): Boolean = { + leftExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray) && + rightExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray) + } + + override def appendToExpression(strBuilder: StringBuilder): Unit = { + strBuilder.append("( ") + strBuilder.append(leftExpression.toExpressionString) + strBuilder.append(" AND ") + strBuilder.append(rightExpression.toExpressionString) + strBuilder.append(" )") + } +} + +@InterfaceAudience.Private +class OrLogicExpression( + val leftExpression: DynamicLogicExpression, + val rightExpression: DynamicLogicExpression) + extends DynamicLogicExpression { + override def execute( + columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable], + valueFromQueryValueArray: Array[Array[Byte]]): Boolean = { + leftExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray) || + rightExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray) + } + override def appendToExpression(strBuilder: StringBuilder): Unit = { + strBuilder.append("( ") + strBuilder.append(leftExpression.toExpressionString) + strBuilder.append(" OR ") + strBuilder.append(rightExpression.toExpressionString) + strBuilder.append(" )") + } +} + +@InterfaceAudience.Private +class EqualLogicExpression(val columnName: String, val valueFromQueryIndex: Int, val isNot: Boolean) + extends DynamicLogicExpression { + override def execute( + columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable], + valueFromQueryValueArray: Array[Array[Byte]]): Boolean = { + val currentRowValue = columnToCurrentRowValueMap.get(columnName) + val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) + + currentRowValue != null && + Bytes.equals( + valueFromQuery, + 0, + valueFromQuery.length, + currentRowValue.bytes, + currentRowValue.offset, + currentRowValue.length) != isNot + } + override def appendToExpression(strBuilder: StringBuilder): Unit = { + val command = if (isNot) "!=" else "==" + strBuilder.append(columnName + " " + command + " " + valueFromQueryIndex) + } +} + +@InterfaceAudience.Private +class StartsWithLogicExpression(val columnName: String, val valueFromQueryIndex: Int) + extends DynamicLogicExpression { + override def execute( + columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable], + valueFromQueryValueArray: Array[Array[Byte]]): Boolean = { + val currentRowValue = columnToCurrentRowValueMap.get(columnName) + val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) + + currentRowValue != null && valueFromQuery != null && currentRowValue.length >= valueFromQuery.length && + Bytes.equals( + valueFromQuery, + 0, + valueFromQuery.length, + currentRowValue.bytes, + currentRowValue.offset, + valueFromQuery.length) + } + override def appendToExpression(strBuilder: StringBuilder): Unit = { + strBuilder.append(columnName + " startsWith " + valueFromQueryIndex) + } +} + +@InterfaceAudience.Private +class IsNullLogicExpression(val columnName: String, val isNot: Boolean) + extends DynamicLogicExpression { + override def execute( + columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable], + valueFromQueryValueArray: Array[Array[Byte]]): Boolean = { + val currentRowValue = columnToCurrentRowValueMap.get(columnName) + + (currentRowValue == null) != isNot + } + override def appendToExpression(strBuilder: StringBuilder): Unit = { + val command = if (isNot) "isNotNull" else "isNull" + strBuilder.append(columnName + " " + command) + } +} + +@InterfaceAudience.Private +class GreaterThanLogicExpression( + override val columnName: String, + override val valueFromQueryIndex: Int) + extends DynamicLogicExpression + with CompareTrait { + override val filterOps = JavaBytesEncoder.Greater + override def appendToExpression(strBuilder: StringBuilder): Unit = { + strBuilder.append(columnName + " > " + valueFromQueryIndex) + } +} + +@InterfaceAudience.Private +class GreaterThanOrEqualLogicExpression( + override val columnName: String, + override val valueFromQueryIndex: Int) + extends DynamicLogicExpression + with CompareTrait { + override val filterOps = JavaBytesEncoder.GreaterEqual + override def appendToExpression(strBuilder: StringBuilder): Unit = { + strBuilder.append(columnName + " >= " + valueFromQueryIndex) + } +} + +@InterfaceAudience.Private +class LessThanLogicExpression( + override val columnName: String, + override val valueFromQueryIndex: Int) + extends DynamicLogicExpression + with CompareTrait { + override val filterOps = JavaBytesEncoder.Less + override def appendToExpression(strBuilder: StringBuilder): Unit = { + strBuilder.append(columnName + " < " + valueFromQueryIndex) + } +} + +@InterfaceAudience.Private +class LessThanOrEqualLogicExpression(val columnName: String, val valueFromQueryIndex: Int) + extends DynamicLogicExpression + with CompareTrait { + override val filterOps = JavaBytesEncoder.LessEqual + override def appendToExpression(strBuilder: StringBuilder): Unit = { + strBuilder.append(columnName + " <= " + valueFromQueryIndex) + } +} + +@InterfaceAudience.Private +class PassThroughLogicExpression() extends DynamicLogicExpression { + override def execute( + columnToCurrentRowValueMap: util.HashMap[String, ByteArrayComparable], + valueFromQueryValueArray: Array[Array[Byte]]): Boolean = true + + override def appendToExpression(strBuilder: StringBuilder): Unit = { + // Fix the offset bug by add dummy to avoid crash the region server. + // because in the DynamicLogicExpressionBuilder.build function, the command is always retrieved from offset + 1 as below + // val command = expressionArray(offSet + 1) + // we have to padding it so that `Pass` is on the right offset. + strBuilder.append("dummy Pass -1") + } +} + +@InterfaceAudience.Private +object DynamicLogicExpressionBuilder { + def build(expressionString: String, encoder: BytesEncoder): DynamicLogicExpression = { + + val expressionAndOffset = build(expressionString.split(' '), 0, encoder) + expressionAndOffset._1 + } + + private def build( + expressionArray: Array[String], + offSet: Int, + encoder: BytesEncoder): (DynamicLogicExpression, Int) = { + val expr = { + if (expressionArray(offSet).equals("(")) { + val left = build(expressionArray, offSet + 1, encoder) + val right = build(expressionArray, left._2 + 1, encoder) + if (expressionArray(left._2).equals("AND")) { + (new AndLogicExpression(left._1, right._1), right._2 + 1) + } else if (expressionArray(left._2).equals("OR")) { + (new OrLogicExpression(left._1, right._1), right._2 + 1) + } else { + throw new Throwable("Unknown gate:" + expressionArray(left._2)) + } + } else { + val command = expressionArray(offSet + 1) + if (command.equals("<")) { + ( + new LessThanLogicExpression(expressionArray(offSet), expressionArray(offSet + 2).toInt), + offSet + 3) + } else if (command.equals("<=")) { + ( + new LessThanOrEqualLogicExpression( + expressionArray(offSet), + expressionArray(offSet + 2).toInt), + offSet + 3) + } else if (command.equals(">")) { + ( + new GreaterThanLogicExpression( + expressionArray(offSet), + expressionArray(offSet + 2).toInt), + offSet + 3) + } else if (command.equals(">=")) { + ( + new GreaterThanOrEqualLogicExpression( + expressionArray(offSet), + expressionArray(offSet + 2).toInt), + offSet + 3) + } else if (command.equals("==")) { + ( + new EqualLogicExpression( + expressionArray(offSet), + expressionArray(offSet + 2).toInt, + false), + offSet + 3) + } else if (command.equals("!=")) { + ( + new EqualLogicExpression( + expressionArray(offSet), + expressionArray(offSet + 2).toInt, + true), + offSet + 3) + } else if (command.equals("startsWith")) { + ( + new StartsWithLogicExpression( + expressionArray(offSet), + expressionArray(offSet + 2).toInt), + offSet + 3) + } else if (command.equals("isNull")) { + (new IsNullLogicExpression(expressionArray(offSet), false), offSet + 2) + } else if (command.equals("isNotNull")) { + (new IsNullLogicExpression(expressionArray(offSet), true), offSet + 2) + } else if (command.equals("Pass")) { + (new PassThroughLogicExpression, offSet + 3) + } else { + throw new Throwable("Unknown logic command:" + command) + } + } + } + expr._1.setEncoder(encoder) + expr + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/FamiliesQualifiersValues.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/FamiliesQualifiersValues.scala new file mode 100644 index 00000000..93410a42 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/FamiliesQualifiersValues.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.util +import org.apache.yetus.audience.InterfaceAudience; + +/** + * This object is a clean way to store and sort all cells that will be bulk + * loaded into a single row + */ +@InterfaceAudience.Public +class FamiliesQualifiersValues extends Serializable { + // Tree maps are used because we need the results to + // be sorted when we read them + val familyMap = new util.TreeMap[ByteArrayWrapper, util.TreeMap[ByteArrayWrapper, Array[Byte]]]() + + // normally in a row there are more columns then + // column families this wrapper is reused for column + // family look ups + val reusableWrapper = new ByteArrayWrapper(null) + + /** + * Adds a new cell to an existing row + * @param family HBase column family + * @param qualifier HBase column qualifier + * @param value HBase cell value + */ + def +=(family: Array[Byte], qualifier: Array[Byte], value: Array[Byte]): Unit = { + + reusableWrapper.value = family + + var qualifierValues = familyMap.get(reusableWrapper) + + if (qualifierValues == null) { + qualifierValues = new util.TreeMap[ByteArrayWrapper, Array[Byte]]() + familyMap.put(new ByteArrayWrapper(family), qualifierValues) + } + + qualifierValues.put(new ByteArrayWrapper(qualifier), value) + } + + /** + * A wrapper for "+=" method above, can be used by Java + * @param family HBase column family + * @param qualifier HBase column qualifier + * @param value HBase cell value + */ + def add(family: Array[Byte], qualifier: Array[Byte], value: Array[Byte]): Unit = { + this += (family, qualifier, value) + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/FamilyHFileWriteOptions.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/FamilyHFileWriteOptions.scala new file mode 100644 index 00000000..05117c51 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/FamilyHFileWriteOptions.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.io.Serializable +import org.apache.yetus.audience.InterfaceAudience; + +/** + * This object will hold optional data for how a given column family's + * writer will work + * + * @param compression String to define the Compression to be used in the HFile + * @param bloomType String to define the bloom type to be used in the HFile + * @param blockSize The block size to be used in the HFile + * @param dataBlockEncoding String to define the data block encoding to be used + * in the HFile + */ +@InterfaceAudience.Public +class FamilyHFileWriteOptions( + val compression: String, + val bloomType: String, + val blockSize: Int, + val dataBlockEncoding: String) + extends Serializable diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCache.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCache.scala new file mode 100644 index 00000000..812975f3 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCache.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.io.IOException +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hbase.HConstants +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.Admin +import org.apache.hadoop.hbase.client.Connection +import org.apache.hadoop.hbase.client.ConnectionFactory +import org.apache.hadoop.hbase.client.RegionLocator +import org.apache.hadoop.hbase.client.Table +import org.apache.hadoop.hbase.ipc.RpcControllerFactory +import org.apache.hadoop.hbase.security.User +import org.apache.hadoop.hbase.security.UserProvider +import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf +import org.apache.yetus.audience.InterfaceAudience +import scala.collection.mutable + +@InterfaceAudience.Private +private[spark] object HBaseConnectionCache extends Logging { + + // A hashmap of Spark-HBase connections. Key is HBaseConnectionKey. + val connectionMap = new mutable.HashMap[HBaseConnectionKey, SmartConnection]() + + val cacheStat = HBaseConnectionCacheStat(0, 0, 0) + + // in milliseconds + private final val DEFAULT_TIME_OUT: Long = HBaseSparkConf.DEFAULT_CONNECTION_CLOSE_DELAY + private var timeout = DEFAULT_TIME_OUT + private var closed: Boolean = false + + var housekeepingThread = new Thread(new Runnable { + override def run() { + while (true) { + try { + Thread.sleep(timeout) + } catch { + case e: InterruptedException => + // setTimeout() and close() may interrupt the sleep and it's safe + // to ignore the exception + } + if (closed) + return + performHousekeeping(false) + } + } + }) + housekeepingThread.setDaemon(true) + housekeepingThread.start() + + def getStat: HBaseConnectionCacheStat = { + connectionMap.synchronized { + cacheStat.numActiveConnections = connectionMap.size + cacheStat.copy() + } + } + + def close(): Unit = { + try { + connectionMap.synchronized { + if (closed) + return + closed = true + housekeepingThread.interrupt() + housekeepingThread = null + HBaseConnectionCache.performHousekeeping(true) + } + } catch { + case e: Exception => logWarning("Error in finalHouseKeeping", e) + } + } + + def performHousekeeping(forceClean: Boolean) = { + val tsNow: Long = System.currentTimeMillis() + connectionMap.synchronized { + connectionMap.foreach { + x => + { + if (x._2.refCount < 0) { + logError(s"Bug to be fixed: negative refCount of connection ${x._2}") + } + + if (forceClean || ((x._2.refCount <= 0) && (tsNow - x._2.timestamp > timeout))) { + try { + x._2.connection.close() + } catch { + case e: IOException => logWarning(s"Fail to close connection ${x._2}", e) + } + connectionMap.remove(x._1) + } + } + } + } + } + + // For testing purpose only + def getConnection(key: HBaseConnectionKey, conn: => Connection): SmartConnection = { + connectionMap.synchronized { + if (closed) + return null + cacheStat.numTotalRequests += 1 + val sc = connectionMap.getOrElseUpdate( + key, { + cacheStat.numActualConnectionsCreated += 1 + new SmartConnection(conn) + }) + sc.refCount += 1 + sc + } + } + + def getConnection(conf: Configuration): SmartConnection = + getConnection(new HBaseConnectionKey(conf), ConnectionFactory.createConnection(conf)) + + // For testing purpose only + def setTimeout(to: Long): Unit = { + connectionMap.synchronized { + if (closed) + return + timeout = to + housekeepingThread.interrupt() + } + } +} + +@InterfaceAudience.Private +private[hbase] case class SmartConnection( + connection: Connection, + var refCount: Int = 0, + var timestamp: Long = 0) { + def getTable(tableName: TableName): Table = connection.getTable(tableName) + def getRegionLocator(tableName: TableName): RegionLocator = connection.getRegionLocator(tableName) + def isClosed: Boolean = connection.isClosed + def getAdmin: Admin = connection.getAdmin + def close() = { + HBaseConnectionCache.connectionMap.synchronized { + refCount -= 1 + if (refCount <= 0) + timestamp = System.currentTimeMillis() + } + } +} + +/** + * Denotes a unique key to an HBase Connection instance. + * Please refer to 'org.apache.hadoop.hbase.client.HConnectionKey'. + * + * In essence, this class captures the properties in Configuration + * that may be used in the process of establishing a connection. + */ +@InterfaceAudience.Private +class HBaseConnectionKey(c: Configuration) extends Logging { + val conf: Configuration = c + val CONNECTION_PROPERTIES: Array[String] = Array[String]( + HConstants.ZOOKEEPER_QUORUM, + HConstants.ZOOKEEPER_ZNODE_PARENT, + HConstants.ZOOKEEPER_CLIENT_PORT, + HConstants.HBASE_CLIENT_PAUSE, + HConstants.HBASE_CLIENT_RETRIES_NUMBER, + HConstants.HBASE_RPC_TIMEOUT_KEY, + HConstants.HBASE_META_SCANNER_CACHING, + HConstants.HBASE_CLIENT_INSTANCE_ID, + HConstants.RPC_CODEC_CONF_KEY, + HConstants.USE_META_REPLICAS, + RpcControllerFactory.CUSTOM_CONTROLLER_CONF_KEY) + + var username: String = _ + var m_properties = mutable.HashMap.empty[String, String] + if (conf != null) { + for (property <- CONNECTION_PROPERTIES) { + val value: String = conf.get(property) + if (value != null) { + m_properties.+=((property, value)) + } + } + try { + val provider: UserProvider = UserProvider.instantiate(conf) + val currentUser: User = provider.getCurrent + if (currentUser != null) { + username = currentUser.getName + } + } catch { + case e: IOException => { + logWarning("Error obtaining current user, skipping username in HBaseConnectionKey", e) + } + } + } + + // make 'properties' immutable + val properties = m_properties.toMap + + override def hashCode: Int = { + val prime: Int = 31 + var result: Int = 1 + if (username != null) { + result = username.hashCode + } + for (property <- CONNECTION_PROPERTIES) { + val value: Option[String] = properties.get(property) + if (value.isDefined) { + result = prime * result + value.hashCode + } + } + result + } + + override def equals(obj: Any): Boolean = { + if (obj == null) return false + if (getClass ne obj.getClass) return false + val that: HBaseConnectionKey = obj.asInstanceOf[HBaseConnectionKey] + if (this.username != null && !(this.username == that.username)) { + return false + } else if (this.username == null && that.username != null) { + return false + } + if (this.properties == null) { + if (that.properties != null) { + return false + } + } else { + if (that.properties == null) { + return false + } + var flag: Boolean = true + for (property <- CONNECTION_PROPERTIES) { + val thisValue: Option[String] = this.properties.get(property) + val thatValue: Option[String] = that.properties.get(property) + flag = true + if (thisValue eq thatValue) { + flag = false // continue, so make flag be false + } + if (flag && (thisValue == null || !(thisValue == thatValue))) { + return false + } + } + } + true + } + + override def toString: String = { + "HBaseConnectionKey{" + "properties=" + properties + ", username='" + username + '\'' + '}' + } +} + +/** + * To log the state of 'HBaseConnectionCache' + * + * @param numTotalRequests number of total connection requests to the cache + * @param numActualConnectionsCreated number of actual HBase connections the cache ever created + * @param numActiveConnections number of current alive HBase connections the cache is holding + */ +@InterfaceAudience.Private +case class HBaseConnectionCacheStat( + var numTotalRequests: Long, + var numActualConnectionsCreated: Long, + var numActiveConnections: Long) diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala new file mode 100644 index 00000000..e17051db --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala @@ -0,0 +1,1134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.io._ +import java.net.InetSocketAddress +import java.util +import java.util.UUID +import javax.management.openmbean.KeyAlreadyExistsException +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileAlreadyExistsException, FileSystem, Path} +import org.apache.hadoop.hbase._ +import org.apache.hadoop.hbase.client._ +import org.apache.hadoop.hbase.fs.HFileSystem +import org.apache.hadoop.hbase.io.ImmutableBytesWritable +import org.apache.hadoop.hbase.io.compress.Compression +import org.apache.hadoop.hbase.io.compress.Compression.Algorithm +import org.apache.hadoop.hbase.io.encoding.DataBlockEncoding +import org.apache.hadoop.hbase.io.hfile.{CacheConfig, HFile, HFileContextBuilder, HFileWriterImpl} +import org.apache.hadoop.hbase.mapreduce.{IdentityTableMapper, TableInputFormat, TableMapReduceUtil} +import org.apache.hadoop.hbase.regionserver.{BloomType, HStoreFile, StoreFileWriter, StoreUtils} +import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._ +import org.apache.hadoop.hbase.util.{Bytes, ChecksumType} +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod +import org.apache.spark.{SerializableWritable, SparkContext} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.DStream +import org.apache.yetus.audience.InterfaceAudience +import scala.collection.mutable +import scala.reflect.ClassTag + +/** + * HBaseContext is a façade for HBase operations + * like bulk put, get, increment, delete, and scan + * + * HBaseContext will take the responsibilities + * of disseminating the configuration information + * to the working and managing the life cycle of Connections. + */ +@InterfaceAudience.Public +class HBaseContext( + @transient val sc: SparkContext, + @transient val config: Configuration, + val tmpHdfsConfgFile: String = null) + extends Serializable + with Logging { + + @transient var tmpHdfsConfiguration: Configuration = config + @transient var appliedCredentials = false + @transient val job = Job.getInstance(config) + TableMapReduceUtil.initCredentials(job) + val broadcastedConf = sc.broadcast(new SerializableWritable(config)) + + LatestHBaseContextCache.latest = this + + if (tmpHdfsConfgFile != null && config != null) { + val fs = FileSystem.newInstance(config) + val tmpPath = new Path(tmpHdfsConfgFile) + if (!fs.exists(tmpPath)) { + val outputStream = fs.create(tmpPath) + config.write(outputStream) + outputStream.close() + } else { + logWarning("tmpHdfsConfigDir " + tmpHdfsConfgFile + " exist!!") + } + } + + /** + * A simple enrichment of the traditional Spark RDD foreachPartition. + * This function differs from the original in that it offers the + * developer access to a already connected Connection object + * + * Note: Do not close the Connection object. All Connection + * management is handled outside this method + * + * @param rdd Original RDD with data to iterate over + * @param f Function to be given a iterator to iterate through + * the RDD values and a Connection object to interact + * with HBase + */ + def foreachPartition[T](rdd: RDD[T], f: (Iterator[T], Connection) => Unit): Unit = { + rdd.foreachPartition(it => hbaseForeachPartition(broadcastedConf, it, f)) + } + + /** + * A simple enrichment of the traditional Spark Streaming dStream foreach + * This function differs from the original in that it offers the + * developer access to a already connected Connection object + * + * Note: Do not close the Connection object. All Connection + * management is handled outside this method + * + * @param dstream Original DStream with data to iterate over + * @param f Function to be given a iterator to iterate through + * the DStream values and a Connection object to + * interact with HBase + */ + def foreachPartition[T](dstream: DStream[T], f: (Iterator[T], Connection) => Unit): Unit = { + dstream.foreachRDD( + (rdd, time) => { + foreachPartition(rdd, f) + }) + } + + /** + * A simple enrichment of the traditional Spark RDD mapPartition. + * This function differs from the original in that it offers the + * developer access to a already connected Connection object + * + * Note: Do not close the Connection object. All Connection + * management is handled outside this method + * + * @param rdd Original RDD with data to iterate over + * @param mp Function to be given a iterator to iterate through + * the RDD values and a Connection object to interact + * with HBase + * @return Returns a new RDD generated by the user definition + * function just like normal mapPartition + */ + def mapPartitions[T, R: ClassTag]( + rdd: RDD[T], + mp: (Iterator[T], Connection) => Iterator[R]): RDD[R] = { + + rdd.mapPartitions[R](it => hbaseMapPartition[T, R](broadcastedConf, it, mp)) + + } + + /** + * A simple enrichment of the traditional Spark Streaming DStream + * foreachPartition. + * + * This function differs from the original in that it offers the + * developer access to a already connected Connection object + * + * Note: Do not close the Connection object. All Connection + * management is handled outside this method + * + * Note: Make sure to partition correctly to avoid memory issue when + * getting data from HBase + * + * @param dstream Original DStream with data to iterate over + * @param f Function to be given a iterator to iterate through + * the DStream values and a Connection object to + * interact with HBase + * @return Returns a new DStream generated by the user + * definition function just like normal mapPartition + */ + def streamForeachPartition[T](dstream: DStream[T], f: (Iterator[T], Connection) => Unit): Unit = { + + dstream.foreachRDD(rdd => this.foreachPartition(rdd, f)) + } + + /** + * A simple enrichment of the traditional Spark Streaming DStream + * mapPartition. + * + * This function differs from the original in that it offers the + * developer access to a already connected Connection object + * + * Note: Do not close the Connection object. All Connection + * management is handled outside this method + * + * Note: Make sure to partition correctly to avoid memory issue when + * getting data from HBase + * + * @param dstream Original DStream with data to iterate over + * @param f Function to be given a iterator to iterate through + * the DStream values and a Connection object to + * interact with HBase + * @return Returns a new DStream generated by the user + * definition function just like normal mapPartition + */ + def streamMapPartitions[T, U: ClassTag]( + dstream: DStream[T], + f: (Iterator[T], Connection) => Iterator[U]): DStream[U] = { + dstream.mapPartitions(it => hbaseMapPartition[T, U](broadcastedConf, it, f)) + } + + /** + * A simple abstraction over the HBaseContext.foreachPartition method. + * + * It allow addition support for a user to take RDD + * and generate puts and send them to HBase. + * The complexity of managing the Connection is + * removed from the developer + * + * @param rdd Original RDD with data to iterate over + * @param tableName The name of the table to put into + * @param f Function to convert a value in the RDD to a HBase Put + */ + def bulkPut[T](rdd: RDD[T], tableName: TableName, f: (T) => Put) { + + val tName = tableName.getName + rdd.foreachPartition( + it => + hbaseForeachPartition[T]( + broadcastedConf, + it, + (iterator, connection) => { + val m = connection.getBufferedMutator(TableName.valueOf(tName)) + iterator.foreach(T => m.mutate(f(T))) + m.flush() + m.close() + })) + } + + def applyCreds[T]() { + if (!appliedCredentials) { + appliedCredentials = true + + @transient val ugi = UserGroupInformation.getCurrentUser + // specify that this is a proxy user + ugi.setAuthenticationMethod(AuthenticationMethod.PROXY) + } + } + + /** + * A simple abstraction over the HBaseContext.streamMapPartition method. + * + * It allow addition support for a user to take a DStream and + * generate puts and send them to HBase. + * + * The complexity of managing the Connection is + * removed from the developer + * + * @param dstream Original DStream with data to iterate over + * @param tableName The name of the table to put into + * @param f Function to convert a value in + * the DStream to a HBase Put + */ + def streamBulkPut[T](dstream: DStream[T], tableName: TableName, f: (T) => Put) = { + val tName = tableName.getName + dstream.foreachRDD( + (rdd, time) => { + bulkPut(rdd, TableName.valueOf(tName), f) + }) + } + + /** + * A simple abstraction over the HBaseContext.foreachPartition method. + * + * It allow addition support for a user to take a RDD and generate delete + * and send them to HBase. The complexity of managing the Connection is + * removed from the developer + * + * @param rdd Original RDD with data to iterate over + * @param tableName The name of the table to delete from + * @param f Function to convert a value in the RDD to a + * HBase Deletes + * @param batchSize The number of delete to batch before sending to HBase + */ + def bulkDelete[T](rdd: RDD[T], tableName: TableName, f: (T) => Delete, batchSize: Integer) { + bulkMutation(rdd, tableName, f, batchSize) + } + + /** + * A simple abstraction over the HBaseContext.streamBulkMutation method. + * + * It allow addition support for a user to take a DStream and + * generate Delete and send them to HBase. + * + * The complexity of managing the Connection is + * removed from the developer + * + * @param dstream Original DStream with data to iterate over + * @param tableName The name of the table to delete from + * @param f function to convert a value in the DStream to a + * HBase Delete + * @param batchSize The number of deletes to batch before sending to HBase + */ + def streamBulkDelete[T]( + dstream: DStream[T], + tableName: TableName, + f: (T) => Delete, + batchSize: Integer) = { + streamBulkMutation(dstream, tableName, f, batchSize) + } + + /** + * Under lining function to support all bulk mutations + * + * May be opened up if requested + */ + private def bulkMutation[T]( + rdd: RDD[T], + tableName: TableName, + f: (T) => Mutation, + batchSize: Integer) { + + val tName = tableName.getName + rdd.foreachPartition( + it => + hbaseForeachPartition[T]( + broadcastedConf, + it, + (iterator, connection) => { + val table = connection.getTable(TableName.valueOf(tName)) + val mutationList = new java.util.ArrayList[Mutation] + iterator.foreach( + T => { + mutationList.add(f(T)) + if (mutationList.size >= batchSize) { + table.batch(mutationList, null) + mutationList.clear() + } + }) + if (mutationList.size() > 0) { + table.batch(mutationList, null) + mutationList.clear() + } + table.close() + })) + } + + /** + * Under lining function to support all bulk streaming mutations + * + * May be opened up if requested + */ + private def streamBulkMutation[T]( + dstream: DStream[T], + tableName: TableName, + f: (T) => Mutation, + batchSize: Integer) = { + val tName = tableName.getName + dstream.foreachRDD( + (rdd, time) => { + bulkMutation(rdd, TableName.valueOf(tName), f, batchSize) + }) + } + + /** + * A simple abstraction over the HBaseContext.mapPartition method. + * + * It allow addition support for a user to take a RDD and generates a + * new RDD based on Gets and the results they bring back from HBase + * + * @param rdd Original RDD with data to iterate over + * @param tableName The name of the table to get from + * @param makeGet function to convert a value in the RDD to a + * HBase Get + * @param convertResult This will convert the HBase Result object to + * what ever the user wants to put in the resulting + * RDD + * return new RDD that is created by the Get to HBase + */ + def bulkGet[T, U: ClassTag]( + tableName: TableName, + batchSize: Integer, + rdd: RDD[T], + makeGet: (T) => Get, + convertResult: (Result) => U): RDD[U] = { + + val getMapPartition = new GetMapPartition(tableName, batchSize, makeGet, convertResult) + + rdd.mapPartitions[U](it => hbaseMapPartition[T, U](broadcastedConf, it, getMapPartition.run)) + } + + /** + * A simple abstraction over the HBaseContext.streamMap method. + * + * It allow addition support for a user to take a DStream and + * generates a new DStream based on Gets and the results + * they bring back from HBase + * + * @param tableName The name of the table to get from + * @param batchSize The number of Gets to be sent in a single batch + * @param dStream Original DStream with data to iterate over + * @param makeGet Function to convert a value in the DStream to a + * HBase Get + * @param convertResult This will convert the HBase Result object to + * what ever the user wants to put in the resulting + * DStream + * @return A new DStream that is created by the Get to HBase + */ + def streamBulkGet[T, U: ClassTag]( + tableName: TableName, + batchSize: Integer, + dStream: DStream[T], + makeGet: (T) => Get, + convertResult: (Result) => U): DStream[U] = { + + val getMapPartition = new GetMapPartition(tableName, batchSize, makeGet, convertResult) + + dStream.mapPartitions[U]( + it => hbaseMapPartition[T, U](broadcastedConf, it, getMapPartition.run)) + } + + /** + * This function will use the native HBase TableInputFormat with the + * given scan object to generate a new RDD + * + * @param tableName the name of the table to scan + * @param scan the HBase scan object to use to read data from HBase + * @param f function to convert a Result object from HBase into + * what the user wants in the final generated RDD + * @return new RDD with results from scan + */ + def hbaseRDD[U: ClassTag]( + tableName: TableName, + scan: Scan, + f: ((ImmutableBytesWritable, Result)) => U): RDD[U] = { + + val job: Job = Job.getInstance(getConf(broadcastedConf)) + + TableMapReduceUtil.initCredentials(job) + TableMapReduceUtil.initTableMapperJob( + tableName, + scan, + classOf[IdentityTableMapper], + null, + null, + job) + + val jconf = new JobConf(job.getConfiguration) + val jobCreds = jconf.getCredentials() + UserGroupInformation.setConfiguration(sc.hadoopConfiguration) + jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials()) + + new NewHBaseRDD( + sc, + classOf[TableInputFormat], + classOf[ImmutableBytesWritable], + classOf[Result], + job.getConfiguration, + this).map(f) + } + + /** + * A overloaded version of HBaseContext hbaseRDD that defines the + * type of the resulting RDD + * + * @param tableName the name of the table to scan + * @param scans the HBase scan object to use to read data from HBase + * @return New RDD with results from scan + */ + def hbaseRDD(tableName: TableName, scans: Scan): RDD[(ImmutableBytesWritable, Result)] = { + + hbaseRDD[(ImmutableBytesWritable, Result)]( + tableName, + scans, + (r: (ImmutableBytesWritable, Result)) => r) + } + + /** + * underlining wrapper all foreach functions in HBaseContext + */ + private def hbaseForeachPartition[T]( + configBroadcast: Broadcast[SerializableWritable[Configuration]], + it: Iterator[T], + f: (Iterator[T], Connection) => Unit) = { + + val config = getConf(configBroadcast) + + applyCreds + // specify that this is a proxy user + val smartConn = HBaseConnectionCache.getConnection(config) + try { + f(it, smartConn.connection) + } finally { + if (smartConn != null) smartConn.close() + } + } + + private def getConf( + configBroadcast: Broadcast[SerializableWritable[Configuration]]): Configuration = { + + if (tmpHdfsConfiguration == null && tmpHdfsConfgFile != null) { + val fs = FileSystem.newInstance(sc.hadoopConfiguration) + val inputStream = fs.open(new Path(tmpHdfsConfgFile)) + tmpHdfsConfiguration = new Configuration(false) + tmpHdfsConfiguration.readFields(inputStream) + inputStream.close() + } + + if (tmpHdfsConfiguration == null) { + try { + tmpHdfsConfiguration = configBroadcast.value.value + } catch { + case ex: Exception => logError("Unable to getConfig from broadcast", ex) + } + } + tmpHdfsConfiguration + } + + /** + * underlining wrapper all mapPartition functions in HBaseContext + */ + private def hbaseMapPartition[K, U]( + configBroadcast: Broadcast[SerializableWritable[Configuration]], + it: Iterator[K], + mp: (Iterator[K], Connection) => Iterator[U]): Iterator[U] = { + + val config = getConf(configBroadcast) + applyCreds + + val smartConn = HBaseConnectionCache.getConnection(config) + try { + mp(it, smartConn.connection) + } finally { + if (smartConn != null) smartConn.close() + } + } + + /** + * underlining wrapper all get mapPartition functions in HBaseContext + */ + private class GetMapPartition[T, U: ClassTag]( + tableName: TableName, + batchSize: Integer, + makeGet: (T) => Get, + convertResult: (Result) => U) + extends Serializable { + + val tName = tableName.getName + + def run(iterator: Iterator[T], connection: Connection): Iterator[U] = { + val table = connection.getTable(TableName.valueOf(tName)) + + val gets = new java.util.ArrayList[Get]() + var res = List[U]() + + while (iterator.hasNext) { + gets.add(makeGet(iterator.next())) + + if (gets.size() == batchSize) { + val results = table.get(gets) + res = res ++ results.map(convertResult) + gets.clear() + } + } + if (gets.size() > 0) { + val results = table.get(gets) + res = res ++ results.map(convertResult) + gets.clear() + } + table.close() + res.iterator + } + } + + /** + * Produces a ClassTag[T], which is actually just a casted ClassTag[AnyRef]. + * + * This method is used to keep ClassTags out of the external Java API, as + * the Java compiler cannot produce them automatically. While this + * ClassTag-faking does please the compiler, it can cause problems at runtime + * if the Scala API relies on ClassTags for correctness. + * + * Often, though, a ClassTag[AnyRef] will not lead to incorrect behavior, + * just worse performance or security issues. + * For instance, an Array of AnyRef can hold any type T, but may lose primitive + * specialization. + */ + private[spark] def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]] + + /** + * Spark Implementation of HBase Bulk load for wide rows or when + * values are not already combined at the time of the map process + * + * This will take the content from an existing RDD then sort and shuffle + * it with respect to region splits. The result of that sort and shuffle + * will be written to HFiles. + * + * After this function is executed the user will have to call + * LoadIncrementalHFiles.doBulkLoad(...) to move the files into HBase + * + * Also note this version of bulk load is different from past versions in + * that it includes the qualifier as part of the sort process. The + * reason for this is to be able to support rows will very large number + * of columns. + * + * @param rdd The RDD we are bulk loading from + * @param tableName The HBase table we are loading into + * @param flatMap A flapMap function that will make every + * row in the RDD + * into N cells for the bulk load + * @param stagingDir The location on the FileSystem to bulk load into + * @param familyHFileWriteOptionsMap Options that will define how the HFile for a + * column family is written + * @param compactionExclude Compaction excluded for the HFiles + * @param maxSize Max size for the HFiles before they roll + * @param nowTimeStamp Version timestamp + * @tparam T The Type of values in the original RDD + */ + def bulkLoad[T]( + rdd: RDD[T], + tableName: TableName, + flatMap: (T) => Iterator[(KeyFamilyQualifier, Array[Byte])], + stagingDir: String, + familyHFileWriteOptionsMap: util.Map[Array[Byte], FamilyHFileWriteOptions] = + new util.HashMap[Array[Byte], FamilyHFileWriteOptions], + compactionExclude: Boolean = false, + maxSize: Long = HConstants.DEFAULT_MAX_FILE_SIZE, + nowTimeStamp: Long = System.currentTimeMillis()): Unit = { + val stagingPath = new Path(stagingDir) + val fs = stagingPath.getFileSystem(config) + if (fs.exists(stagingPath)) { + throw new FileAlreadyExistsException("Path " + stagingDir + " already exists") + } + val conn = HBaseConnectionCache.getConnection(config) + try { + val regionLocator = conn.getRegionLocator(tableName) + val startKeys = regionLocator.getStartKeys + if (startKeys.length == 0) { + logInfo("Table " + tableName.toString + " was not found") + } + val defaultCompressionStr = + config.get("hfile.compression", Compression.Algorithm.NONE.getName) + val hfileCompression = HFileWriterImpl + .compressionByName(defaultCompressionStr) + val tableRawName = tableName.getName + + val familyHFileWriteOptionsMapInternal = + new util.HashMap[ByteArrayWrapper, FamilyHFileWriteOptions] + + val entrySetIt = familyHFileWriteOptionsMap.entrySet().iterator() + + while (entrySetIt.hasNext) { + val entry = entrySetIt.next() + familyHFileWriteOptionsMapInternal.put(new ByteArrayWrapper(entry.getKey), entry.getValue) + } + + val regionSplitPartitioner = + new BulkLoadPartitioner(startKeys) + + // This is where all the magic happens + // Here we are going to do the following things + // 1. FlapMap every row in the RDD into key column value tuples + // 2. Then we are going to repartition sort and shuffle + // 3. Finally we are going to write out our HFiles + rdd + .flatMap(r => flatMap(r)) + .repartitionAndSortWithinPartitions(regionSplitPartitioner) + .hbaseForeachPartition( + this, + (it, conn) => { + + val conf = broadcastedConf.value.value + val fs = new Path(stagingDir).getFileSystem(conf) + val writerMap = new mutable.HashMap[ByteArrayWrapper, WriterLength] + var previousRow: Array[Byte] = HConstants.EMPTY_BYTE_ARRAY + var rollOverRequested = false + val localTableName = TableName.valueOf(tableRawName) + + // Here is where we finally iterate through the data in this partition of the + // RDD that has been sorted and partitioned + it.foreach { + case (keyFamilyQualifier, cellValue: Array[Byte]) => + val wl = writeValueToHFile( + keyFamilyQualifier.rowKey, + keyFamilyQualifier.family, + keyFamilyQualifier.qualifier, + cellValue, + nowTimeStamp, + fs, + conn, + localTableName, + conf, + familyHFileWriteOptionsMapInternal, + hfileCompression, + writerMap, + stagingDir) + + rollOverRequested = rollOverRequested || wl.written > maxSize + + // This will only roll if we have at least one column family file that is + // bigger then maxSize and we have finished a given row key + if (rollOverRequested && Bytes + .compareTo(previousRow, keyFamilyQualifier.rowKey) != 0) { + rollWriters(fs, writerMap, regionSplitPartitioner, previousRow, compactionExclude) + rollOverRequested = false + } + + previousRow = keyFamilyQualifier.rowKey + } + // We have finished all the data so lets close up the writers + rollWriters(fs, writerMap, regionSplitPartitioner, previousRow, compactionExclude) + rollOverRequested = false + }) + } finally { + if (null != conn) conn.close() + } + } + + /** + * Spark Implementation of HBase Bulk load for short rows some where less then + * a 1000 columns. This bulk load should be faster for tables will thinner + * rows then the other spark implementation of bulk load that puts only one + * value into a record going into a shuffle + * + * This will take the content from an existing RDD then sort and shuffle + * it with respect to region splits. The result of that sort and shuffle + * will be written to HFiles. + * + * After this function is executed the user will have to call + * LoadIncrementalHFiles.doBulkLoad(...) to move the files into HBase + * + * In this implementation, only the rowKey is given to the shuffle as the key + * and all the columns are already linked to the RowKey before the shuffle + * stage. The sorting of the qualifier is done in memory out side of the + * shuffle stage + * + * Also make sure that incoming RDDs only have one record for every row key. + * + * @param rdd The RDD we are bulk loading from + * @param tableName The HBase table we are loading into + * @param mapFunction A function that will convert the RDD records to + * the key value format used for the shuffle to prep + * for writing to the bulk loaded HFiles + * @param stagingDir The location on the FileSystem to bulk load into + * @param familyHFileWriteOptionsMap Options that will define how the HFile for a + * column family is written + * @param compactionExclude Compaction excluded for the HFiles + * @param maxSize Max size for the HFiles before they roll + * @tparam T The Type of values in the original RDD + */ + def bulkLoadThinRows[T]( + rdd: RDD[T], + tableName: TableName, + mapFunction: (T) => (ByteArrayWrapper, FamiliesQualifiersValues), + stagingDir: String, + familyHFileWriteOptionsMap: util.Map[Array[Byte], FamilyHFileWriteOptions] = + new util.HashMap[Array[Byte], FamilyHFileWriteOptions], + compactionExclude: Boolean = false, + maxSize: Long = HConstants.DEFAULT_MAX_FILE_SIZE): Unit = { + val stagingPath = new Path(stagingDir) + val fs = stagingPath.getFileSystem(config) + if (fs.exists(stagingPath)) { + throw new FileAlreadyExistsException("Path " + stagingDir + " already exists") + } + val conn = HBaseConnectionCache.getConnection(config) + try { + val regionLocator = conn.getRegionLocator(tableName) + val startKeys = regionLocator.getStartKeys + if (startKeys.length == 0) { + logInfo("Table " + tableName.toString + " was not found") + } + val defaultCompressionStr = + config.get("hfile.compression", Compression.Algorithm.NONE.getName) + val defaultCompression = HFileWriterImpl + .compressionByName(defaultCompressionStr) + val nowTimeStamp = System.currentTimeMillis() + val tableRawName = tableName.getName + + val familyHFileWriteOptionsMapInternal = + new util.HashMap[ByteArrayWrapper, FamilyHFileWriteOptions] + + val entrySetIt = familyHFileWriteOptionsMap.entrySet().iterator() + + while (entrySetIt.hasNext) { + val entry = entrySetIt.next() + familyHFileWriteOptionsMapInternal.put(new ByteArrayWrapper(entry.getKey), entry.getValue) + } + + val regionSplitPartitioner = + new BulkLoadPartitioner(startKeys) + + // This is where all the magic happens + // Here we are going to do the following things + // 1. FlapMap every row in the RDD into key column value tuples + // 2. Then we are going to repartition sort and shuffle + // 3. Finally we are going to write out our HFiles + rdd + .map(r => mapFunction(r)) + .repartitionAndSortWithinPartitions(regionSplitPartitioner) + .hbaseForeachPartition( + this, + (it, conn) => { + + val conf = broadcastedConf.value.value + val fs = new Path(stagingDir).getFileSystem(conf) + val writerMap = new mutable.HashMap[ByteArrayWrapper, WriterLength] + var previousRow: Array[Byte] = HConstants.EMPTY_BYTE_ARRAY + var rollOverRequested = false + val localTableName = TableName.valueOf(tableRawName) + + // Here is where we finally iterate through the data in this partition of the + // RDD that has been sorted and partitioned + it.foreach { + case (rowKey: ByteArrayWrapper, familiesQualifiersValues: FamiliesQualifiersValues) => + if (Bytes.compareTo(previousRow, rowKey.value) == 0) { + throw new KeyAlreadyExistsException( + "The following key was sent to the " + + "HFile load more then one: " + Bytes.toString(previousRow)) + } + + // The family map is a tree map so the families will be sorted + val familyIt = familiesQualifiersValues.familyMap.entrySet().iterator() + while (familyIt.hasNext) { + val familyEntry = familyIt.next() + + val family = familyEntry.getKey.value + + val qualifierIt = familyEntry.getValue.entrySet().iterator() + + // The qualifier map is a tree map so the families will be sorted + while (qualifierIt.hasNext) { + + val qualifierEntry = qualifierIt.next() + val qualifier = qualifierEntry.getKey + val cellValue = qualifierEntry.getValue + + writeValueToHFile( + rowKey.value, + family, + qualifier.value, // qualifier + cellValue, // value + nowTimeStamp, + fs, + conn, + localTableName, + conf, + familyHFileWriteOptionsMapInternal, + defaultCompression, + writerMap, + stagingDir) + + previousRow = rowKey.value + } + + writerMap.values.foreach( + wl => { + rollOverRequested = rollOverRequested || wl.written > maxSize + + // This will only roll if we have at least one column family file that is + // bigger then maxSize and we have finished a given row key + if (rollOverRequested) { + rollWriters( + fs, + writerMap, + regionSplitPartitioner, + previousRow, + compactionExclude) + rollOverRequested = false + } + }) + } + } + + // This will get a writer for the column family + // If there is no writer for a given column family then + // it will get created here. + // We have finished all the data so lets close up the writers + rollWriters(fs, writerMap, regionSplitPartitioner, previousRow, compactionExclude) + rollOverRequested = false + }) + } finally { + if (null != conn) conn.close() + } + } + + /** + * This will return a new HFile writer when requested + * + * @param family column family + * @param conf configuration to connect to HBase + * @param favoredNodes nodes that we would like to write too + * @param fs FileSystem object where we will be writing the HFiles to + * @return WriterLength object + */ + private def getNewHFileWriter( + family: Array[Byte], + conf: Configuration, + favoredNodes: Array[InetSocketAddress], + fs: FileSystem, + familydir: Path, + familyHFileWriteOptionsMapInternal: util.HashMap[ByteArrayWrapper, FamilyHFileWriteOptions], + defaultCompression: Compression.Algorithm): WriterLength = { + + var familyOptions = familyHFileWriteOptionsMapInternal.get(new ByteArrayWrapper(family)) + + if (familyOptions == null) { + familyOptions = new FamilyHFileWriteOptions( + defaultCompression.toString, + BloomType.NONE.toString, + HConstants.DEFAULT_BLOCKSIZE, + DataBlockEncoding.NONE.toString) + familyHFileWriteOptionsMapInternal.put(new ByteArrayWrapper(family), familyOptions) + } + + val tempConf = new Configuration(conf) + tempConf.setFloat(HConstants.HFILE_BLOCK_CACHE_SIZE_KEY, 0.0f) + + // HBASE-25249 introduced an incompatible change in the IA.Private HStore and StoreUtils + // so here, we directly use conf.get for CheckSumType and BytesPerCheckSum to make it + // compatible between hbase 2.3.x and 2.4.x + val contextBuilder = new HFileContextBuilder() + .withCompression(Algorithm.valueOf(familyOptions.compression)) + // ChecksumType.nameToType is still an IA.Private Utils, but it's unlikely to be changed. + .withChecksumType( + ChecksumType + .nameToType( + conf.get(HConstants.CHECKSUM_TYPE_NAME, ChecksumType.getDefaultChecksumType.getName))) + .withCellComparator(CellComparator.getInstance()) + .withBytesPerCheckSum( + conf.getInt(HConstants.BYTES_PER_CHECKSUM, HFile.DEFAULT_BYTES_PER_CHECKSUM)) + .withBlockSize(familyOptions.blockSize) + + if (HFile.getFormatVersion(conf) >= HFile.MIN_FORMAT_VERSION_WITH_TAGS) { + contextBuilder.withIncludesTags(true) + } + + contextBuilder.withDataBlockEncoding(DataBlockEncoding.valueOf(familyOptions.dataBlockEncoding)) + val hFileContext = contextBuilder.build() + + // Add a '_' to the file name because this is a unfinished file. A rename will happen + // to remove the '_' when the file is closed. + new WriterLength( + 0, + new StoreFileWriter.Builder(conf, new CacheConfig(tempConf), new HFileSystem(fs)) + .withBloomType(BloomType.valueOf(familyOptions.bloomType)) + .withFileContext(hFileContext) + .withFilePath(new Path(familydir, "_" + UUID.randomUUID.toString.replaceAll("-", ""))) + .withFavoredNodes(favoredNodes) + .build()) + + } + + /** + * Encompasses the logic to write a value to an HFile + * + * @param rowKey The RowKey for the record + * @param family HBase column family for the record + * @param qualifier HBase column qualifier for the record + * @param cellValue HBase cell value + * @param nowTimeStamp The cell time stamp + * @param fs Connection to the FileSystem for the HFile + * @param conn Connection to HBaes + * @param tableName HBase TableName object + * @param conf Configuration to be used when making a new HFile + * @param familyHFileWriteOptionsMapInternal Extra configs for the HFile + * @param hfileCompression The compression codec for the new HFile + * @param writerMap HashMap of existing writers and their offsets + * @param stagingDir The staging directory on the FileSystem to store + * the HFiles + * @return The writer for the given HFile that was writen + * too + */ + private def writeValueToHFile( + rowKey: Array[Byte], + family: Array[Byte], + qualifier: Array[Byte], + cellValue: Array[Byte], + nowTimeStamp: Long, + fs: FileSystem, + conn: Connection, + tableName: TableName, + conf: Configuration, + familyHFileWriteOptionsMapInternal: util.HashMap[ByteArrayWrapper, FamilyHFileWriteOptions], + hfileCompression: Compression.Algorithm, + writerMap: mutable.HashMap[ByteArrayWrapper, WriterLength], + stagingDir: String): WriterLength = { + + val wl = writerMap.getOrElseUpdate( + new ByteArrayWrapper(family), { + val familyDir = new Path(stagingDir, Bytes.toString(family)) + + familyDir.getFileSystem(conf).mkdirs(familyDir); + + val loc: HRegionLocation = { + try { + val locator = + conn.getRegionLocator(tableName) + locator.getRegionLocation(rowKey) + } catch { + case e: Throwable => + logWarning( + "there's something wrong when locating rowkey: " + + Bytes.toString(rowKey)) + null + } + } + if (null == loc) { + if (log.isTraceEnabled) { + logTrace( + "failed to get region location, so use default writer: " + + Bytes.toString(rowKey)) + } + getNewHFileWriter( + family = family, + conf = conf, + favoredNodes = null, + fs = fs, + familydir = familyDir, + familyHFileWriteOptionsMapInternal, + hfileCompression) + } else { + if (log.isDebugEnabled) { + logDebug("first rowkey: [" + Bytes.toString(rowKey) + "]") + } + val initialIsa = + new InetSocketAddress(loc.getHostname, loc.getPort) + if (initialIsa.isUnresolved) { + if (log.isTraceEnabled) { + logTrace( + "failed to resolve bind address: " + loc.getHostname + ":" + + loc.getPort + ", so use default writer") + } + getNewHFileWriter( + family, + conf, + null, + fs, + familyDir, + familyHFileWriteOptionsMapInternal, + hfileCompression) + } else { + if (log.isDebugEnabled) { + logDebug("use favored nodes writer: " + initialIsa.getHostString) + } + getNewHFileWriter( + family, + conf, + Array[InetSocketAddress](initialIsa), + fs, + familyDir, + familyHFileWriteOptionsMapInternal, + hfileCompression) + } + } + }) + + val keyValue = new KeyValue(rowKey, family, qualifier, nowTimeStamp, cellValue) + + wl.writer.append(keyValue) + wl.written += keyValue.getLength + + wl + } + + /** + * This will roll all Writers + * @param fs Hadoop FileSystem object + * @param writerMap HashMap that contains all the writers + * @param regionSplitPartitioner The partitioner with knowledge of how the + * Region's are split by row key + * @param previousRow The last row to fill the HFile ending range metadata + * @param compactionExclude The exclude compaction metadata flag for the HFile + */ + private def rollWriters( + fs: FileSystem, + writerMap: mutable.HashMap[ByteArrayWrapper, WriterLength], + regionSplitPartitioner: BulkLoadPartitioner, + previousRow: Array[Byte], + compactionExclude: Boolean): Unit = { + writerMap.values.foreach( + wl => { + if (wl.writer != null) { + logDebug( + "Writer=" + wl.writer.getPath + + (if (wl.written == 0) "" else ", wrote=" + wl.written)) + closeHFileWriter(fs, wl.writer, regionSplitPartitioner, previousRow, compactionExclude) + } + }) + writerMap.clear() + + } + + /** + * Function to close an HFile + * @param fs Hadoop FileSystem object + * @param w HFile Writer + * @param regionSplitPartitioner The partitioner with knowledge of how the + * Region's are split by row key + * @param previousRow The last row to fill the HFile ending range metadata + * @param compactionExclude The exclude compaction metadata flag for the HFile + */ + private def closeHFileWriter( + fs: FileSystem, + w: StoreFileWriter, + regionSplitPartitioner: BulkLoadPartitioner, + previousRow: Array[Byte], + compactionExclude: Boolean): Unit = { + if (w != null) { + w.appendFileInfo(HStoreFile.BULKLOAD_TIME_KEY, Bytes.toBytes(System.currentTimeMillis())) + w.appendFileInfo( + HStoreFile.BULKLOAD_TASK_KEY, + Bytes.toBytes(regionSplitPartitioner.getPartition(previousRow))) + w.appendFileInfo(HStoreFile.MAJOR_COMPACTION_KEY, Bytes.toBytes(true)) + w.appendFileInfo( + HStoreFile.EXCLUDE_FROM_MINOR_COMPACTION_KEY, + Bytes.toBytes(compactionExclude)) + w.appendTrackedTimestampsToMetadata() + w.close() + + val srcPath = w.getPath + + // In the new path you will see that we are using substring. This is to + // remove the '_' character in front of the HFile name. '_' is a character + // that will tell HBase that this file shouldn't be included in the bulk load + // This feature is to protect for unfinished HFiles being submitted to HBase + val newPath = new Path(w.getPath.getParent, w.getPath.getName.substring(1)) + if (!fs.rename(srcPath, newPath)) { + throw new IOException( + "Unable to rename '" + srcPath + + "' to " + newPath) + } + } + } + + /** + * This is a wrapper class around StoreFileWriter. The reason for the + * wrapper is to keep the length of the file along side the writer + * + * @param written The writer to be wrapped + * @param writer The number of bytes written to the writer + */ + class WriterLength(var written: Long, val writer: StoreFileWriter) +} + +@InterfaceAudience.Private +object LatestHBaseContextCache { + var latest: HBaseContext = null +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseDStreamFunctions.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseDStreamFunctions.scala new file mode 100644 index 00000000..0ca5038c --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseDStreamFunctions.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client._ +import org.apache.hadoop.hbase.io.ImmutableBytesWritable +import org.apache.spark.streaming.dstream.DStream +import org.apache.yetus.audience.InterfaceAudience +import scala.reflect.ClassTag + +/** + * HBaseDStreamFunctions contains a set of implicit functions that can be + * applied to a Spark DStream so that we can easily interact with HBase + */ +@InterfaceAudience.Public +object HBaseDStreamFunctions { + + /** + * These are implicit methods for a DStream that contains any type of + * data. + * + * @param dStream This is for dStreams of any type + * @tparam T Type T + */ + implicit class GenericHBaseDStreamFunctions[T](val dStream: DStream[T]) { + + /** + * Implicit method that gives easy access to HBaseContext's bulk + * put. This will not return a new Stream. Think of it like a foreach + * + * @param hc The hbaseContext object to identify which + * HBase cluster connection to use + * @param tableName The tableName that the put will be sent to + * @param f The function that will turn the DStream values + * into HBase Put objects. + */ + def hbaseBulkPut(hc: HBaseContext, tableName: TableName, f: (T) => Put): Unit = { + hc.streamBulkPut(dStream, tableName, f) + } + + /** + * Implicit method that gives easy access to HBaseContext's bulk + * get. This will return a new DStream. Think about it as a DStream map + * function. In that every DStream value will get a new value out of + * HBase. That new value will populate the newly generated DStream. + * + * @param hc The hbaseContext object to identify which + * HBase cluster connection to use + * @param tableName The tableName that the put will be sent to + * @param batchSize How many gets to execute in a single batch + * @param f The function that will turn the RDD values + * in HBase Get objects + * @param convertResult The function that will convert a HBase + * Result object into a value that will go + * into the resulting DStream + * @tparam R The type of Object that will be coming + * out of the resulting DStream + * @return A resulting DStream with type R objects + */ + def hbaseBulkGet[R: ClassTag]( + hc: HBaseContext, + tableName: TableName, + batchSize: Int, + f: (T) => Get, + convertResult: (Result) => R): DStream[R] = { + hc.streamBulkGet[T, R](tableName, batchSize, dStream, f, convertResult) + } + + /** + * Implicit method that gives easy access to HBaseContext's bulk + * get. This will return a new DStream. Think about it as a DStream map + * function. In that every DStream value will get a new value out of + * HBase. That new value will populate the newly generated DStream. + * + * @param hc The hbaseContext object to identify which + * HBase cluster connection to use + * @param tableName The tableName that the put will be sent to + * @param batchSize How many gets to execute in a single batch + * @param f The function that will turn the RDD values + * in HBase Get objects + * @return A resulting DStream with type R objects + */ + def hbaseBulkGet( + hc: HBaseContext, + tableName: TableName, + batchSize: Int, + f: (T) => Get): DStream[(ImmutableBytesWritable, Result)] = { + hc.streamBulkGet[T, (ImmutableBytesWritable, Result)]( + tableName, + batchSize, + dStream, + f, + result => (new ImmutableBytesWritable(result.getRow), result)) + } + + /** + * Implicit method that gives easy access to HBaseContext's bulk + * Delete. This will not return a new DStream. + * + * @param hc The hbaseContext object to identify which HBase + * cluster connection to use + * @param tableName The tableName that the deletes will be sent to + * @param f The function that will convert the DStream value into + * a HBase Delete Object + * @param batchSize The number of Deletes to be sent in a single batch + */ + def hbaseBulkDelete( + hc: HBaseContext, + tableName: TableName, + f: (T) => Delete, + batchSize: Int): Unit = { + hc.streamBulkDelete(dStream, tableName, f, batchSize) + } + + /** + * Implicit method that gives easy access to HBaseContext's + * foreachPartition method. This will ack very much like a normal DStream + * foreach method but for the fact that you will now have a HBase connection + * while iterating through the values. + * + * @param hc The hbaseContext object to identify which HBase + * cluster connection to use + * @param f This function will get an iterator for a Partition of an + * DStream along with a connection object to HBase + */ + def hbaseForeachPartition(hc: HBaseContext, f: (Iterator[T], Connection) => Unit): Unit = { + hc.streamForeachPartition(dStream, f) + } + + /** + * Implicit method that gives easy access to HBaseContext's + * mapPartitions method. This will ask very much like a normal DStream + * map partitions method but for the fact that you will now have a + * HBase connection while iterating through the values + * + * @param hc The hbaseContext object to identify which HBase + * cluster connection to use + * @param f This function will get an iterator for a Partition of an + * DStream along with a connection object to HBase + * @tparam R This is the type of objects that will go into the resulting + * DStream + * @return A resulting DStream of type R + */ + def hbaseMapPartitions[R: ClassTag]( + hc: HBaseContext, + f: (Iterator[T], Connection) => Iterator[R]): DStream[R] = { + hc.streamMapPartitions(dStream, f) + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseRDDFunctions.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseRDDFunctions.scala new file mode 100644 index 00000000..9e3da565 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/HBaseRDDFunctions.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.util +import org.apache.hadoop.hbase.{HConstants, TableName} +import org.apache.hadoop.hbase.client._ +import org.apache.hadoop.hbase.io.ImmutableBytesWritable +import org.apache.spark.rdd.RDD +import org.apache.yetus.audience.InterfaceAudience +import scala.reflect.ClassTag + +/** + * HBaseRDDFunctions contains a set of implicit functions that can be + * applied to a Spark RDD so that we can easily interact with HBase + */ +@InterfaceAudience.Public +object HBaseRDDFunctions { + + /** + * These are implicit methods for a RDD that contains any type of + * data. + * + * @param rdd This is for rdd of any type + * @tparam T This is any type + */ + implicit class GenericHBaseRDDFunctions[T](val rdd: RDD[T]) { + + /** + * Implicit method that gives easy access to HBaseContext's bulk + * put. This will not return a new RDD. Think of it like a foreach + * + * @param hc The hbaseContext object to identify which + * HBase cluster connection to use + * @param tableName The tableName that the put will be sent to + * @param f The function that will turn the RDD values + * into HBase Put objects. + */ + def hbaseBulkPut(hc: HBaseContext, tableName: TableName, f: (T) => Put): Unit = { + hc.bulkPut(rdd, tableName, f) + } + + /** + * Implicit method that gives easy access to HBaseContext's bulk + * get. This will return a new RDD. Think about it as a RDD map + * function. In that every RDD value will get a new value out of + * HBase. That new value will populate the newly generated RDD. + * + * @param hc The hbaseContext object to identify which + * HBase cluster connection to use + * @param tableName The tableName that the put will be sent to + * @param batchSize How many gets to execute in a single batch + * @param f The function that will turn the RDD values + * in HBase Get objects + * @param convertResult The function that will convert a HBase + * Result object into a value that will go + * into the resulting RDD + * @tparam R The type of Object that will be coming + * out of the resulting RDD + * @return A resulting RDD with type R objects + */ + def hbaseBulkGet[R: ClassTag]( + hc: HBaseContext, + tableName: TableName, + batchSize: Int, + f: (T) => Get, + convertResult: (Result) => R): RDD[R] = { + hc.bulkGet[T, R](tableName, batchSize, rdd, f, convertResult) + } + + /** + * Implicit method that gives easy access to HBaseContext's bulk + * get. This will return a new RDD. Think about it as a RDD map + * function. In that every RDD value will get a new value out of + * HBase. That new value will populate the newly generated RDD. + * + * @param hc The hbaseContext object to identify which + * HBase cluster connection to use + * @param tableName The tableName that the put will be sent to + * @param batchSize How many gets to execute in a single batch + * @param f The function that will turn the RDD values + * in HBase Get objects + * @return A resulting RDD with type R objects + */ + def hbaseBulkGet( + hc: HBaseContext, + tableName: TableName, + batchSize: Int, + f: (T) => Get): RDD[(ImmutableBytesWritable, Result)] = { + hc.bulkGet[T, (ImmutableBytesWritable, Result)]( + tableName, + batchSize, + rdd, + f, + result => + if (result != null && result.getRow != null) { + (new ImmutableBytesWritable(result.getRow), result) + } else { + null + }) + } + + /** + * Implicit method that gives easy access to HBaseContext's bulk + * Delete. This will not return a new RDD. + * + * @param hc The hbaseContext object to identify which HBase + * cluster connection to use + * @param tableName The tableName that the deletes will be sent to + * @param f The function that will convert the RDD value into + * a HBase Delete Object + * @param batchSize The number of Deletes to be sent in a single batch + */ + def hbaseBulkDelete( + hc: HBaseContext, + tableName: TableName, + f: (T) => Delete, + batchSize: Int): Unit = { + hc.bulkDelete(rdd, tableName, f, batchSize) + } + + /** + * Implicit method that gives easy access to HBaseContext's + * foreachPartition method. This will ack very much like a normal RDD + * foreach method but for the fact that you will now have a HBase connection + * while iterating through the values. + * + * @param hc The hbaseContext object to identify which HBase + * cluster connection to use + * @param f This function will get an iterator for a Partition of an + * RDD along with a connection object to HBase + */ + def hbaseForeachPartition(hc: HBaseContext, f: (Iterator[T], Connection) => Unit): Unit = { + hc.foreachPartition(rdd, f) + } + + /** + * Implicit method that gives easy access to HBaseContext's + * mapPartitions method. This will ask very much like a normal RDD + * map partitions method but for the fact that you will now have a + * HBase connection while iterating through the values + * + * @param hc The hbaseContext object to identify which HBase + * cluster connection to use + * @param f This function will get an iterator for a Partition of an + * RDD along with a connection object to HBase + * @tparam R This is the type of objects that will go into the resulting + * RDD + * @return A resulting RDD of type R + */ + def hbaseMapPartitions[R: ClassTag]( + hc: HBaseContext, + f: (Iterator[T], Connection) => Iterator[R]): RDD[R] = { + hc.mapPartitions[T, R](rdd, f) + } + + /** + * Spark Implementation of HBase Bulk load for wide rows or when + * values are not already combined at the time of the map process + * + * A Spark Implementation of HBase Bulk load + * + * This will take the content from an existing RDD then sort and shuffle + * it with respect to region splits. The result of that sort and shuffle + * will be written to HFiles. + * + * After this function is executed the user will have to call + * LoadIncrementalHFiles.doBulkLoad(...) to move the files into HBase + * + * Also note this version of bulk load is different from past versions in + * that it includes the qualifier as part of the sort process. The + * reason for this is to be able to support rows will very large number + * of columns. + * + * @param tableName The HBase table we are loading into + * @param flatMap A flapMap function that will make every row in the RDD + * into N cells for the bulk load + * @param stagingDir The location on the FileSystem to bulk load into + * @param familyHFileWriteOptionsMap Options that will define how the HFile for a + * column family is written + * @param compactionExclude Compaction excluded for the HFiles + * @param maxSize Max size for the HFiles before they roll + */ + def hbaseBulkLoad( + hc: HBaseContext, + tableName: TableName, + flatMap: (T) => Iterator[(KeyFamilyQualifier, Array[Byte])], + stagingDir: String, + familyHFileWriteOptionsMap: util.Map[Array[Byte], FamilyHFileWriteOptions] = + new util.HashMap[Array[Byte], FamilyHFileWriteOptions](), + compactionExclude: Boolean = false, + maxSize: Long = HConstants.DEFAULT_MAX_FILE_SIZE): Unit = { + hc.bulkLoad( + rdd, + tableName, + flatMap, + stagingDir, + familyHFileWriteOptionsMap, + compactionExclude, + maxSize) + } + + /** + * Implicit method that gives easy access to HBaseContext's + * bulkLoadThinRows method. + * + * Spark Implementation of HBase Bulk load for short rows some where less then + * a 1000 columns. This bulk load should be faster for tables will thinner + * rows then the other spark implementation of bulk load that puts only one + * value into a record going into a shuffle + * + * This will take the content from an existing RDD then sort and shuffle + * it with respect to region splits. The result of that sort and shuffle + * will be written to HFiles. + * + * After this function is executed the user will have to call + * LoadIncrementalHFiles.doBulkLoad(...) to move the files into HBase + * + * In this implementation only the rowKey is given to the shuffle as the key + * and all the columns are already linked to the RowKey before the shuffle + * stage. The sorting of the qualifier is done in memory out side of the + * shuffle stage + * + * @param tableName The HBase table we are loading into + * @param mapFunction A function that will convert the RDD records to + * the key value format used for the shuffle to prep + * for writing to the bulk loaded HFiles + * @param stagingDir The location on the FileSystem to bulk load into + * @param familyHFileWriteOptionsMap Options that will define how the HFile for a + * column family is written + * @param compactionExclude Compaction excluded for the HFiles + * @param maxSize Max size for the HFiles before they roll + */ + def hbaseBulkLoadThinRows( + hc: HBaseContext, + tableName: TableName, + mapFunction: (T) => (ByteArrayWrapper, FamiliesQualifiersValues), + stagingDir: String, + familyHFileWriteOptionsMap: util.Map[Array[Byte], FamilyHFileWriteOptions] = + new util.HashMap[Array[Byte], FamilyHFileWriteOptions](), + compactionExclude: Boolean = false, + maxSize: Long = HConstants.DEFAULT_MAX_FILE_SIZE): Unit = { + hc.bulkLoadThinRows( + rdd, + tableName, + mapFunction, + stagingDir, + familyHFileWriteOptionsMap, + compactionExclude, + maxSize) + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/JavaHBaseContext.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/JavaHBaseContext.scala new file mode 100644 index 00000000..48ac3737 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/JavaHBaseContext.scala @@ -0,0 +1,420 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.util + +import java.lang.Iterable +import java.util.Map +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.{Connection, Delete, Get, Put, Result, Scan} +import org.apache.hadoop.hbase.io.ImmutableBytesWritable +import org.apache.hadoop.hbase.util.Pair +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.java.function.{FlatMapFunction, Function, VoidFunction} +import org.apache.spark.streaming.api.java.JavaDStream +import org.apache.yetus.audience.InterfaceAudience +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag + +/** + * This is the Java Wrapper over HBaseContext which is written in + * Scala. This class will be used by developers that want to + * work with Spark or Spark Streaming in Java + * + * @param jsc This is the JavaSparkContext that we will wrap + * @param config This is the config information to out HBase cluster + */ +@InterfaceAudience.Public +class JavaHBaseContext(@transient val jsc: JavaSparkContext, @transient val config: Configuration) + extends Serializable { + val hbaseContext = new HBaseContext(jsc.sc, config) + + /** + * A simple enrichment of the traditional Spark javaRdd foreachPartition. + * This function differs from the original in that it offers the + * developer access to a already connected Connection object + * + * Note: Do not close the Connection object. All Connection + * management is handled outside this method + * + * @param javaRdd Original javaRdd with data to iterate over + * @param f Function to be given a iterator to iterate through + * the RDD values and a Connection object to interact + * with HBase + */ + def foreachPartition[T]( + javaRdd: JavaRDD[T], + f: VoidFunction[(java.util.Iterator[T], Connection)]) = { + + hbaseContext.foreachPartition( + javaRdd.rdd, + (it: Iterator[T], conn: Connection) => { + f.call((it.asJava, conn)) + }) + } + + /** + * A simple enrichment of the traditional Spark Streaming dStream foreach + * This function differs from the original in that it offers the + * developer access to a already connected Connection object + * + * Note: Do not close the Connection object. All Connection + * management is handled outside this method + * + * @param javaDstream Original DStream with data to iterate over + * @param f Function to be given a iterator to iterate through + * the JavaDStream values and a Connection object to + * interact with HBase + */ + def foreachPartition[T]( + javaDstream: JavaDStream[T], + f: VoidFunction[(Iterator[T], Connection)]) = { + hbaseContext.foreachPartition( + javaDstream.dstream, + (it: Iterator[T], conn: Connection) => f.call(it, conn)) + } + + /** + * A simple enrichment of the traditional Spark JavaRDD mapPartition. + * This function differs from the original in that it offers the + * developer access to a already connected Connection object + * + * Note: Do not close the Connection object. All Connection + * management is handled outside this method + * + * Note: Make sure to partition correctly to avoid memory issue when + * getting data from HBase + * + * @param javaRdd Original JavaRdd with data to iterate over + * @param f Function to be given a iterator to iterate through + * the RDD values and a Connection object to interact + * with HBase + * @return Returns a new RDD generated by the user definition + * function just like normal mapPartition + */ + def mapPartitions[T, R]( + javaRdd: JavaRDD[T], + f: FlatMapFunction[(java.util.Iterator[T], Connection), R]): JavaRDD[R] = { + JavaRDD.fromRDD( + hbaseContext.mapPartitions( + javaRdd.rdd, + (it: Iterator[T], conn: Connection) => f.call((it.asJava, conn)).asScala)(fakeClassTag[R]))(fakeClassTag[R]) + } + + /** + * A simple enrichment of the traditional Spark Streaming JavaDStream + * mapPartition. + * + * This function differs from the original in that it offers the + * developer access to a already connected Connection object + * + * Note: Do not close the Connection object. All Connection + * management is handled outside this method + * + * Note: Make sure to partition correctly to avoid memory issue when + * getting data from HBase + * + * @param javaDstream Original JavaDStream with data to iterate over + * @param mp Function to be given a iterator to iterate through + * the JavaDStream values and a Connection object to + * interact with HBase + * @return Returns a new JavaDStream generated by the user + * definition function just like normal mapPartition + */ + def streamMap[T, U]( + javaDstream: JavaDStream[T], + mp: Function[(Iterator[T], Connection), Iterator[U]]): JavaDStream[U] = { + JavaDStream.fromDStream( + hbaseContext.streamMapPartitions( + javaDstream.dstream, + (it: Iterator[T], conn: Connection) => mp.call(it, conn))(fakeClassTag[U]))(fakeClassTag[U]) + } + + /** + * A simple abstraction over the HBaseContext.foreachPartition method. + * + * It allow addition support for a user to take JavaRDD + * and generate puts and send them to HBase. + * The complexity of managing the Connection is + * removed from the developer + * + * @param javaRdd Original JavaRDD with data to iterate over + * @param tableName The name of the table to put into + * @param f Function to convert a value in the JavaRDD + * to a HBase Put + */ + def bulkPut[T](javaRdd: JavaRDD[T], tableName: TableName, f: Function[(T), Put]) { + + hbaseContext.bulkPut(javaRdd.rdd, tableName, (t: T) => f.call(t)) + } + + /** + * A simple abstraction over the HBaseContext.streamMapPartition method. + * + * It allow addition support for a user to take a JavaDStream and + * generate puts and send them to HBase. + * + * The complexity of managing the Connection is + * removed from the developer + * + * @param javaDstream Original DStream with data to iterate over + * @param tableName The name of the table to put into + * @param f Function to convert a value in + * the JavaDStream to a HBase Put + */ + def streamBulkPut[T](javaDstream: JavaDStream[T], tableName: TableName, f: Function[T, Put]) = { + hbaseContext.streamBulkPut(javaDstream.dstream, tableName, (t: T) => f.call(t)) + } + + /** + * A simple abstraction over the HBaseContext.foreachPartition method. + * + * It allow addition support for a user to take a JavaRDD and + * generate delete and send them to HBase. + * + * The complexity of managing the Connection is + * removed from the developer + * + * @param javaRdd Original JavaRDD with data to iterate over + * @param tableName The name of the table to delete from + * @param f Function to convert a value in the JavaRDD to a + * HBase Deletes + * @param batchSize The number of deletes to batch before sending to HBase + */ + def bulkDelete[T]( + javaRdd: JavaRDD[T], + tableName: TableName, + f: Function[T, Delete], + batchSize: Integer) { + hbaseContext.bulkDelete(javaRdd.rdd, tableName, (t: T) => f.call(t), batchSize) + } + + /** + * A simple abstraction over the HBaseContext.streamBulkMutation method. + * + * It allow addition support for a user to take a JavaDStream and + * generate Delete and send them to HBase. + * + * The complexity of managing the Connection is + * removed from the developer + * + * @param javaDStream Original DStream with data to iterate over + * @param tableName The name of the table to delete from + * @param f Function to convert a value in the JavaDStream to a + * HBase Delete + * @param batchSize The number of deletes to be sent at once + */ + def streamBulkDelete[T]( + javaDStream: JavaDStream[T], + tableName: TableName, + f: Function[T, Delete], + batchSize: Integer) = { + hbaseContext.streamBulkDelete(javaDStream.dstream, tableName, (t: T) => f.call(t), batchSize) + } + + /** + * A simple abstraction over the HBaseContext.mapPartition method. + * + * It allow addition support for a user to take a JavaRDD and generates a + * new RDD based on Gets and the results they bring back from HBase + * + * @param tableName The name of the table to get from + * @param batchSize batch size of how many gets to retrieve in a single fetch + * @param javaRdd Original JavaRDD with data to iterate over + * @param makeGet Function to convert a value in the JavaRDD to a + * HBase Get + * @param convertResult This will convert the HBase Result object to + * what ever the user wants to put in the resulting + * JavaRDD + * @return New JavaRDD that is created by the Get to HBase + */ + def bulkGet[T, U]( + tableName: TableName, + batchSize: Integer, + javaRdd: JavaRDD[T], + makeGet: Function[T, Get], + convertResult: Function[Result, U]): JavaRDD[U] = { + + JavaRDD.fromRDD( + hbaseContext.bulkGet[T, U]( + tableName, + batchSize, + javaRdd.rdd, + (t: T) => makeGet.call(t), + (r: Result) => { + convertResult.call(r) + })(fakeClassTag[U]))(fakeClassTag[U]) + + } + + /** + * A simple abstraction over the HBaseContext.streamMap method. + * + * It allow addition support for a user to take a DStream and + * generates a new DStream based on Gets and the results + * they bring back from HBase + * + * @param tableName The name of the table to get from + * @param batchSize The number of gets to be batched together + * @param javaDStream Original DStream with data to iterate over + * @param makeGet Function to convert a value in the JavaDStream to a + * HBase Get + * @param convertResult This will convert the HBase Result object to + * what ever the user wants to put in the resulting + * JavaDStream + * @return New JavaDStream that is created by the Get to HBase + */ + def streamBulkGet[T, U]( + tableName: TableName, + batchSize: Integer, + javaDStream: JavaDStream[T], + makeGet: Function[T, Get], + convertResult: Function[Result, U]): JavaDStream[U] = { + JavaDStream.fromDStream( + hbaseContext.streamBulkGet( + tableName, + batchSize, + javaDStream.dstream, + (t: T) => makeGet.call(t), + (r: Result) => convertResult.call(r))(fakeClassTag[U]))(fakeClassTag[U]) + } + + /** + * A simple abstraction over the HBaseContext.bulkLoad method. + * It allow addition support for a user to take a JavaRDD and + * convert into new JavaRDD[Pair] based on MapFunction, + * and HFiles will be generated in stagingDir for bulk load + * + * @param javaRdd The javaRDD we are bulk loading from + * @param tableName The HBase table we are loading into + * @param mapFunc A Function that will convert a value in JavaRDD + * to Pair(KeyFamilyQualifier, Array[Byte]) + * @param stagingDir The location on the FileSystem to bulk load into + * @param familyHFileWriteOptionsMap Options that will define how the HFile for a + * column family is written + * @param compactionExclude Compaction excluded for the HFiles + * @param maxSize Max size for the HFiles before they roll + */ + def bulkLoad[T]( + javaRdd: JavaRDD[T], + tableName: TableName, + mapFunc: Function[T, Pair[KeyFamilyQualifier, Array[Byte]]], + stagingDir: String, + familyHFileWriteOptionsMap: Map[Array[Byte], FamilyHFileWriteOptions], + compactionExclude: Boolean, + maxSize: Long): Unit = { + hbaseContext.bulkLoad[Pair[KeyFamilyQualifier, Array[Byte]]]( + javaRdd.map(mapFunc).rdd, + tableName, + t => { + val keyFamilyQualifier = t.getFirst + val value = t.getSecond + Seq((keyFamilyQualifier, value)).iterator + }, + stagingDir, + familyHFileWriteOptionsMap, + compactionExclude, + maxSize) + } + + /** + * A simple abstraction over the HBaseContext.bulkLoadThinRows method. + * It allow addition support for a user to take a JavaRDD and + * convert into new JavaRDD[Pair] based on MapFunction, + * and HFiles will be generated in stagingDir for bulk load + * + * @param javaRdd The javaRDD we are bulk loading from + * @param tableName The HBase table we are loading into + * @param mapFunc A Function that will convert a value in JavaRDD + * to Pair(ByteArrayWrapper, FamiliesQualifiersValues) + * @param stagingDir The location on the FileSystem to bulk load into + * @param familyHFileWriteOptionsMap Options that will define how the HFile for a + * column family is written + * @param compactionExclude Compaction excluded for the HFiles + * @param maxSize Max size for the HFiles before they roll + */ + def bulkLoadThinRows[T]( + javaRdd: JavaRDD[T], + tableName: TableName, + mapFunc: Function[T, Pair[ByteArrayWrapper, FamiliesQualifiersValues]], + stagingDir: String, + familyHFileWriteOptionsMap: Map[Array[Byte], FamilyHFileWriteOptions], + compactionExclude: Boolean, + maxSize: Long): Unit = { + hbaseContext.bulkLoadThinRows[Pair[ByteArrayWrapper, FamiliesQualifiersValues]]( + javaRdd.map(mapFunc).rdd, + tableName, + t => { + (t.getFirst, t.getSecond) + }, + stagingDir, + familyHFileWriteOptionsMap, + compactionExclude, + maxSize) + } + + /** + * This function will use the native HBase TableInputFormat with the + * given scan object to generate a new JavaRDD + * + * @param tableName The name of the table to scan + * @param scans The HBase scan object to use to read data from HBase + * @param f Function to convert a Result object from HBase into + * What the user wants in the final generated JavaRDD + * @return New JavaRDD with results from scan + */ + def hbaseRDD[U]( + tableName: TableName, + scans: Scan, + f: Function[(ImmutableBytesWritable, Result), U]): JavaRDD[U] = { + JavaRDD.fromRDD( + hbaseContext + .hbaseRDD[U](tableName, scans, (v: (ImmutableBytesWritable, Result)) => f.call(v._1, v._2))( + fakeClassTag[U]))(fakeClassTag[U]) + } + + /** + * A overloaded version of HBaseContext hbaseRDD that define the + * type of the resulting JavaRDD + * + * @param tableName The name of the table to scan + * @param scans The HBase scan object to use to read data from HBase + * @return New JavaRDD with results from scan + */ + def hbaseRDD(tableName: TableName, scans: Scan): JavaRDD[(ImmutableBytesWritable, Result)] = { + JavaRDD.fromRDD(hbaseContext.hbaseRDD(tableName, scans)) + } + + /** + * Produces a ClassTag[T], which is actually just a casted ClassTag[AnyRef]. + * + * This method is used to keep ClassTags out of the external Java API, as the Java compiler + * cannot produce them automatically. While this ClassTag-faking does please the compiler, + * it can cause problems at runtime if the Scala API relies on ClassTags for correctness. + * + * Often, though, a ClassTag[AnyRef] will not lead to incorrect behavior, + * just worse performance or security issues. + * For instance, an Array[AnyRef] can hold any type T, + * but may lose primitive + * specialization. + */ + private[spark] def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]] + +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/KeyFamilyQualifier.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/KeyFamilyQualifier.scala new file mode 100644 index 00000000..d71e31a5 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/KeyFamilyQualifier.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.io.Serializable +import org.apache.hadoop.hbase.util.Bytes +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is the key to be used for sorting and shuffling. + * + * We will only partition on the rowKey but we will sort on all three + * + * @param rowKey Record RowKey + * @param family Record ColumnFamily + * @param qualifier Cell Qualifier + */ +@InterfaceAudience.Public +class KeyFamilyQualifier( + val rowKey: Array[Byte], + val family: Array[Byte], + val qualifier: Array[Byte]) + extends Comparable[KeyFamilyQualifier] + with Serializable { + override def compareTo(o: KeyFamilyQualifier): Int = { + var result = Bytes.compareTo(rowKey, o.rowKey) + if (result == 0) { + result = Bytes.compareTo(family, o.family) + if (result == 0) result = Bytes.compareTo(qualifier, o.qualifier) + } + result + } + override def toString: String = { + Bytes.toString(rowKey) + ":" + Bytes.toString(family) + ":" + Bytes.toString(qualifier) + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/Logging.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/Logging.scala new file mode 100644 index 00000000..18636313 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/Logging.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.yetus.audience.InterfaceAudience +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import org.slf4j.impl.StaticLoggerBinder + +/** + * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows + * logging messages at different levels using methods that only evaluate parameters lazily if the + * log level is enabled. + * Logging is private in Spark 2.0 + * This is to isolate incompatibilties across Spark releases. + */ +@InterfaceAudience.Private +trait Logging { + + // Make the log field transient so that objects with Logging can + // be serialized and used on another machine + @transient private var log_ : Logger = null + + // Method to get the logger name for this object + protected def logName = { + // Ignore trailing $'s in the class names for Scala objects + this.getClass.getName.stripSuffix("$") + } + + // Method to get or create the logger for this object + protected def log: Logger = { + if (log_ == null) { + initializeLogIfNecessary(false) + log_ = LoggerFactory.getLogger(logName) + } + log_ + } + + // Log methods that take only a String + protected def logInfo(msg: => String) { + if (log.isInfoEnabled) log.info(msg) + } + + protected def logDebug(msg: => String) { + if (log.isDebugEnabled) log.debug(msg) + } + + protected def logTrace(msg: => String) { + if (log.isTraceEnabled) log.trace(msg) + } + + protected def logWarning(msg: => String) { + if (log.isWarnEnabled) log.warn(msg) + } + + protected def logError(msg: => String) { + if (log.isErrorEnabled) log.error(msg) + } + + // Log methods that take Throwables (Exceptions/Errors) too + protected def logInfo(msg: => String, throwable: Throwable) { + if (log.isInfoEnabled) log.info(msg, throwable) + } + + protected def logDebug(msg: => String, throwable: Throwable) { + if (log.isDebugEnabled) log.debug(msg, throwable) + } + + protected def logTrace(msg: => String, throwable: Throwable) { + if (log.isTraceEnabled) log.trace(msg, throwable) + } + + protected def logWarning(msg: => String, throwable: Throwable) { + if (log.isWarnEnabled) log.warn(msg, throwable) + } + + protected def logError(msg: => String, throwable: Throwable) { + if (log.isErrorEnabled) log.error(msg, throwable) + } + + protected def initializeLogIfNecessary(isInterpreter: Boolean): Unit = { + if (!Logging.initialized) { + Logging.initLock.synchronized { + if (!Logging.initialized) { + initializeLogging(isInterpreter) + } + } + } + } + + private def initializeLogging(isInterpreter: Boolean): Unit = { + // Don't use a logger in here, as this is itself occurring during initialization of a logger + // If Log4j 1.2 is being used, but is not initialized, load a default properties file + val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr + Logging.initialized = true + + // Force a call into slf4j to initialize it. Avoids this happening from multiple threads + // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html + log + } +} + +private object Logging { + @volatile private var initialized = false + val initLock = new Object() +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/NewHBaseRDD.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/NewHBaseRDD.scala new file mode 100644 index 00000000..aeb502d9 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/NewHBaseRDD.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.InputFormat +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.NewHadoopRDD +import org.apache.yetus.audience.InterfaceAudience + +@InterfaceAudience.Public +class NewHBaseRDD[K, V]( + @transient val sc: SparkContext, + @transient val inputFormatClass: Class[_ <: InputFormat[K, V]], + @transient val keyClass: Class[K], + @transient val valueClass: Class[V], + @transient private val __conf: Configuration, + val hBaseContext: HBaseContext) + extends NewHadoopRDD(sc, inputFormatClass, keyClass, valueClass, __conf) { + + override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { + hBaseContext.applyCreds() + super.compute(theSplit, context) + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Bound.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Bound.scala new file mode 100644 index 00000000..ef0d287c --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Bound.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.datasources + +import org.apache.hadoop.hbase.spark.hbase._ +import org.apache.yetus.audience.InterfaceAudience + +/** + * The Bound represent the boudary for the scan + * + * @param b The byte array of the bound + * @param inc inclusive or not. + */ +@InterfaceAudience.Private +case class Bound(b: Array[Byte], inc: Boolean) +// The non-overlapping ranges we need to scan, if lower is equal to upper, it is a get request + +@InterfaceAudience.Private +case class Range(lower: Option[Bound], upper: Option[Bound]) + +@InterfaceAudience.Private +object Range { + def apply(region: HBaseRegion): Range = { + Range( + region.start.map(Bound(_, true)), + if (region.end.get.size == 0) { + None + } else { + region.end.map((Bound(_, false))) + }) + } +} + +@InterfaceAudience.Private +object Ranges { + // We assume that + // 1. r.lower.inc is true, and r.upper.inc is false + // 2. for each range in rs, its upper.inc is false + def and(r: Range, rs: Seq[Range]): Seq[Range] = { + rs.flatMap { + s => + val lower = s.lower + .map { + x => + // the scan has lower bound + r.lower + .map { + y => + // the region has lower bound + if (ord.compare(x.b, y.b) < 0) { + // scan lower bound is smaller than region server lower bound + Some(y) + } else { + // scan low bound is greater or equal to region server lower bound + Some(x) + } + } + .getOrElse(Some(x)) + } + .getOrElse(r.lower) + + val upper = s.upper + .map { + x => + // the scan has upper bound + r.upper + .map { + y => + // the region has upper bound + if (ord.compare(x.b, y.b) >= 0) { + // scan upper bound is larger than server upper bound + // but region server scan stop is exclusive. It is OK here. + Some(y) + } else { + // scan upper bound is less or equal to region server upper bound + Some(x) + } + } + .getOrElse(Some(x)) + } + .getOrElse(r.upper) + + val c = lower + .map { + case x => + upper + .map { + case y => + ord.compare(x.b, y.b) + } + .getOrElse(-1) + } + .getOrElse(-1) + if (c < 0) { + Some(Range(lower, upper)) + } else { + None + } + }.seq + } +} + +@InterfaceAudience.Private +object Points { + def and(r: Range, ps: Seq[Array[Byte]]): Seq[Array[Byte]] = { + ps.flatMap { + p => + if (ord.compare(r.lower.get.b, p) <= 0) { + // if region lower bound is less or equal to the point + if (r.upper.isDefined) { + // if region upper bound is defined + if (ord.compare(r.upper.get.b, p) > 0) { + // if the upper bound is greater than the point (because upper bound is exclusive) + Some(p) + } else { + None + } + } else { + // if the region upper bound is not defined (infinity) + Some(p) + } + } else { + None + } + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/DataTypeParserWrapper.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/DataTypeParserWrapper.scala new file mode 100644 index 00000000..936276da --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/DataTypeParserWrapper.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.datasources + +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.types.DataType +import org.apache.yetus.audience.InterfaceAudience + +@InterfaceAudience.Private +trait DataTypeParser { + def parse(dataTypeString: String): DataType +} + +@InterfaceAudience.Private +object DataTypeParserWrapper extends DataTypeParser { + def parse(dataTypeString: String): DataType = CatalystSqlParser.parseDataType(dataTypeString) +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseResources.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseResources.scala new file mode 100644 index 00000000..a179d514 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseResources.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.datasources + +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client._ +import org.apache.hadoop.hbase.spark.{HBaseConnectionCache, HBaseConnectionKey, HBaseRelation, SmartConnection} +import org.apache.yetus.audience.InterfaceAudience +import scala.language.implicitConversions + +// Resource and ReferencedResources are defined for extensibility, +// e.g., consolidate scan and bulkGet in the future work. + +// User has to invoke release explicitly to release the resource, +// and potentially parent resources +@InterfaceAudience.Private +trait Resource { + def release(): Unit +} + +@InterfaceAudience.Private +case class ScanResource(tbr: TableResource, rs: ResultScanner) extends Resource { + def release() { + rs.close() + tbr.release() + } +} + +@InterfaceAudience.Private +case class GetResource(tbr: TableResource, rs: Array[Result]) extends Resource { + def release() { + tbr.release() + } +} + +@InterfaceAudience.Private +trait ReferencedResource { + var count: Int = 0 + def init(): Unit + def destroy(): Unit + def acquire() = synchronized { + try { + count += 1 + if (count == 1) { + init() + } + } catch { + case e: Throwable => + release() + throw e + } + } + + def release() = synchronized { + count -= 1 + if (count == 0) { + destroy() + } + } + + def releaseOnException[T](func: => T): T = { + acquire() + val ret = { + try { + func + } catch { + case e: Throwable => + release() + throw e + } + } + ret + } +} + +@InterfaceAudience.Private +case class TableResource(relation: HBaseRelation) extends ReferencedResource { + var connection: SmartConnection = _ + var table: Table = _ + + override def init(): Unit = { + connection = HBaseConnectionCache.getConnection(relation.hbaseConf) + table = connection.getTable(TableName.valueOf(relation.tableName)) + } + + override def destroy(): Unit = { + if (table != null) { + table.close() + table = null + } + if (connection != null) { + connection.close() + connection = null + } + } + + def getScanner(scan: Scan): ScanResource = releaseOnException { + ScanResource(this, table.getScanner(scan)) + } + + def get(list: java.util.List[org.apache.hadoop.hbase.client.Get]) = releaseOnException { + GetResource(this, table.get(list)) + } +} + +@InterfaceAudience.Private +case class RegionResource(relation: HBaseRelation) extends ReferencedResource { + var connection: SmartConnection = _ + var rl: RegionLocator = _ + val regions = releaseOnException { + val keys = rl.getStartEndKeys + keys.getFirst + .zip(keys.getSecond) + .zipWithIndex + .map( + x => + HBaseRegion( + x._2, + Some(x._1._1), + Some(x._1._2), + Some(rl.getRegionLocation(x._1._1).getHostname))) + } + + override def init(): Unit = { + connection = HBaseConnectionCache.getConnection(relation.hbaseConf) + rl = connection.getRegionLocator(TableName.valueOf(relation.tableName)) + } + + override def destroy(): Unit = { + if (rl != null) { + rl.close() + rl = null + } + if (connection != null) { + connection.close() + connection = null + } + } +} + +@InterfaceAudience.Private +object HBaseResources { + implicit def ScanResToScan(sr: ScanResource): ResultScanner = { + sr.rs + } + + implicit def GetResToResult(gr: GetResource): Array[Result] = { + gr.rs + } + + implicit def TableResToTable(tr: TableResource): Table = { + tr.table + } + + implicit def RegionResToRegions(rr: RegionResource): Seq[HBaseRegion] = { + rr.regions + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala new file mode 100644 index 00000000..77c15316 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.datasources + +import org.apache.yetus.audience.InterfaceAudience; + +/** + * This is the hbase configuration. User can either set them in SparkConf, which + * will take effect globally, or configure it per table, which will overwrite the value + * set in SparkConf. If not set, the default value will take effect. + */ +@InterfaceAudience.Public +object HBaseSparkConf { + + /** + * Set to false to disable server-side caching of blocks for this scan, + * false by default, since full table scans generate too much BC churn. + */ + val QUERY_CACHEBLOCKS = "hbase.spark.query.cacheblocks" + val DEFAULT_QUERY_CACHEBLOCKS = false + + /** The number of rows for caching that will be passed to scan. */ + val QUERY_CACHEDROWS = "hbase.spark.query.cachedrows" + + /** Set the maximum number of values to return for each call to next() in scan. */ + val QUERY_BATCHSIZE = "hbase.spark.query.batchsize" + + /** The number of BulkGets send to HBase. */ + val BULKGET_SIZE = "hbase.spark.bulkget.size" + val DEFAULT_BULKGET_SIZE = 1000 + + /** Set to specify the location of hbase configuration file. */ + val HBASE_CONFIG_LOCATION = "hbase.spark.config.location" + + /** Set to specify whether create or use latest cached HBaseContext */ + val USE_HBASECONTEXT = "hbase.spark.use.hbasecontext" + val DEFAULT_USE_HBASECONTEXT = true + + /** Pushdown the filter to data source engine to increase the performance of queries. */ + val PUSHDOWN_COLUMN_FILTER = "hbase.spark.pushdown.columnfilter" + val DEFAULT_PUSHDOWN_COLUMN_FILTER = true + + /** Class name of the encoder, which encode data types from Spark to HBase bytes. */ + val QUERY_ENCODER = "hbase.spark.query.encoder" + val DEFAULT_QUERY_ENCODER = classOf[NaiveEncoder].getCanonicalName + + /** The timestamp used to filter columns with a specific timestamp. */ + val TIMESTAMP = "hbase.spark.query.timestamp" + + /** The starting timestamp used to filter columns with a specific range of versions. */ + val TIMERANGE_START = "hbase.spark.query.timerange.start" + + /** The ending timestamp used to filter columns with a specific range of versions. */ + val TIMERANGE_END = "hbase.spark.query.timerange.end" + + /** The maximum number of version to return. */ + val MAX_VERSIONS = "hbase.spark.query.maxVersions" + + /** Delayed time to close hbase-spark connection when no reference to this connection, in milliseconds. */ + val DEFAULT_CONNECTION_CLOSE_DELAY = 10 * 60 * 1000 +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableCatalog.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableCatalog.scala new file mode 100644 index 00000000..d88d306c --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableCatalog.scala @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.datasources + +import org.apache.avro.Schema +import org.apache.hadoop.hbase.spark.{Logging, SchemaConverters} +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.sql.types._ +import org.apache.yetus.audience.InterfaceAudience +import org.json4s.DefaultFormats +import org.json4s.Formats +import org.json4s.jackson.JsonMethods +import scala.collection.mutable + +// The definition of each column cell, which may be composite type +// TODO: add avro support +@InterfaceAudience.Private +case class Field( + colName: String, + cf: String, + col: String, + sType: Option[String] = None, + avroSchema: Option[String] = None, + serdes: Option[SerDes] = None, + len: Int = -1) + extends Logging { + override def toString = s"$colName $cf $col" + val isRowKey = cf == HBaseTableCatalog.rowKey + var start: Int = _ + def schema: Option[Schema] = avroSchema.map { + x => + logDebug(s"avro: $x") + val p = new Schema.Parser + p.parse(x) + } + + lazy val exeSchema = schema + + // converter from avro to catalyst structure + lazy val avroToCatalyst: Option[Any => Any] = { + schema.map(SchemaConverters.createConverterToSQL(_)) + } + + // converter from catalyst to avro + lazy val catalystToAvro: (Any) => Any = { + SchemaConverters.createConverterToAvro(dt, colName, "recordNamespace") + } + + def cfBytes: Array[Byte] = { + if (isRowKey) { + Bytes.toBytes("") + } else { + Bytes.toBytes(cf) + } + } + def colBytes: Array[Byte] = { + if (isRowKey) { + Bytes.toBytes("key") + } else { + Bytes.toBytes(col) + } + } + + val dt = { + sType.map(DataTypeParserWrapper.parse(_)).getOrElse { + schema.map { x => SchemaConverters.toSqlType(x).dataType }.get + } + } + + var length: Int = { + if (len == -1) { + dt match { + case BinaryType | StringType => -1 + case BooleanType => Bytes.SIZEOF_BOOLEAN + case ByteType => 1 + case DoubleType => Bytes.SIZEOF_DOUBLE + case FloatType => Bytes.SIZEOF_FLOAT + case IntegerType => Bytes.SIZEOF_INT + case LongType => Bytes.SIZEOF_LONG + case ShortType => Bytes.SIZEOF_SHORT + case _ => -1 + } + } else { + len + } + + } + + override def equals(other: Any): Boolean = other match { + case that: Field => + colName == that.colName && cf == that.cf && col == that.col + case _ => false + } +} + +// The row key definition, with each key refer to the col defined in Field, e.g., +// key1:key2:key3 +@InterfaceAudience.Private +case class RowKey(k: String) { + val keys = k.split(":") + var fields: Seq[Field] = _ + var varLength = false + def length = { + if (varLength) { + -1 + } else { + fields.foldLeft(0) { + case (x, y) => + x + y.length + } + } + } +} +// The map between the column presented to Spark and the HBase field +@InterfaceAudience.Private +case class SchemaMap(map: mutable.HashMap[String, Field]) { + def toFields = map.map { + case (name, field) => + StructField(name, field.dt) + }.toSeq + + def fields = map.values + + def getField(name: String) = map(name) +} + +// The definition of HBase and Relation relation schema +@InterfaceAudience.Private +case class HBaseTableCatalog( + namespace: String, + name: String, + row: RowKey, + sMap: SchemaMap, + @transient params: Map[String, String]) + extends Logging { + def toDataType = StructType(sMap.toFields) + def getField(name: String) = sMap.getField(name) + def getRowKey: Seq[Field] = row.fields + def getPrimaryKey = row.keys(0) + def getColumnFamilies = { + sMap.fields.map(_.cf).filter(_ != HBaseTableCatalog.rowKey).toSeq.distinct + } + + def get(key: String) = params.get(key) + + // Setup the start and length for each dimension of row key at runtime. + def dynSetupRowKey(rowKey: Array[Byte]) { + logDebug(s"length: ${rowKey.length}") + if (row.varLength) { + var start = 0 + row.fields.foreach { + f => + logDebug(s"start: $start") + f.start = start + f.length = { + // If the length is not defined + if (f.length == -1) { + f.dt match { + case StringType => + var pos = rowKey.indexOf(HBaseTableCatalog.delimiter, start) + if (pos == -1 || pos > rowKey.length) { + // this is at the last dimension + pos = rowKey.length + } + pos - start + // We don't know the length, assume it extend to the end of the rowkey. + case _ => rowKey.length - start + } + } else { + f.length + } + } + start += f.length + } + } + } + + def initRowKey = { + val fields = sMap.fields.filter(_.cf == HBaseTableCatalog.rowKey) + row.fields = row.keys.flatMap(n => fields.find(_.col == n)) + // The length is determined at run time if it is string or binary and the length is undefined. + if (row.fields.filter(_.length == -1).isEmpty) { + var start = 0 + row.fields.foreach { + f => + f.start = start + start += f.length + } + } else { + row.varLength = true + } + } + initRowKey +} + +@InterfaceAudience.Public +object HBaseTableCatalog { + // If defined and larger than 3, a new table will be created with the nubmer of region specified. + val newTable = "newtable" + // The json string specifying hbase catalog information + val regionStart = "regionStart" + val defaultRegionStart = "aaaaaaa" + val regionEnd = "regionEnd" + val defaultRegionEnd = "zzzzzzz" + val tableCatalog = "catalog" + // The row key with format key1:key2 specifying table row key + val rowKey = "rowkey" + // The key for hbase table whose value specify namespace and table name + val table = "table" + // The namespace of hbase table + val nameSpace = "namespace" + // The name of hbase table + val tableName = "name" + // The name of columns in hbase catalog + val columns = "columns" + val cf = "cf" + val col = "col" + val `type` = "type" + // the name of avro schema json string + val avro = "avro" + val delimiter: Byte = 0 + val serdes = "serdes" + val length = "length" + + /** + * User provide table schema definition + * {"tablename":"name", "rowkey":"key1:key2", + * "columns":{"col1":{"cf":"cf1", "col":"col1", "type":"type1"}, + * "col2":{"cf":"cf2", "col":"col2", "type":"type2"}}} + * Note that any col in the rowKey, there has to be one corresponding col defined in columns + */ + def apply(params: Map[String, String]): HBaseTableCatalog = { + val parameters = convert(params) + // println(jString) + val jString = parameters(tableCatalog) + implicit val formats: Formats = DefaultFormats + val map = JsonMethods.parse(jString) + val tableMeta = map \ table + val nSpace = (tableMeta \ nameSpace).extractOrElse("default") + val tName = (tableMeta \ tableName).extract[String] + val cIter = (map \ columns).extract[Map[String, Map[String, String]]] + val schemaMap = mutable.HashMap.empty[String, Field] + cIter.foreach { + case (name, column) => + val sd = { + column + .get(serdes) + .asInstanceOf[Option[String]] + .map(n => Class.forName(n).newInstance().asInstanceOf[SerDes]) + } + val len = column.get(length).map(_.toInt).getOrElse(-1) + val sAvro = column.get(avro).map(parameters(_)) + val f = Field( + name, + column.getOrElse(cf, rowKey), + column.get(col).get, + column.get(`type`), + sAvro, + sd, + len) + schemaMap.+=((name, f)) + } + val rKey = RowKey((map \ rowKey).extract[String]) + HBaseTableCatalog(nSpace, tName, rKey, SchemaMap(schemaMap), parameters) + } + + val TABLE_KEY: String = "hbase.table" + val SCHEMA_COLUMNS_MAPPING_KEY: String = "hbase.columns.mapping" + + /* for backward compatibility. Convert the old definition to new json based definition formated as below + val catalog = s"""{ + |"table":{"namespace":"default", "name":"htable"}, + |"rowkey":"key1:key2", + |"columns":{ + |"col1":{"cf":"rowkey", "col":"key1", "type":"string"}, + |"col2":{"cf":"rowkey", "col":"key2", "type":"double"}, + |"col3":{"cf":"cf1", "col":"col2", "type":"binary"}, + |"col4":{"cf":"cf1", "col":"col3", "type":"timestamp"}, + |"col5":{"cf":"cf1", "col":"col4", "type":"double", "serdes":"${classOf[DoubleSerDes].getName}"}, + |"col6":{"cf":"cf1", "col":"col5", "type":"$map"}, + |"col7":{"cf":"cf1", "col":"col6", "type":"$array"}, + |"col8":{"cf":"cf1", "col":"col7", "type":"$arrayMap"} + |} + |}""".stripMargin + */ + @deprecated("Please use new json format to define HBaseCatalog") + // TODO: There is no need to deprecate since this is the first release. + def convert(parameters: Map[String, String]): Map[String, String] = { + val nsTableName = parameters.get(TABLE_KEY).getOrElse(null) + // if the hbase.table is not defined, we assume it is json format already. + if (nsTableName == null) return parameters + val tableParts = nsTableName.trim.split(':') + val tableNamespace = if (tableParts.length == 1) { + "default" + } else if (tableParts.length == 2) { + tableParts(0) + } else { + throw new IllegalArgumentException( + "Invalid table name '" + nsTableName + + "' should be ':' or '' ") + } + val tableName = tableParts(tableParts.length - 1) + val schemaMappingString = parameters.getOrElse(SCHEMA_COLUMNS_MAPPING_KEY, "") + import scala.collection.JavaConverters._ + val schemaMap = generateSchemaMappingMap(schemaMappingString).asScala.map( + _._2.asInstanceOf[SchemaQualifierDefinition]) + + val rowkey = schemaMap + .filter { + _.columnFamily == "rowkey" + } + .map(_.columnName) + val cols = schemaMap.map { + x => + s""""${x.columnName}":{"cf":"${x.columnFamily}", "col":"${x.qualifier}", "type":"${x.colType}"}""".stripMargin + } + val jsonCatalog = + s"""{ + |"table":{"namespace":"${tableNamespace}", "name":"${tableName}"}, + |"rowkey":"${rowkey.mkString(":")}", + |"columns":{ + |${cols.mkString(",")} + |} + |} + """.stripMargin + parameters ++ Map(HBaseTableCatalog.tableCatalog -> jsonCatalog) + } + + /** + * Reads the SCHEMA_COLUMNS_MAPPING_KEY and converts it to a map of + * SchemaQualifierDefinitions with the original sql column name as the key + * + * @param schemaMappingString The schema mapping string from the SparkSQL map + * @return A map of definitions keyed by the SparkSQL column name + */ + @InterfaceAudience.Private + def generateSchemaMappingMap( + schemaMappingString: String): java.util.HashMap[String, SchemaQualifierDefinition] = { + println(schemaMappingString) + try { + val columnDefinitions = schemaMappingString.split(',') + val resultingMap = new java.util.HashMap[String, SchemaQualifierDefinition]() + columnDefinitions.map( + cd => { + val parts = cd.trim.split(' ') + + // Make sure we get three parts + // + if (parts.length == 3) { + val hbaseDefinitionParts = if (parts(2).charAt(0) == ':') { + Array[String]("rowkey", parts(0)) + } else { + parts(2).split(':') + } + resultingMap.put( + parts(0), + new SchemaQualifierDefinition( + parts(0), + parts(1), + hbaseDefinitionParts(0), + hbaseDefinitionParts(1))) + } else { + throw new IllegalArgumentException( + "Invalid value for schema mapping '" + cd + + "' should be ' :' " + + "for columns and ' :' for rowKeys") + } + }) + resultingMap + } catch { + case e: Exception => + throw new IllegalArgumentException( + "Invalid value for " + SCHEMA_COLUMNS_MAPPING_KEY + + " '" + + schemaMappingString + "'", + e) + } + } +} + +/** + * Construct to contains column data that spend SparkSQL and HBase + * + * @param columnName SparkSQL column name + * @param colType SparkSQL column type + * @param columnFamily HBase column family + * @param qualifier HBase qualifier name + */ +@InterfaceAudience.Private +case class SchemaQualifierDefinition( + columnName: String, + colType: String, + columnFamily: String, + qualifier: String) diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala new file mode 100644 index 00000000..82f7e8c4 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.datasources + +import java.util.ArrayList +import org.apache.hadoop.hbase.client._ +import org.apache.hadoop.hbase.spark._ +import org.apache.hadoop.hbase.spark.datasources.HBaseResources._ +import org.apache.hadoop.hbase.spark.hbase._ +import org.apache.hadoop.hbase.util.ShutdownHookManager +import org.apache.spark.{Partition, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.yetus.audience.InterfaceAudience +import scala.collection.mutable + +@InterfaceAudience.Private +class HBaseTableScanRDD( + relation: HBaseRelation, + val hbaseContext: HBaseContext, + @transient val filter: Option[SparkSQLPushDownFilter] = None, + val columns: Seq[Field] = Seq.empty) + extends RDD[Result](relation.sqlContext.sparkContext, Nil) { + private def sparkConf = SparkEnv.get.conf + @transient var ranges = Seq.empty[Range] + @transient var points = Seq.empty[Array[Byte]] + def addPoint(p: Array[Byte]) { + points :+= p + } + + def addRange(r: ScanRange) = { + val lower = if (r.lowerBound != null && r.lowerBound.length > 0) { + Some(Bound(r.lowerBound, r.isLowerBoundEqualTo)) + } else { + None + } + val upper = if (r.upperBound != null && r.upperBound.length > 0) { + if (!r.isUpperBoundEqualTo) { + Some(Bound(r.upperBound, false)) + } else { + + // HBase stopRow is exclusive: therefore it DOESN'T act like isUpperBoundEqualTo + // by default. So we need to add a new max byte to the stopRow key + val newArray = new Array[Byte](r.upperBound.length + 1) + System.arraycopy(r.upperBound, 0, newArray, 0, r.upperBound.length) + + // New Max Bytes + newArray(r.upperBound.length) = ByteMin + Some(Bound(newArray, false)) + } + } else { + None + } + ranges :+= Range(lower, upper) + } + + override def getPartitions: Array[Partition] = { + val regions = RegionResource(relation) + var idx = 0 + logDebug(s"There are ${regions.size} regions") + val ps = regions.flatMap { + x => + val rs = Ranges.and(Range(x), ranges) + val ps = Points.and(Range(x), points) + if (rs.size > 0 || ps.size > 0) { + if (log.isDebugEnabled) { + rs.foreach(x => logDebug(x.toString)) + } + idx += 1 + Some( + HBaseScanPartition( + idx - 1, + x, + rs, + ps, + SerializedFilter.toSerializedTypedFilter(filter))) + } else { + None + } + }.toArray + if (log.isDebugEnabled) { + logDebug(s"Partitions: ${ps.size}"); + ps.foreach(x => logDebug(x.toString)) + } + regions.release() + ShutdownHookManager.affixShutdownHook( + new Thread() { + override def run() { + HBaseConnectionCache.close() + } + }, + 0) + ps.asInstanceOf[Array[Partition]] + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + split + .asInstanceOf[HBaseScanPartition] + .regions + .server + .map { + identity + } + .toSeq + } + + private def buildGets( + tbr: TableResource, + g: Seq[Array[Byte]], + filter: Option[SparkSQLPushDownFilter], + columns: Seq[Field], + hbaseContext: HBaseContext): Iterator[Result] = { + g.grouped(relation.bulkGetSize).flatMap { + x => + val gets = new ArrayList[Get](x.size) + val rowkeySet = new mutable.HashSet[String]() + x.foreach { + y => + if (!rowkeySet.contains(y.mkString("Array(", ", ", ")"))) { + val g = new Get(y) + handleTimeSemantics(g) + columns.foreach { + d => + if (!d.isRowKey) { + g.addColumn(d.cfBytes, d.colBytes) + } + } + filter.foreach(g.setFilter(_)) + gets.add(g) + rowkeySet.add(y.mkString("Array(", ", ", ")")) + } + } + hbaseContext.applyCreds() + val tmp = tbr.get(gets) + rddResources.addResource(tmp) + toResultIterator(tmp) + } + } + + private def toResultIterator(result: GetResource): Iterator[Result] = { + val iterator = new Iterator[Result] { + var idx = 0 + var cur: Option[Result] = None + override def hasNext: Boolean = { + while (idx < result.length && cur.isEmpty) { + val r = result(idx) + idx += 1 + if (!r.isEmpty) { + cur = Some(r) + } + } + if (cur.isEmpty) { + rddResources.release(result) + } + cur.isDefined + } + override def next(): Result = { + hasNext + val ret = cur.get + cur = None + ret + } + } + iterator + } + + private def buildScan( + range: Range, + filter: Option[SparkSQLPushDownFilter], + columns: Seq[Field]): Scan = { + val scan = (range.lower, range.upper) match { + case (Some(Bound(a, b)), Some(Bound(c, d))) => new Scan(a, c) + case (None, Some(Bound(c, d))) => new Scan(Array[Byte](), c) + case (Some(Bound(a, b)), None) => new Scan(a) + case (None, None) => new Scan() + } + handleTimeSemantics(scan) + + columns.foreach { + d => + if (!d.isRowKey) { + scan.addColumn(d.cfBytes, d.colBytes) + } + } + scan.setCacheBlocks(relation.blockCacheEnable) + scan.setBatch(relation.batchNum) + scan.setCaching(relation.cacheSize) + filter.foreach(scan.setFilter(_)) + scan + } + private def toResultIterator(scanner: ScanResource): Iterator[Result] = { + val iterator = new Iterator[Result] { + var cur: Option[Result] = None + override def hasNext: Boolean = { + if (cur.isEmpty) { + val r = scanner.next() + if (r == null) { + rddResources.release(scanner) + } else { + cur = Some(r) + } + } + cur.isDefined + } + override def next(): Result = { + hasNext + val ret = cur.get + cur = None + ret + } + } + iterator + } + + lazy val rddResources = RDDResources(new mutable.HashSet[Resource]()) + + private def close() { + rddResources.release() + } + + override def compute(split: Partition, context: TaskContext): Iterator[Result] = { + val partition = split.asInstanceOf[HBaseScanPartition] + val filter = SerializedFilter.fromSerializedFilter(partition.sf) + val scans = partition.scanRanges + .map(buildScan(_, filter, columns)) + val tableResource = TableResource(relation) + context.addTaskCompletionListener[Unit](context => close()) + val points = partition.points + val gIt: Iterator[Result] = { + if (points.isEmpty) { + Iterator.empty: Iterator[Result] + } else { + buildGets(tableResource, points, filter, columns, hbaseContext) + } + } + val rIts = scans + .map { + scan => + hbaseContext.applyCreds() + val scanner = tableResource.getScanner(scan) + rddResources.addResource(scanner) + scanner + } + .map(toResultIterator(_)) + .fold(Iterator.empty: Iterator[Result]) { + case (x, y) => + x ++ y + } ++ gIt + ShutdownHookManager.affixShutdownHook( + new Thread() { + override def run() { + HBaseConnectionCache.close() + } + }, + 0) + rIts + } + + private def handleTimeSemantics(query: Query): Unit = { + // Set timestamp related values if present + (query, relation.timestamp, relation.minTimestamp, relation.maxTimestamp) match { + case (q: Scan, Some(ts), None, None) => q.setTimeStamp(ts) + case (q: Get, Some(ts), None, None) => q.setTimeStamp(ts) + + case (q: Scan, None, Some(minStamp), Some(maxStamp)) => q.setTimeRange(minStamp, maxStamp) + case (q: Get, None, Some(minStamp), Some(maxStamp)) => q.setTimeRange(minStamp, maxStamp) + + case (q, None, None, None) => + + case _ => + throw new IllegalArgumentException( + s"Invalid combination of query/timestamp/time range provided. " + + s"timeStamp is: ${relation.timestamp.get}, minTimeStamp is: ${relation.minTimestamp.get}, " + + s"maxTimeStamp is: ${relation.maxTimestamp.get}") + } + if (relation.maxVersions.isDefined) { + query match { + case q: Scan => q.setMaxVersions(relation.maxVersions.get) + case q: Get => q.setMaxVersions(relation.maxVersions.get) + case _ => throw new IllegalArgumentException("Invalid query provided with maxVersions") + } + } + } +} + +@InterfaceAudience.Private +case class SerializedFilter(b: Option[Array[Byte]]) + +object SerializedFilter { + def toSerializedTypedFilter(f: Option[SparkSQLPushDownFilter]): SerializedFilter = { + SerializedFilter(f.map(_.toByteArray)) + } + + def fromSerializedFilter(sf: SerializedFilter): Option[SparkSQLPushDownFilter] = { + sf.b.map(SparkSQLPushDownFilter.parseFrom(_)) + } +} + +@InterfaceAudience.Private +private[hbase] case class HBaseRegion( + override val index: Int, + val start: Option[HBaseType] = None, + val end: Option[HBaseType] = None, + val server: Option[String] = None) + extends Partition + +@InterfaceAudience.Private +private[hbase] case class HBaseScanPartition( + override val index: Int, + val regions: HBaseRegion, + val scanRanges: Seq[Range], + val points: Seq[Array[Byte]], + val sf: SerializedFilter) + extends Partition + +@InterfaceAudience.Private +case class RDDResources(set: mutable.HashSet[Resource]) { + def addResource(s: Resource) { + set += s + } + def release() { + set.foreach(release(_)) + } + def release(rs: Resource) { + try { + rs.release() + } finally { + set.remove(rs) + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/JavaBytesEncoder.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/JavaBytesEncoder.scala new file mode 100644 index 00000000..eac4feb6 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/JavaBytesEncoder.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.datasources + +import org.apache.hadoop.hbase.HBaseInterfaceAudience +import org.apache.hadoop.hbase.spark.Logging +import org.apache.hadoop.hbase.spark.datasources.JavaBytesEncoder.JavaBytesEncoder +import org.apache.spark.sql.types._ +import org.apache.yetus.audience.InterfaceAudience +import org.apache.yetus.audience.InterfaceStability + +/** + * The ranges for the data type whose size is known. Whether the bound is inclusive + * or exclusive is undefind, and upper to the caller to decide. + * + * @param low: the lower bound of the range. + * @param upper: the upper bound of the range. + */ +@InterfaceAudience.LimitedPrivate(Array(HBaseInterfaceAudience.SPARK)) +@InterfaceStability.Evolving +case class BoundRange(low: Array[Byte], upper: Array[Byte]) + +/** + * The class identifies the ranges for a java primitive type. The caller needs + * to decide the bound is either inclusive or exclusive on its own. + * information + * + * @param less: the set of ranges for LessThan/LessOrEqualThan + * @param greater: the set of ranges for GreaterThan/GreaterThanOrEqualTo + * @param value: the byte array of the original value + */ +@InterfaceAudience.LimitedPrivate(Array(HBaseInterfaceAudience.SPARK)) +@InterfaceStability.Evolving +case class BoundRanges(less: Array[BoundRange], greater: Array[BoundRange], value: Array[Byte]) + +/** + * The trait to support plugin architecture for different encoder/decoder. + * encode is used for serializing the data type to byte array and the filter is + * used to filter out the unnecessary records. + */ +@InterfaceAudience.LimitedPrivate(Array(HBaseInterfaceAudience.SPARK)) +@InterfaceStability.Evolving +trait BytesEncoder { + def encode(dt: DataType, value: Any): Array[Byte] + + /** + * The function performing real filtering operations. The format of filterBytes depends on the + * implementation of the BytesEncoder. + * + * @param input: the current input byte array that needs to be filtered out + * @param offset1: the starting offset of the input byte array. + * @param length1: the length of the input byte array. + * @param filterBytes: the byte array provided by query condition. + * @param offset2: the starting offset in the filterBytes. + * @param length2: the length of the bytes in the filterBytes + * @param ops: The operation of the filter operator. + * @return true: the record satisfies the predicates + * false: the record does not satisfy the predicates. + */ + def filter( + input: Array[Byte], + offset1: Int, + length1: Int, + filterBytes: Array[Byte], + offset2: Int, + length2: Int, + ops: JavaBytesEncoder): Boolean + + /** + * Currently, it is used for partition pruning. + * As for some codec, the order may be inconsistent between java primitive + * type and its byte array. We may have to split the predicates on some + * of the java primitive type into multiple predicates. + * + * For example in naive codec, some of the java primitive types have to be + * split into multiple predicates, and union these predicates together to + * make the predicates be performed correctly. + * For example, if we have "COLUMN < 2", we will transform it into + * "0 <= COLUMN < 2 OR Integer.MIN_VALUE <= COLUMN <= -1" + */ + def ranges(in: Any): Option[BoundRanges] +} + +@InterfaceAudience.LimitedPrivate(Array(HBaseInterfaceAudience.SPARK)) +@InterfaceStability.Evolving +object JavaBytesEncoder extends Enumeration with Logging { + type JavaBytesEncoder = Value + val Greater, GreaterEqual, Less, LessEqual, Equal, Unknown = Value + + /** + * create the encoder/decoder + * + * @param clsName: the class name of the encoder/decoder class + * @return the instance of the encoder plugin. + */ + def create(clsName: String): BytesEncoder = { + try { + Class.forName(clsName).newInstance.asInstanceOf[BytesEncoder] + } catch { + case _: Throwable => + logWarning(s"$clsName cannot be initiated, falling back to naive encoder") + new NaiveEncoder() + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/NaiveEncoder.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/NaiveEncoder.scala new file mode 100644 index 00000000..b54d2797 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/NaiveEncoder.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.datasources +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.hadoop.hbase.spark.Logging +import org.apache.hadoop.hbase.spark.datasources.JavaBytesEncoder.JavaBytesEncoder +import org.apache.hadoop.hbase.spark.hbase._ +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is the naive non-order preserving encoder/decoder. + * Due to the inconsistency of the order between java primitive types + * and their bytearray. The data type has to be passed in so that the filter + * can work correctly, which is done by wrapping the type into the first byte + * of the serialized array. + */ +@InterfaceAudience.Private +class NaiveEncoder extends BytesEncoder with Logging { + var code = 0 + def nextCode: Byte = { + code += 1 + (code - 1).asInstanceOf[Byte] + } + val BooleanEnc = nextCode + val ShortEnc = nextCode + val IntEnc = nextCode + val LongEnc = nextCode + val FloatEnc = nextCode + val DoubleEnc = nextCode + val StringEnc = nextCode + val BinaryEnc = nextCode + val TimestampEnc = nextCode + val UnknownEnc = nextCode + + /** + * Evaluate the java primitive type and return the BoundRanges. For one value, it may have + * multiple output ranges because of the inconsistency of order between java primitive type + * and its byte array order. + * + * For short, integer, and long, the order of number is consistent with byte array order + * if two number has the same sign bit. But the negative number is larger than positive + * number in byte array. + * + * For double and float, the order of positive number is consistent with its byte array order. + * But the order of negative number is the reverse order of byte array. Please refer to IEEE-754 + * and https://en.wikipedia.org/wiki/Single-precision_floating-point_format + */ + def ranges(in: Any): Option[BoundRanges] = in match { + case a: Integer => + val b = Bytes.toBytes(a) + if (a >= 0) { + logDebug(s"range is 0 to $a and ${Integer.MIN_VALUE} to -1") + Some( + BoundRanges( + Array( + BoundRange(Bytes.toBytes(0: Int), b), + BoundRange(Bytes.toBytes(Integer.MIN_VALUE), Bytes.toBytes(-1: Int))), + Array(BoundRange(b, Bytes.toBytes(Integer.MAX_VALUE))), + b)) + } else { + Some( + BoundRanges( + Array(BoundRange(Bytes.toBytes(Integer.MIN_VALUE), b)), + Array( + BoundRange(b, Bytes.toBytes(-1: Integer)), + BoundRange(Bytes.toBytes(0: Int), Bytes.toBytes(Integer.MAX_VALUE))), + b)) + } + case a: Long => + val b = Bytes.toBytes(a) + if (a >= 0) { + Some( + BoundRanges( + Array( + BoundRange(Bytes.toBytes(0: Long), b), + BoundRange(Bytes.toBytes(Long.MinValue), Bytes.toBytes(-1: Long))), + Array(BoundRange(b, Bytes.toBytes(Long.MaxValue))), + b)) + } else { + Some( + BoundRanges( + Array(BoundRange(Bytes.toBytes(Long.MinValue), b)), + Array( + BoundRange(b, Bytes.toBytes(-1: Long)), + BoundRange(Bytes.toBytes(0: Long), Bytes.toBytes(Long.MaxValue))), + b)) + } + case a: Short => + val b = Bytes.toBytes(a) + if (a >= 0) { + Some( + BoundRanges( + Array( + BoundRange(Bytes.toBytes(0: Short), b), + BoundRange(Bytes.toBytes(Short.MinValue), Bytes.toBytes(-1: Short))), + Array(BoundRange(b, Bytes.toBytes(Short.MaxValue))), + b)) + } else { + Some( + BoundRanges( + Array(BoundRange(Bytes.toBytes(Short.MinValue), b)), + Array( + BoundRange(b, Bytes.toBytes(-1: Short)), + BoundRange(Bytes.toBytes(0: Short), Bytes.toBytes(Short.MaxValue))), + b)) + } + case a: Double => + val b = Bytes.toBytes(a) + if (a >= 0.0f) { + Some( + BoundRanges( + Array( + BoundRange(Bytes.toBytes(0.0d), b), + BoundRange(Bytes.toBytes(-0.0d), Bytes.toBytes(Double.MinValue))), + Array(BoundRange(b, Bytes.toBytes(Double.MaxValue))), + b)) + } else { + Some( + BoundRanges( + Array(BoundRange(b, Bytes.toBytes(Double.MinValue))), + Array( + BoundRange(Bytes.toBytes(-0.0d), b), + BoundRange(Bytes.toBytes(0.0d), Bytes.toBytes(Double.MaxValue))), + b)) + } + case a: Float => + val b = Bytes.toBytes(a) + if (a >= 0.0f) { + Some( + BoundRanges( + Array( + BoundRange(Bytes.toBytes(0.0f), b), + BoundRange(Bytes.toBytes(-0.0f), Bytes.toBytes(Float.MinValue))), + Array(BoundRange(b, Bytes.toBytes(Float.MaxValue))), + b)) + } else { + Some( + BoundRanges( + Array(BoundRange(b, Bytes.toBytes(Float.MinValue))), + Array( + BoundRange(Bytes.toBytes(-0.0f), b), + BoundRange(Bytes.toBytes(0.0f), Bytes.toBytes(Float.MaxValue))), + b)) + } + case a: Array[Byte] => + Some(BoundRanges(Array(BoundRange(bytesMin, a)), Array(BoundRange(a, bytesMax)), a)) + case a: Byte => + val b = Array(a) + Some(BoundRanges(Array(BoundRange(bytesMin, b)), Array(BoundRange(b, bytesMax)), b)) + case a: String => + val b = Bytes.toBytes(a) + Some(BoundRanges(Array(BoundRange(bytesMin, b)), Array(BoundRange(b, bytesMax)), b)) + case a: UTF8String => + val b = a.getBytes + Some(BoundRanges(Array(BoundRange(bytesMin, b)), Array(BoundRange(b, bytesMax)), b)) + case _ => None + } + + def compare(c: Int, ops: JavaBytesEncoder): Boolean = { + ops match { + case JavaBytesEncoder.Greater => c > 0 + case JavaBytesEncoder.GreaterEqual => c >= 0 + case JavaBytesEncoder.Less => c < 0 + case JavaBytesEncoder.LessEqual => c <= 0 + } + } + + /** + * encode the data type into byte array. Note that it is a naive implementation with the + * data type byte appending to the head of the serialized byte array. + * + * @param dt: The data type of the input + * @param value: the value of the input + * @return the byte array with the first byte indicating the data type. + */ + override def encode(dt: DataType, value: Any): Array[Byte] = { + dt match { + case BooleanType => + val result = new Array[Byte](Bytes.SIZEOF_BOOLEAN + 1) + result(0) = BooleanEnc + value.asInstanceOf[Boolean] match { + case true => result(1) = -1: Byte + case false => result(1) = 0: Byte + } + result + case ShortType => + val result = new Array[Byte](Bytes.SIZEOF_SHORT + 1) + result(0) = ShortEnc + Bytes.putShort(result, 1, value.asInstanceOf[Short]) + result + case IntegerType => + val result = new Array[Byte](Bytes.SIZEOF_INT + 1) + result(0) = IntEnc + Bytes.putInt(result, 1, value.asInstanceOf[Int]) + result + case LongType | TimestampType => + val result = new Array[Byte](Bytes.SIZEOF_LONG + 1) + result(0) = LongEnc + Bytes.putLong(result, 1, value.asInstanceOf[Long]) + result + case FloatType => + val result = new Array[Byte](Bytes.SIZEOF_FLOAT + 1) + result(0) = FloatEnc + Bytes.putFloat(result, 1, value.asInstanceOf[Float]) + result + case DoubleType => + val result = new Array[Byte](Bytes.SIZEOF_DOUBLE + 1) + result(0) = DoubleEnc + Bytes.putDouble(result, 1, value.asInstanceOf[Double]) + result + case BinaryType => + val v = value.asInstanceOf[Array[Bytes]] + val result = new Array[Byte](v.length + 1) + result(0) = BinaryEnc + System.arraycopy(v, 0, result, 1, v.length) + result + case StringType => + val bytes = Bytes.toBytes(value.asInstanceOf[String]) + val result = new Array[Byte](bytes.length + 1) + result(0) = StringEnc + System.arraycopy(bytes, 0, result, 1, bytes.length) + result + case _ => + val bytes = Bytes.toBytes(value.toString) + val result = new Array[Byte](bytes.length + 1) + result(0) = UnknownEnc + System.arraycopy(bytes, 0, result, 1, bytes.length) + result + } + } + + override def filter( + input: Array[Byte], + offset1: Int, + length1: Int, + filterBytes: Array[Byte], + offset2: Int, + length2: Int, + ops: JavaBytesEncoder): Boolean = { + filterBytes(offset2) match { + case ShortEnc => + val in = Bytes.toShort(input, offset1) + val value = Bytes.toShort(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case IntEnc => + val in = Bytes.toInt(input, offset1) + val value = Bytes.toInt(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case LongEnc | TimestampEnc => + val in = Bytes.toLong(input, offset1) + val value = Bytes.toLong(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case FloatEnc => + val in = Bytes.toFloat(input, offset1) + val value = Bytes.toFloat(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case DoubleEnc => + val in = Bytes.toDouble(input, offset1) + val value = Bytes.toDouble(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case _ => + // for String, Byte, Binary, Boolean and other types + // we can use the order of byte array directly. + compare( + Bytes.compareTo(input, offset1, length1, filterBytes, offset2 + 1, length2 - 1), + ops) + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SchemaConverters.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SchemaConverters.scala new file mode 100644 index 00000000..3eafe172 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SchemaConverters.scala @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.io.ByteArrayInputStream +import java.nio.ByteBuffer +import java.sql.Timestamp +import java.util +import java.util.HashMap +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.Schema.Type._ +import org.apache.avro.SchemaBuilder.BaseFieldTypeBuilder +import org.apache.avro.SchemaBuilder.BaseTypeBuilder +import org.apache.avro.SchemaBuilder.FieldAssembler +import org.apache.avro.SchemaBuilder.FieldDefault +import org.apache.avro.SchemaBuilder.RecordBuilder +import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} +import org.apache.avro.generic.GenericData.{Fixed, Record} +import org.apache.avro.io._ +import org.apache.commons.io.output.ByteArrayOutputStream +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.apache.yetus.audience.InterfaceAudience +import scala.jdk.CollectionConverters._ +import scala.collection.immutable.Map + +@InterfaceAudience.Private +abstract class AvroException(msg: String) extends Exception(msg) + +@InterfaceAudience.Private +case class SchemaConversionException(msg: String) extends AvroException(msg) + +/** + * * + * On top level, the converters provide three high level interface. + * 1. toSqlType: This function takes an avro schema and returns a sql schema. + * 2. createConverterToSQL: Returns a function that is used to convert avro types to their + * corresponding sparkSQL representations. + * 3. convertTypeToAvro: This function constructs converter function for a given sparkSQL + * datatype. This is used in writing Avro records out to disk + */ +@InterfaceAudience.Private +object SchemaConverters { + + case class SchemaType(dataType: DataType, nullable: Boolean) + + /** + * This function takes an avro schema and returns a sql schema. + */ + def toSqlType(avroSchema: Schema): SchemaType = { + avroSchema.getType match { + case INT => SchemaType(IntegerType, nullable = false) + case STRING => SchemaType(StringType, nullable = false) + case BOOLEAN => SchemaType(BooleanType, nullable = false) + case BYTES => SchemaType(BinaryType, nullable = false) + case DOUBLE => SchemaType(DoubleType, nullable = false) + case FLOAT => SchemaType(FloatType, nullable = false) + case LONG => SchemaType(LongType, nullable = false) + case FIXED => SchemaType(BinaryType, nullable = false) + case ENUM => SchemaType(StringType, nullable = false) + + case RECORD => + val fields = avroSchema.getFields.asScala.map { + f => + val schemaType = toSqlType(f.schema()) + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + + SchemaType(StructType(fields.toList.asJava), nullable = false) + + case ARRAY => + val schemaType = toSqlType(avroSchema.getElementType) + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + + case MAP => + val schemaType = toSqlType(avroSchema.getValueType) + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + + case UNION => + if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { + // In case of a union with null, eliminate it and make a recursive call + val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + toSqlType(remainingUnionTypes.head).copy(nullable = true) + } else { + toSqlType(Schema.createUnion(remainingUnionTypes.toList.asJava)).copy(nullable = true) + } + } else + avroSchema.getTypes.asScala.map(_.getType) match { + case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => + SchemaType(LongType, nullable = false) + case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => + SchemaType(DoubleType, nullable = false) + case other => + throw new SchemaConversionException( + s"This mix of union types is not supported: $other") + } + + case other => throw new SchemaConversionException(s"Unsupported type $other") + } + } + + /** + * This function converts sparkSQL StructType into avro schema. This method uses two other + * converter methods in order to do the conversion. + */ + private def convertStructToAvro[T]( + structType: StructType, + schemaBuilder: RecordBuilder[T], + recordNamespace: String): T = { + val fieldsAssembler: FieldAssembler[T] = schemaBuilder.fields() + structType.fields.foreach { + field => + val newField = fieldsAssembler.name(field.name).`type`() + + if (field.nullable) { + convertFieldTypeToAvro( + field.dataType, + newField.nullable(), + field.name, + recordNamespace).noDefault + } else { + convertFieldTypeToAvro(field.dataType, newField, field.name, recordNamespace).noDefault + } + } + fieldsAssembler.endRecord() + } + + /** + * Returns a function that is used to convert avro types to their + * corresponding sparkSQL representations. + */ + def createConverterToSQL(schema: Schema): Any => Any = { + schema.getType match { + // Avro strings are in Utf8, so we have to call toString on them + case STRING | ENUM => (item: Any) => if (item == null) null else item.toString + case INT | BOOLEAN | DOUBLE | FLOAT | LONG => identity + // Byte arrays are reused by avro, so we have to make a copy of them. + case FIXED => + (item: Any) => + if (item == null) { + null + } else { + item.asInstanceOf[Fixed].bytes().clone() + } + case BYTES => + (item: Any) => + if (item == null) { + null + } else { + val bytes = item.asInstanceOf[ByteBuffer] + val javaBytes = new Array[Byte](bytes.remaining) + bytes.get(javaBytes) + javaBytes + } + case RECORD => + val fieldConverters = schema.getFields.asScala.map(f => createConverterToSQL(f.schema)) + (item: Any) => + if (item == null) { + null + } else { + val record = item.asInstanceOf[GenericRecord] + val converted = new Array[Any](fieldConverters.size) + var idx = 0 + while (idx < fieldConverters.size) { + converted(idx) = fieldConverters.apply(idx)(record.get(idx)) + idx += 1 + } + Row.fromSeq(converted.toSeq) + } + case ARRAY => + val elementConverter = createConverterToSQL(schema.getElementType) + (item: Any) => + if (item == null) { + null + } else { + try { + item.asInstanceOf[GenericData.Array[Any]].asScala.map(elementConverter) + } catch { + case e: Throwable => + item.asInstanceOf[util.ArrayList[Any]].asScala.map(elementConverter) + } + } + case MAP => + val valueConverter = createConverterToSQL(schema.getValueType) + (item: Any) => + if (item == null) { + null + } else { + item + .asInstanceOf[HashMap[Any, Any]].asScala + .map(x => (x._1.toString, valueConverter(x._2))) + .toMap + } + case UNION => + if (schema.getTypes.asScala.exists(_.getType == NULL)) { + val remainingUnionTypes = schema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + createConverterToSQL(remainingUnionTypes.head) + } else { + createConverterToSQL(Schema.createUnion(remainingUnionTypes.toList.asJava)) + } + } else + schema.getTypes.asScala.map(_.getType) match { + case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => + (item: Any) => { + item match { + case l: Long => l + case i: Int => i.toLong + case null => null + } + } + case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => + (item: Any) => { + item match { + case d: Double => d + case f: Float => f.toDouble + case null => null + } + } + case other => + throw new SchemaConversionException( + s"This mix of union types is not supported (see README): $other") + } + case other => throw new SchemaConversionException(s"invalid avro type: $other") + } + } + + /** + * This function is used to convert some sparkSQL type to avro type. Note that this function won't + * be used to construct fields of avro record (convertFieldTypeToAvro is used for that). + */ + private def convertTypeToAvro[T]( + dataType: DataType, + schemaBuilder: BaseTypeBuilder[T], + structName: String, + recordNamespace: String): T = { + dataType match { + case ByteType => schemaBuilder.intType() + case ShortType => schemaBuilder.intType() + case IntegerType => schemaBuilder.intType() + case LongType => schemaBuilder.longType() + case FloatType => schemaBuilder.floatType() + case DoubleType => schemaBuilder.doubleType() + case _: DecimalType => schemaBuilder.stringType() + case StringType => schemaBuilder.stringType() + case BinaryType => schemaBuilder.bytesType() + case BooleanType => schemaBuilder.booleanType() + case TimestampType => schemaBuilder.longType() + + case ArrayType(elementType, _) => + val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) + val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace) + schemaBuilder.array().items(elementSchema) + + case MapType(StringType, valueType, _) => + val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull) + val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace) + schemaBuilder.map().values(valueSchema) + + case structType: StructType => + convertStructToAvro( + structType, + schemaBuilder.record(structName).namespace(recordNamespace), + recordNamespace) + + case other => throw new IllegalArgumentException(s"Unexpected type $dataType.") + } + } + + /** + * This function is used to construct fields of the avro record, where schema of the field is + * specified by avro representation of dataType. Since builders for record fields are different + * from those for everything else, we have to use a separate method. + */ + private def convertFieldTypeToAvro[T]( + dataType: DataType, + newFieldBuilder: BaseFieldTypeBuilder[T], + structName: String, + recordNamespace: String): FieldDefault[T, _] = { + dataType match { + case ByteType => newFieldBuilder.intType() + case ShortType => newFieldBuilder.intType() + case IntegerType => newFieldBuilder.intType() + case LongType => newFieldBuilder.longType() + case FloatType => newFieldBuilder.floatType() + case DoubleType => newFieldBuilder.doubleType() + case _: DecimalType => newFieldBuilder.stringType() + case StringType => newFieldBuilder.stringType() + case BinaryType => newFieldBuilder.bytesType() + case BooleanType => newFieldBuilder.booleanType() + case TimestampType => newFieldBuilder.longType() + + case ArrayType(elementType, _) => + val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) + val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace) + newFieldBuilder.array().items(elementSchema) + + case MapType(StringType, valueType, _) => + val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull) + val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace) + newFieldBuilder.map().values(valueSchema) + + case structType: StructType => + convertStructToAvro( + structType, + newFieldBuilder.record(structName).namespace(recordNamespace), + recordNamespace) + + case other => throw new IllegalArgumentException(s"Unexpected type $dataType.") + } + } + + private def getSchemaBuilder(isNullable: Boolean): BaseTypeBuilder[Schema] = { + if (isNullable) { + SchemaBuilder.builder().nullable() + } else { + SchemaBuilder.builder() + } + } + + /** + * This function constructs converter function for a given sparkSQL datatype. This is used in + * writing Avro records out to disk + */ + def createConverterToAvro( + dataType: DataType, + structName: String, + recordNamespace: String): (Any) => Any = { + dataType match { + case BinaryType => + (item: Any) => + item match { + case null => null + case bytes: Array[Byte] => ByteBuffer.wrap(bytes) + } + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | StringType | + BooleanType => + identity + case _: DecimalType => (item: Any) => if (item == null) null else item.toString + case TimestampType => + (item: Any) => if (item == null) null else item.asInstanceOf[Timestamp].getTime + case ArrayType(elementType, _) => + val elementConverter = createConverterToAvro(elementType, structName, recordNamespace) + (item: Any) => { + if (item == null) { + null + } else { + val sourceArray = item.asInstanceOf[Seq[Any]] + val sourceArraySize = sourceArray.size + val targetArray = new util.ArrayList[Any](sourceArraySize) + var idx = 0 + while (idx < sourceArraySize) { + targetArray.add(elementConverter(sourceArray(idx))) + idx += 1 + } + targetArray + } + } + case MapType(StringType, valueType, _) => + val valueConverter = createConverterToAvro(valueType, structName, recordNamespace) + (item: Any) => { + if (item == null) { + null + } else { + val javaMap = new HashMap[String, Any]() + item.asInstanceOf[Map[String, Any]].foreach { + case (key, value) => + javaMap.put(key, valueConverter(value)) + } + javaMap + } + } + case structType: StructType => + val builder = SchemaBuilder.record(structName).namespace(recordNamespace) + val schema: Schema = + SchemaConverters.convertStructToAvro(structType, builder, recordNamespace) + val fieldConverters = structType.fields.map( + field => createConverterToAvro(field.dataType, field.name, recordNamespace)) + (item: Any) => { + if (item == null) { + null + } else { + val record = new Record(schema) + val convertersIterator = fieldConverters.iterator + val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator + val rowIterator = item.asInstanceOf[Row].toSeq.iterator + + while (convertersIterator.hasNext) { + val converter = convertersIterator.next() + record.put(fieldNamesIterator.next(), converter(rowIterator.next())) + } + record + } + } + } + } +} + +@InterfaceAudience.Private +object AvroSerdes { + // We only handle top level is record or primary type now + def serialize(input: Any, schema: Schema): Array[Byte] = { + schema.getType match { + case BOOLEAN => Bytes.toBytes(input.asInstanceOf[Boolean]) + case BYTES | FIXED => input.asInstanceOf[Array[Byte]] + case DOUBLE => Bytes.toBytes(input.asInstanceOf[Double]) + case FLOAT => Bytes.toBytes(input.asInstanceOf[Float]) + case INT => Bytes.toBytes(input.asInstanceOf[Int]) + case LONG => Bytes.toBytes(input.asInstanceOf[Long]) + case STRING => Bytes.toBytes(input.asInstanceOf[String]) + case RECORD => + val gr = input.asInstanceOf[GenericRecord] + val writer2 = new GenericDatumWriter[GenericRecord](schema) + val bao2 = new ByteArrayOutputStream() + val encoder2: BinaryEncoder = EncoderFactory.get().directBinaryEncoder(bao2, null) + writer2.write(gr, encoder2) + bao2.toByteArray() + case _ => throw new Exception(s"unsupported data type ${schema.getType}") // TODO + } + } + + def deserialize(input: Array[Byte], schema: Schema): GenericRecord = { + val reader2: DatumReader[GenericRecord] = new GenericDatumReader[GenericRecord](schema) + val bai2 = new ByteArrayInputStream(input) + val decoder2: BinaryDecoder = DecoderFactory.get().directBinaryDecoder(bai2, null) + val gr2: GenericRecord = reader2.read(null, decoder2) + gr2 + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerDes.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerDes.scala new file mode 100644 index 00000000..c3d5c634 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerDes.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.datasources + +import org.apache.hadoop.hbase.util.Bytes +import org.apache.yetus.audience.InterfaceAudience + +// TODO: This is not really used in code. +@InterfaceAudience.Public +trait SerDes { + def serialize(value: Any): Array[Byte] + def deserialize(bytes: Array[Byte], start: Int, end: Int): Any +} + +// TODO: This is not really used in code. +@InterfaceAudience.Private +class DoubleSerDes extends SerDes { + override def serialize(value: Any): Array[Byte] = Bytes.toBytes(value.asInstanceOf[Double]) + override def deserialize(bytes: Array[Byte], start: Int, end: Int): Any = { + Bytes.toDouble(bytes, start) + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerializableConfiguration.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerializableConfiguration.scala new file mode 100644 index 00000000..f5b4c5a2 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerializableConfiguration.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.datasources + +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import org.apache.hadoop.conf.Configuration +import org.apache.yetus.audience.InterfaceAudience +import scala.util.control.NonFatal + +@InterfaceAudience.Private +class SerializableConfiguration(@transient var value: Configuration) extends Serializable { + private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { + out.defaultWriteObject() + value.write(out) + } + + private def readObject(in: ObjectInputStream): Unit = tryOrIOException { + value = new Configuration(false) + value.readFields(in) + } + + def tryOrIOException(block: => Unit) { + try { + block + } catch { + case e: IOException => throw e + case NonFatal(t) => throw new IOException(t) + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala new file mode 100644 index 00000000..de7eec9a --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.datasources + +import java.sql.{Date, Timestamp} +import org.apache.hadoop.hbase.spark.AvroSerdes +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.yetus.audience.InterfaceAudience; + +@InterfaceAudience.Private +object Utils { + + /** + * Parses the hbase field to it's corresponding + * scala type which can then be put into a Spark GenericRow + * which is then automatically converted by Spark. + */ + def hbaseFieldToScalaType(f: Field, src: Array[Byte], offset: Int, length: Int): Any = { + if (f.exeSchema.isDefined) { + // If we have avro schema defined, use it to get record, and then convert them to catalyst data type + val m = AvroSerdes.deserialize(src, f.exeSchema.get) + val n = f.avroToCatalyst.map(_(m)) + n.get + } else { + // Fall back to atomic type + f.dt match { + case BooleanType => src(offset) != 0 + case ByteType => src(offset) + case ShortType => Bytes.toShort(src, offset) + case IntegerType => Bytes.toInt(src, offset) + case LongType => Bytes.toLong(src, offset) + case FloatType => Bytes.toFloat(src, offset) + case DoubleType => Bytes.toDouble(src, offset) + case DateType => new Date(Bytes.toLong(src, offset)) + case TimestampType => new Timestamp(Bytes.toLong(src, offset)) + case StringType => Bytes.toString(src, offset, length) + case BinaryType => + val newArray = new Array[Byte](length) + System.arraycopy(src, offset, newArray, 0, length) + newArray + // TODO: SparkSqlSerializer.deserialize[Any](src) + case _ => throw new Exception(s"unsupported data type ${f.dt}") + } + } + } + + // convert input to data type + def toBytes(input: Any, field: Field): Array[Byte] = { + if (field.schema.isDefined) { + // Here we assume the top level type is structType + val record = field.catalystToAvro(input) + AvroSerdes.serialize(record, field.schema.get) + } else { + field.dt match { + case BooleanType => Bytes.toBytes(input.asInstanceOf[Boolean]) + case ByteType => Array(input.asInstanceOf[Number].byteValue) + case ShortType => Bytes.toBytes(input.asInstanceOf[Number].shortValue) + case IntegerType => Bytes.toBytes(input.asInstanceOf[Number].intValue) + case LongType => Bytes.toBytes(input.asInstanceOf[Number].longValue) + case FloatType => Bytes.toBytes(input.asInstanceOf[Number].floatValue) + case DoubleType => Bytes.toBytes(input.asInstanceOf[Number].doubleValue) + case DateType | TimestampType => Bytes.toBytes(input.asInstanceOf[java.util.Date].getTime) + case StringType => Bytes.toBytes(input.toString) + case BinaryType => input.asInstanceOf[Array[Byte]] + case _ => throw new Exception(s"unsupported data type ${field.dt}") + } + } + } + + // increment Byte array's value by 1 + def incrementByteArray(array: Array[Byte]): Array[Byte] = { + if (array.length == 0) { + return null + } + var index = -1 // index of the byte we have to increment + var a = array.length - 1 + + while (a >= 0) { + if (array(a) != (-1).toByte) { + index = a + a = -1 // break from the loop because we found a non -1 element + } + a = a - 1 + } + + if (index < 0) { + return null + } + val returnArray = new Array[Byte](array.length) + + for (a <- 0 until index) { + returnArray(a) = array(a) + } + returnArray(index) = (array(index) + 1).toByte + for (a <- index + 1 until array.length) { + returnArray(a) = 0.toByte + } + + returnArray + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala new file mode 100644 index 00000000..66e3897e --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.util.Bytes +import scala.math.Ordering + +// TODO: add @InterfaceAudience.Private if https://issues.scala-lang.org/browse/SI-3600 is resolved +package object hbase { + type HBaseType = Array[Byte] + def bytesMin = new Array[Byte](0) + def bytesMax = null + val ByteMax = -1.asInstanceOf[Byte] + val ByteMin = 0.asInstanceOf[Byte] + val ord: Ordering[HBaseType] = new Ordering[HBaseType] { + def compare(x: Array[Byte], y: Array[Byte]): Int = { + return Bytes.compareTo(x, y) + } + } + // Do not use BinaryType.ordering + implicit val order: Ordering[HBaseType] = ord + +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/AvroSource.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/AvroSource.scala new file mode 100644 index 00000000..b0d50290 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/AvroSource.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.datasources + +import org.apache.avro.Schema +import org.apache.avro.generic.GenericData +import org.apache.hadoop.hbase.spark.AvroSerdes +import org.apache.hadoop.hbase.spark.datasources.HBaseTableCatalog +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.SQLContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * @param col0 Column #0, Type is String + * @param col1 Column #1, Type is Array[Byte] + */ +@InterfaceAudience.Private +case class AvroHBaseRecord(col0: String, col1: Array[Byte]) +@InterfaceAudience.Private +object AvroHBaseRecord { + val schemaString = + s"""{"namespace": "example.avro", + | "type": "record", "name": "User", + | "fields": [ + | {"name": "name", "type": "string"}, + | {"name": "favorite_number", "type": ["int", "null"]}, + | {"name": "favorite_color", "type": ["string", "null"]}, + | {"name": "favorite_array", "type": {"type": "array", "items": "string"}}, + | {"name": "favorite_map", "type": {"type": "map", "values": "int"}} + | ] }""".stripMargin + + val avroSchema: Schema = { + val p = new Schema.Parser + p.parse(schemaString) + } + + def apply(i: Int): AvroHBaseRecord = { + + val user = new GenericData.Record(avroSchema); + user.put("name", s"name${"%03d".format(i)}") + user.put("favorite_number", i) + user.put("favorite_color", s"color${"%03d".format(i)}") + val favoriteArray = + new GenericData.Array[String](2, avroSchema.getField("favorite_array").schema()) + favoriteArray.add(s"number${i}") + favoriteArray.add(s"number${i + 1}") + user.put("favorite_array", favoriteArray) + import scala.collection.JavaConverters._ + val favoriteMap = Map[String, Int](("key1" -> i), ("key2" -> (i + 1))).asJava + user.put("favorite_map", favoriteMap) + val avroByte = AvroSerdes.serialize(user, avroSchema) + AvroHBaseRecord(s"name${"%03d".format(i)}", avroByte) + } +} + +@InterfaceAudience.Private +object AvroSource { + def catalog = s"""{ + |"table":{"namespace":"default", "name":"ExampleAvrotable"}, + |"rowkey":"key", + |"columns":{ + |"col0":{"cf":"rowkey", "col":"key", "type":"string"}, + |"col1":{"cf":"cf1", "col":"col1", "type":"binary"} + |} + |}""".stripMargin + + def avroCatalog = s"""{ + |"table":{"namespace":"default", "name":"ExampleAvrotable"}, + |"rowkey":"key", + |"columns":{ + |"col0":{"cf":"rowkey", "col":"key", "type":"string"}, + |"col1":{"cf":"cf1", "col":"col1", "avro":"avroSchema"} + |} + |}""".stripMargin + + def avroCatalogInsert = s"""{ + |"table":{"namespace":"default", "name":"ExampleAvrotableInsert"}, + |"rowkey":"key", + |"columns":{ + |"col0":{"cf":"rowkey", "col":"key", "type":"string"}, + |"col1":{"cf":"cf1", "col":"col1", "avro":"avroSchema"} + |} + |}""".stripMargin + + def main(args: Array[String]) { + val sparkConf = new SparkConf().setAppName("AvroSourceExample") + val sc = new SparkContext(sparkConf) + val sqlContext = new SQLContext(sc) + + import sqlContext.implicits._ + + def withCatalog(cat: String): DataFrame = { + sqlContext.read + .options( + Map( + "avroSchema" -> AvroHBaseRecord.schemaString, + HBaseTableCatalog.tableCatalog -> avroCatalog)) + .format("org.apache.hadoop.hbase.spark") + .load() + } + + val data = (0 to 255).map { i => AvroHBaseRecord(i) } + + sc.parallelize(data) + .toDF + .write + .options(Map(HBaseTableCatalog.tableCatalog -> catalog, HBaseTableCatalog.newTable -> "5")) + .format("org.apache.hadoop.hbase.spark") + .save() + + val df = withCatalog(catalog) + df.show() + df.printSchema() + df.registerTempTable("ExampleAvrotable") + val c = sqlContext.sql("select count(1) from ExampleAvrotable") + c.show() + + val filtered = df.select($"col0", $"col1.favorite_array").where($"col0" === "name001") + filtered.show() + val collected = filtered.collect() + if (collected(0).getSeq[String](1)(0) != "number1") { + throw new UserCustomizedSampleException("value invalid") + } + if (collected(0).getSeq[String](1)(1) != "number2") { + throw new UserCustomizedSampleException("value invalid") + } + + df.write + .options( + Map( + "avroSchema" -> AvroHBaseRecord.schemaString, + HBaseTableCatalog.tableCatalog -> avroCatalogInsert, + HBaseTableCatalog.newTable -> "5")) + .format("org.apache.hadoop.hbase.spark") + .save() + val newDF = withCatalog(avroCatalogInsert) + newDF.show() + newDF.printSchema() + if (newDF.count() != 256) { + throw new UserCustomizedSampleException("value invalid") + } + + df.filter($"col1.name" === "name005" || $"col1.name" <= "name005") + .select("col0", "col1.favorite_color", "col1.favorite_number") + .show() + + df.filter($"col1.name" <= "name005" || $"col1.name".contains("name007")) + .select("col0", "col1.favorite_color", "col1.favorite_number") + .show() + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/DataType.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/DataType.scala new file mode 100644 index 00000000..d314dc8c --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/DataType.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.datasources + +import org.apache.hadoop.hbase.spark.datasources.HBaseTableCatalog +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.SQLContext +import org.apache.yetus.audience.InterfaceAudience + +@InterfaceAudience.Private +class UserCustomizedSampleException(message: String = null, cause: Throwable = null) + extends RuntimeException(UserCustomizedSampleException.message(message, cause), cause) + +@InterfaceAudience.Private +object UserCustomizedSampleException { + def message(message: String, cause: Throwable) = + if (message != null) message + else if (cause != null) cause.toString() + else null +} + +@InterfaceAudience.Private +case class IntKeyRecord( + col0: Integer, + col1: Boolean, + col2: Double, + col3: Float, + col4: Int, + col5: Long, + col6: Short, + col7: String, + col8: Byte) + +object IntKeyRecord { + def apply(i: Int): IntKeyRecord = { + IntKeyRecord( + if (i % 2 == 0) i else -i, + i % 2 == 0, + i.toDouble, + i.toFloat, + i, + i.toLong, + i.toShort, + s"String$i extra", + i.toByte) + } +} + +@InterfaceAudience.Private +object DataType { + val cat = s"""{ + |"table":{"namespace":"default", "name":"DataTypeExampleTable"}, + |"rowkey":"key", + |"columns":{ + |"col0":{"cf":"rowkey", "col":"key", "type":"int"}, + |"col1":{"cf":"cf1", "col":"col1", "type":"boolean"}, + |"col2":{"cf":"cf2", "col":"col2", "type":"double"}, + |"col3":{"cf":"cf3", "col":"col3", "type":"float"}, + |"col4":{"cf":"cf4", "col":"col4", "type":"int"}, + |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"}, + |"col6":{"cf":"cf6", "col":"col6", "type":"smallint"}, + |"col7":{"cf":"cf7", "col":"col7", "type":"string"}, + |"col8":{"cf":"cf8", "col":"col8", "type":"tinyint"} + |} + |}""".stripMargin + + def main(args: Array[String]) { + val sparkConf = new SparkConf().setAppName("DataTypeExample") + val sc = new SparkContext(sparkConf) + val sqlContext = new SQLContext(sc) + + import sqlContext.implicits._ + + def withCatalog(cat: String): DataFrame = { + sqlContext.read + .options(Map(HBaseTableCatalog.tableCatalog -> cat)) + .format("org.apache.hadoop.hbase.spark") + .load() + } + + // test populate table + val data = (0 until 32).map { i => IntKeyRecord(i) } + sc.parallelize(data) + .toDF + .write + .options(Map(HBaseTableCatalog.tableCatalog -> cat, HBaseTableCatalog.newTable -> "5")) + .format("org.apache.hadoop.hbase.spark") + .save() + + // test less than 0 + val df = withCatalog(cat) + val s = df.filter($"col0" < 0) + s.show() + if (s.count() != 16) { + throw new UserCustomizedSampleException("value invalid") + } + + // test less or equal than -10. The number of results is 11 + val num1 = df.filter($"col0" <= -10) + num1.show() + val c1 = num1.count() + println(s"test result count should be 11: $c1") + + // test less or equal than -9. The number of results is 12 + val num2 = df.filter($"col0" <= -9) + num2.show() + val c2 = num2.count() + println(s"test result count should be 12: $c2") + + // test greater or equal than -9". The number of results is 21 + val num3 = df.filter($"col0" >= -9) + num3.show() + val c3 = num3.count() + println(s"test result count should be 21: $c3") + + // test greater or equal than 0. The number of results is 16 + val num4 = df.filter($"col0" >= 0) + num4.show() + val c4 = num4.count() + println(s"test result count should be 16: $c4") + + // test greater than 10. The number of results is 10 + val num5 = df.filter($"col0" > 10) + num5.show() + val c5 = num5.count() + println(s"test result count should be 10: $c5") + + // test "and". The number of results is 11 + val num6 = df.filter($"col0" > -10 && $"col0" <= 10) + num6.show() + val c6 = num6.count() + println(s"test result count should be 11: $c6") + + // test "or". The number of results is 21 + val num7 = df.filter($"col0" <= -10 || $"col0" > 10) + num7.show() + val c7 = num7.count() + println(s"test result count should be 21: $c7") + + // test "all". The number of results is 32 + val num8 = df.filter($"col0" >= -100) + num8.show() + val c8 = num8.count() + println(s"test result count should be 32: $c8") + + // test "full query" + val df1 = withCatalog(cat) + df1.show() + val c_df = df1.count() + println(s"df count should be 32: $c_df") + if (c_df != 32) { + throw new UserCustomizedSampleException("value invalid") + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/HBaseSource.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/HBaseSource.scala new file mode 100644 index 00000000..d7101e2e --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/datasources/HBaseSource.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.datasources + +import org.apache.hadoop.hbase.spark.datasources.HBaseTableCatalog +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.SQLContext +import org.apache.yetus.audience.InterfaceAudience + +@InterfaceAudience.Private +case class HBaseRecord( + col0: String, + col1: Boolean, + col2: Double, + col3: Float, + col4: Int, + col5: Long, + col6: Short, + col7: String, + col8: Byte) + +@InterfaceAudience.Private +object HBaseRecord { + def apply(i: Int): HBaseRecord = { + val s = s"""row${"%03d".format(i)}""" + HBaseRecord( + s, + i % 2 == 0, + i.toDouble, + i.toFloat, + i, + i.toLong, + i.toShort, + s"String$i extra", + i.toByte) + } +} + +@InterfaceAudience.Private +object HBaseSource { + val cat = s"""{ + |"table":{"namespace":"default", "name":"HBaseSourceExampleTable"}, + |"rowkey":"key", + |"columns":{ + |"col0":{"cf":"rowkey", "col":"key", "type":"string"}, + |"col1":{"cf":"cf1", "col":"col1", "type":"boolean"}, + |"col2":{"cf":"cf2", "col":"col2", "type":"double"}, + |"col3":{"cf":"cf3", "col":"col3", "type":"float"}, + |"col4":{"cf":"cf4", "col":"col4", "type":"int"}, + |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"}, + |"col6":{"cf":"cf6", "col":"col6", "type":"smallint"}, + |"col7":{"cf":"cf7", "col":"col7", "type":"string"}, + |"col8":{"cf":"cf8", "col":"col8", "type":"tinyint"} + |} + |}""".stripMargin + + def main(args: Array[String]) { + val sparkConf = new SparkConf().setAppName("HBaseSourceExample") + val sc = new SparkContext(sparkConf) + val sqlContext = new SQLContext(sc) + + import sqlContext.implicits._ + + def withCatalog(cat: String): DataFrame = { + sqlContext.read + .options(Map(HBaseTableCatalog.tableCatalog -> cat)) + .format("org.apache.hadoop.hbase.spark") + .load() + } + + val data = (0 to 255).map { i => HBaseRecord(i) } + + sc.parallelize(data) + .toDF + .write + .options(Map(HBaseTableCatalog.tableCatalog -> cat, HBaseTableCatalog.newTable -> "5")) + .format("org.apache.hadoop.hbase.spark") + .save() + + val df = withCatalog(cat) + df.show() + df.filter($"col0" <= "row005") + .select($"col0", $"col1") + .show + df.filter($"col0" === "row005" || $"col0" <= "row005") + .select($"col0", $"col1") + .show + df.filter($"col0" > "row250") + .select($"col0", $"col1") + .show + df.registerTempTable("table1") + val c = sqlContext.sql("select count(col1) from table1 where col0 < 'row050'") + c.show() + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkDeleteExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkDeleteExample.scala new file mode 100644 index 00000000..94659779 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkDeleteExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext + +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.Delete +import org.apache.hadoop.hbase.spark.HBaseContext +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a simple example of deleting records in HBase + * with the bulkDelete function. + */ +@InterfaceAudience.Private +object HBaseBulkDeleteExample { + def main(args: Array[String]) { + if (args.length < 1) { + println("HBaseBulkDeleteExample {tableName} missing an argument") + return + } + + val tableName = args(0) + + val sparkConf = new SparkConf().setAppName("HBaseBulkDeleteExample " + tableName) + val sc = new SparkContext(sparkConf) + try { + // [Array[Byte]] + val rdd = sc.parallelize( + Array[Array[Byte]]( + Bytes.toBytes("1"), + Bytes.toBytes("2"), + Bytes.toBytes("3"), + Bytes.toBytes("4"), + Bytes.toBytes("5"))) + + val conf = HBaseConfiguration.create() + + val hbaseContext = new HBaseContext(sc, conf) + hbaseContext.bulkDelete[Array[Byte]]( + rdd, + TableName.valueOf(tableName), + putRecord => new Delete(putRecord), + 4) + } finally { + sc.stop() + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkGetExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkGetExample.scala new file mode 100644 index 00000000..cae3dc11 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkGetExample.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext + +import org.apache.hadoop.hbase.CellUtil +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.Get +import org.apache.hadoop.hbase.client.Result +import org.apache.hadoop.hbase.spark.HBaseContext +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a simple example of getting records from HBase + * with the bulkGet function. + */ +@InterfaceAudience.Private +object HBaseBulkGetExample { + def main(args: Array[String]) { + if (args.length < 1) { + println("HBaseBulkGetExample {tableName} missing an argument") + return + } + + val tableName = args(0) + + val sparkConf = new SparkConf().setAppName("HBaseBulkGetExample " + tableName) + val sc = new SparkContext(sparkConf) + + try { + + // [(Array[Byte])] + val rdd = sc.parallelize( + Array[Array[Byte]]( + Bytes.toBytes("1"), + Bytes.toBytes("2"), + Bytes.toBytes("3"), + Bytes.toBytes("4"), + Bytes.toBytes("5"), + Bytes.toBytes("6"), + Bytes.toBytes("7"))) + + val conf = HBaseConfiguration.create() + + val hbaseContext = new HBaseContext(sc, conf) + + val getRdd = hbaseContext.bulkGet[Array[Byte], String]( + TableName.valueOf(tableName), + 2, + rdd, + record => { + System.out.println("making Get") + new Get(record) + }, + (result: Result) => { + + val it = result.listCells().iterator() + val b = new StringBuilder + + b.append(Bytes.toString(result.getRow) + ":") + + while (it.hasNext) { + val cell = it.next() + val q = Bytes.toString(CellUtil.cloneQualifier(cell)) + if (q.equals("counter")) { + b.append("(" + q + "," + Bytes.toLong(CellUtil.cloneValue(cell)) + ")") + } else { + b.append("(" + q + "," + Bytes.toString(CellUtil.cloneValue(cell)) + ")") + } + } + b.toString() + }) + + getRdd + .collect() + .foreach(v => println(v)) + + } finally { + sc.stop() + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutExample.scala new file mode 100644 index 00000000..59e2d298 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutExample.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext + +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.Put +import org.apache.hadoop.hbase.spark.HBaseContext +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a simple example of putting records in HBase + * with the bulkPut function. + */ +@InterfaceAudience.Private +object HBaseBulkPutExample { + def main(args: Array[String]) { + if (args.length < 2) { + println("HBaseBulkPutExample {tableName} {columnFamily} are missing an arguments") + return + } + + val tableName = args(0) + val columnFamily = args(1) + + val sparkConf = new SparkConf().setAppName( + "HBaseBulkPutExample " + + tableName + " " + columnFamily) + val sc = new SparkContext(sparkConf) + + try { + // [(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])] + val rdd = sc.parallelize( + Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( + ( + Bytes.toBytes("1"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("1")))), + ( + Bytes.toBytes("2"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("2")))), + ( + Bytes.toBytes("3"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("3")))), + ( + Bytes.toBytes("4"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("4")))), + ( + Bytes.toBytes("5"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("5")))))) + + val conf = HBaseConfiguration.create() + + val hbaseContext = new HBaseContext(sc, conf) + hbaseContext.bulkPut[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( + rdd, + TableName.valueOf(tableName), + (putRecord) => { + val put = new Put(putRecord._1) + putRecord._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3)) + put + }); + } finally { + sc.stop() + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutExampleFromFile.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutExampleFromFile.scala new file mode 100644 index 00000000..e7895763 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutExampleFromFile.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext + +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.Put +import org.apache.hadoop.hbase.spark.HBaseContext +import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.io.LongWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a simple example of putting records in HBase + * with the bulkPut function. In this example we are + * getting the put information from a file + */ +@InterfaceAudience.Private +object HBaseBulkPutExampleFromFile { + def main(args: Array[String]) { + if (args.length < 3) { + println( + "HBaseBulkPutExampleFromFile {tableName} {columnFamily} {inputFile} are missing an argument") + return + } + + val tableName = args(0) + val columnFamily = args(1) + val inputFile = args(2) + + val sparkConf = new SparkConf().setAppName( + "HBaseBulkPutExampleFromFile " + + tableName + " " + columnFamily + " " + inputFile) + val sc = new SparkContext(sparkConf) + + try { + var rdd = sc + .hadoopFile(inputFile, classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) + .map( + v => { + System.out.println("reading-" + v._2.toString) + v._2.toString + }) + + val conf = HBaseConfiguration.create() + + val hbaseContext = new HBaseContext(sc, conf) + hbaseContext.bulkPut[String]( + rdd, + TableName.valueOf(tableName), + (putRecord) => { + System.out.println("hbase-" + putRecord) + val put = new Put(Bytes.toBytes("Value- " + putRecord)) + put.addColumn(Bytes.toBytes("c"), Bytes.toBytes("1"), Bytes.toBytes(putRecord.length())) + put + }); + } finally { + sc.stop() + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutTimestampExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutTimestampExample.scala new file mode 100644 index 00000000..5e84f2e2 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseBulkPutTimestampExample.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext + +import org.apache.hadoop.hbase.{HBaseConfiguration, TableName} +import org.apache.hadoop.hbase.client.Put +import org.apache.hadoop.hbase.spark.HBaseContext +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a simple example of putting records in HBase + * with the bulkPut function. In this example we are + * also setting the timestamp in the put + */ +@InterfaceAudience.Private +object HBaseBulkPutTimestampExample { + def main(args: Array[String]) { + if (args.length < 2) { + System.out.println( + "HBaseBulkPutTimestampExample {tableName} {columnFamily} are missing an argument") + return + } + + val tableName = args(0) + val columnFamily = args(1) + + val sparkConf = new SparkConf().setAppName( + "HBaseBulkPutTimestampExample " + + tableName + " " + columnFamily) + val sc = new SparkContext(sparkConf) + + try { + + val rdd = sc.parallelize( + Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( + ( + Bytes.toBytes("6"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("1")))), + ( + Bytes.toBytes("7"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("2")))), + ( + Bytes.toBytes("8"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("3")))), + ( + Bytes.toBytes("9"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("4")))), + ( + Bytes.toBytes("10"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("5")))))) + + val conf = HBaseConfiguration.create() + + val timeStamp = System.currentTimeMillis() + + val hbaseContext = new HBaseContext(sc, conf) + hbaseContext.bulkPut[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( + rdd, + TableName.valueOf(tableName), + (putRecord) => { + val put = new Put(putRecord._1) + putRecord._2.foreach( + (putValue) => put.addColumn(putValue._1, putValue._2, timeStamp, putValue._3)) + put + }) + } finally { + sc.stop() + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseDistributedScanExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseDistributedScanExample.scala new file mode 100644 index 00000000..2f3619b2 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseDistributedScanExample.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext + +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.Scan +import org.apache.hadoop.hbase.spark.HBaseContext +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a simple example of scanning records from HBase + * with the hbaseRDD function in Distributed fashion. + */ +@InterfaceAudience.Private +object HBaseDistributedScanExample { + def main(args: Array[String]) { + if (args.length < 1) { + println("HBaseDistributedScanExample {tableName} missing an argument") + return + } + + val tableName = args(0) + + val sparkConf = new SparkConf().setAppName("HBaseDistributedScanExample " + tableName) + val sc = new SparkContext(sparkConf) + + try { + val conf = HBaseConfiguration.create() + + val hbaseContext = new HBaseContext(sc, conf) + + val scan = new Scan() + scan.setCaching(100) + + val getRdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan) + + getRdd.foreach(v => println(Bytes.toString(v._1.get()))) + + println( + "Length: " + getRdd + .map(r => r._1.copyBytes()) + .collect() + .length); + } finally { + sc.stop() + } + } + +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseStreamingBulkPutExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseStreamingBulkPutExample.scala new file mode 100644 index 00000000..4774a066 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/hbasecontext/HBaseStreamingBulkPutExample.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.hbasecontext + +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.Put +import org.apache.hadoop.hbase.spark.HBaseContext +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.streaming.Seconds +import org.apache.spark.streaming.StreamingContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a simple example of BulkPut with Spark Streaming + */ +@InterfaceAudience.Private +object HBaseStreamingBulkPutExample { + def main(args: Array[String]) { + if (args.length < 4) { + println( + "HBaseStreamingBulkPutExample " + + "{host} {port} {tableName} {columnFamily} are missing an argument") + return + } + + val host = args(0) + val port = args(1) + val tableName = args(2) + val columnFamily = args(3) + + val sparkConf = new SparkConf().setAppName( + "HBaseStreamingBulkPutExample " + + tableName + " " + columnFamily) + val sc = new SparkContext(sparkConf) + try { + val ssc = new StreamingContext(sc, Seconds(1)) + + val lines = ssc.socketTextStream(host, port.toInt) + + val conf = HBaseConfiguration.create() + + val hbaseContext = new HBaseContext(sc, conf) + + hbaseContext.streamBulkPut[String]( + lines, + TableName.valueOf(tableName), + (putRecord) => { + if (putRecord.length() > 0) { + val put = new Put(Bytes.toBytes(putRecord)) + put.addColumn(Bytes.toBytes("c"), Bytes.toBytes("foo"), Bytes.toBytes("bar")) + put + } else { + null + } + }) + ssc.start() + ssc.awaitTerminationOrTimeout(60000) + } finally { + sc.stop() + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkDeleteExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkDeleteExample.scala new file mode 100644 index 00000000..1498df05 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkDeleteExample.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.rdd + +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.Delete +import org.apache.hadoop.hbase.spark.HBaseContext +import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._ +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a simple example of deleting records in HBase + * with the bulkDelete function. + */ +@InterfaceAudience.Private +object HBaseBulkDeleteExample { + def main(args: Array[String]) { + if (args.length < 1) { + println("HBaseBulkDeleteExample {tableName} are missing an argument") + return + } + + val tableName = args(0) + + val sparkConf = new SparkConf().setAppName("HBaseBulkDeleteExample " + tableName) + val sc = new SparkContext(sparkConf) + try { + // [Array[Byte]] + val rdd = sc.parallelize( + Array[Array[Byte]]( + Bytes.toBytes("1"), + Bytes.toBytes("2"), + Bytes.toBytes("3"), + Bytes.toBytes("4"), + Bytes.toBytes("5"))) + + val conf = HBaseConfiguration.create() + + val hbaseContext = new HBaseContext(sc, conf) + + rdd.hbaseBulkDelete( + hbaseContext, + TableName.valueOf(tableName), + putRecord => new Delete(putRecord), + 4) + + } finally { + sc.stop() + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkGetExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkGetExample.scala new file mode 100644 index 00000000..e57c7690 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkGetExample.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.rdd + +import org.apache.hadoop.hbase.CellUtil +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.Get +import org.apache.hadoop.hbase.client.Result +import org.apache.hadoop.hbase.spark.HBaseContext +import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._ +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a simple example of getting records from HBase + * with the bulkGet function. + */ +@InterfaceAudience.Private +object HBaseBulkGetExample { + def main(args: Array[String]) { + if (args.length < 1) { + println("HBaseBulkGetExample {tableName} is missing an argument") + return + } + + val tableName = args(0) + + val sparkConf = new SparkConf().setAppName("HBaseBulkGetExample " + tableName) + val sc = new SparkContext(sparkConf) + + try { + + // [(Array[Byte])] + val rdd = sc.parallelize( + Array[Array[Byte]]( + Bytes.toBytes("1"), + Bytes.toBytes("2"), + Bytes.toBytes("3"), + Bytes.toBytes("4"), + Bytes.toBytes("5"), + Bytes.toBytes("6"), + Bytes.toBytes("7"))) + + val conf = HBaseConfiguration.create() + + val hbaseContext = new HBaseContext(sc, conf) + + val getRdd = rdd.hbaseBulkGet[String]( + hbaseContext, + TableName.valueOf(tableName), + 2, + record => { + System.out.println("making Get") + new Get(record) + }, + (result: Result) => { + + val it = result.listCells().iterator() + val b = new StringBuilder + + b.append(Bytes.toString(result.getRow) + ":") + + while (it.hasNext) { + val cell = it.next() + val q = Bytes.toString(CellUtil.cloneQualifier(cell)) + if (q.equals("counter")) { + b.append("(" + q + "," + Bytes.toLong(CellUtil.cloneValue(cell)) + ")") + } else { + b.append("(" + q + "," + Bytes.toString(CellUtil.cloneValue(cell)) + ")") + } + } + b.toString() + }) + + getRdd + .collect() + .foreach(v => println(v)) + + } finally { + sc.stop() + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkPutExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkPutExample.scala new file mode 100644 index 00000000..6bc7ac2c --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseBulkPutExample.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.rdd + +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.Put +import org.apache.hadoop.hbase.spark.HBaseContext +import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._ +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a simple example of putting records in HBase + * with the bulkPut function. + */ +@InterfaceAudience.Private +object HBaseBulkPutExample { + def main(args: Array[String]) { + if (args.length < 2) { + println("HBaseBulkPutExample {tableName} {columnFamily} are missing an arguments") + return + } + + val tableName = args(0) + val columnFamily = args(1) + + val sparkConf = new SparkConf().setAppName( + "HBaseBulkPutExample " + + tableName + " " + columnFamily) + val sc = new SparkContext(sparkConf) + + try { + // [(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])] + val rdd = sc.parallelize( + Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( + ( + Bytes.toBytes("1"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("1")))), + ( + Bytes.toBytes("2"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("2")))), + ( + Bytes.toBytes("3"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("3")))), + ( + Bytes.toBytes("4"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("4")))), + ( + Bytes.toBytes("5"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("5")))))) + + val conf = HBaseConfiguration.create() + + val hbaseContext = new HBaseContext(sc, conf) + + rdd.hbaseBulkPut( + hbaseContext, + TableName.valueOf(tableName), + (putRecord) => { + val put = new Put(putRecord._1) + putRecord._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3)) + put + }) + + } finally { + sc.stop() + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseForeachPartitionExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseForeachPartitionExample.scala new file mode 100644 index 00000000..a952c548 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseForeachPartitionExample.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.rdd + +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.Put +import org.apache.hadoop.hbase.spark.HBaseContext +import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._ +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a simple example of using the foreachPartition + * method with a HBase connection + */ +@InterfaceAudience.Private +object HBaseForeachPartitionExample { + def main(args: Array[String]) { + if (args.length < 2) { + println("HBaseForeachPartitionExample {tableName} {columnFamily} are missing an arguments") + return + } + + val tableName = args(0) + val columnFamily = args(1) + + val sparkConf = new SparkConf().setAppName( + "HBaseForeachPartitionExample " + + tableName + " " + columnFamily) + val sc = new SparkContext(sparkConf) + + try { + // [(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])] + val rdd = sc.parallelize( + Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( + ( + Bytes.toBytes("1"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("1")))), + ( + Bytes.toBytes("2"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("2")))), + ( + Bytes.toBytes("3"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("3")))), + ( + Bytes.toBytes("4"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("4")))), + ( + Bytes.toBytes("5"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("1"), Bytes.toBytes("5")))))) + + val conf = HBaseConfiguration.create() + + val hbaseContext = new HBaseContext(sc, conf) + + rdd.hbaseForeachPartition( + hbaseContext, + (it, connection) => { + val m = connection.getBufferedMutator(TableName.valueOf(tableName)) + + it.foreach( + r => { + val put = new Put(r._1) + r._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3)) + m.mutate(put) + }) + m.flush() + m.close() + }) + + } finally { + sc.stop() + } + } +} diff --git a/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseMapPartitionExample.scala b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseMapPartitionExample.scala new file mode 100644 index 00000000..cac41d21 --- /dev/null +++ b/spark4/hbase-spark4/src/main/scala/org/apache/hadoop/hbase/spark/example/rdd/HBaseMapPartitionExample.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark.example.rdd + +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.Get +import org.apache.hadoop.hbase.spark.HBaseContext +import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._ +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.yetus.audience.InterfaceAudience + +/** + * This is a simple example of using the mapPartitions + * method with a HBase connection + */ +@InterfaceAudience.Private +object HBaseMapPartitionExample { + def main(args: Array[String]) { + if (args.length < 1) { + println("HBaseMapPartitionExample {tableName} is missing an argument") + return + } + + val tableName = args(0) + + val sparkConf = new SparkConf().setAppName("HBaseMapPartitionExample " + tableName) + val sc = new SparkContext(sparkConf) + + try { + + // [(Array[Byte])] + val rdd = sc.parallelize( + Array[Array[Byte]]( + Bytes.toBytes("1"), + Bytes.toBytes("2"), + Bytes.toBytes("3"), + Bytes.toBytes("4"), + Bytes.toBytes("5"), + Bytes.toBytes("6"), + Bytes.toBytes("7"))) + + val conf = HBaseConfiguration.create() + + val hbaseContext = new HBaseContext(sc, conf) + + val getRdd = rdd.hbaseMapPartitions[String]( + hbaseContext, + (it, connection) => { + val table = connection.getTable(TableName.valueOf(tableName)) + it.map { + r => + // batching would be faster. This is just an example + val result = table.get(new Get(r)) + + val it = result.listCells().iterator() + val b = new StringBuilder + + b.append(Bytes.toString(result.getRow) + ":") + + while (it.hasNext) { + val cell = it.next() + val q = Bytes.toString(cell.getQualifierArray) + if (q.equals("counter")) { + b.append("(" + q + "," + Bytes.toLong(cell.getValueArray) + ")") + } else { + b.append("(" + q + "," + Bytes.toString(cell.getValueArray) + ")") + } + } + b.toString() + } + }) + + getRdd + .collect() + .foreach(v => println(v)) + + } finally { + sc.stop() + } + } +} diff --git a/spark4/hbase-spark4/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java b/spark4/hbase-spark4/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java new file mode 100644 index 00000000..0f8c4aa7 --- /dev/null +++ b/spark4/hbase-spark4/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java @@ -0,0 +1,515 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hbase.Cell; +import org.apache.hadoop.hbase.CellUtil; +import org.apache.hadoop.hbase.HBaseClassTestRule; +import org.apache.hadoop.hbase.HBaseTestingUtility; +import org.apache.hadoop.hbase.HConstants; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Admin; +import org.apache.hadoop.hbase.client.Connection; +import org.apache.hadoop.hbase.client.ConnectionFactory; +import org.apache.hadoop.hbase.client.Delete; +import org.apache.hadoop.hbase.client.Get; +import org.apache.hadoop.hbase.client.Put; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.client.Scan; +import org.apache.hadoop.hbase.client.Table; +import org.apache.hadoop.hbase.io.ImmutableBytesWritable; +import org.apache.hadoop.hbase.spark.example.hbasecontext.JavaHBaseBulkDeleteExample; +import org.apache.hadoop.hbase.testclassification.MediumTests; +import org.apache.hadoop.hbase.testclassification.MiscTests; +import org.apache.hadoop.hbase.tool.LoadIncrementalHFiles; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.hadoop.hbase.util.Pair; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Tuple2; + +@Category({ MiscTests.class, MediumTests.class }) +public class TestJavaHBaseContext implements Serializable { + + @ClassRule + public static final HBaseClassTestRule TIMEOUT = + HBaseClassTestRule.forClass(TestJavaHBaseContext.class); + + protected static transient JavaSparkContext JSC; + private static HBaseTestingUtility TEST_UTIL; + private static JavaHBaseContext HBASE_CONTEXT; + private static final Logger LOG = LoggerFactory.getLogger(TestJavaHBaseContext.class); + + protected byte[] tableName = Bytes.toBytes("t1"); + protected byte[] columnFamily = Bytes.toBytes("c"); + byte[] columnFamily1 = Bytes.toBytes("d"); + String columnFamilyStr = Bytes.toString(columnFamily); + String columnFamilyStr1 = Bytes.toString(columnFamily1); + + @BeforeClass + public static void setUpBeforeClass() throws Exception { + // NOTE: We need to do this due to behaviour change in spark 3.2, where the below conf is true + // by default. We will get empty table as result (for small sized tables) for HBase version not + // having HBASE-26340 + SparkConf sparkConf = new SparkConf().set("spark.hadoopRDD.ignoreEmptySplits", "false"); + JSC = new JavaSparkContext("local", "JavaHBaseContextSuite", sparkConf); + + init(); + } + + protected static void init() throws Exception { + TEST_UTIL = new HBaseTestingUtility(); + Configuration conf = TEST_UTIL.getConfiguration(); + + HBASE_CONTEXT = new JavaHBaseContext(JSC, conf); + + LOG.info("cleaning up test dir"); + + TEST_UTIL.cleanupTestDir(); + + LOG.info("starting minicluster"); + + TEST_UTIL.startMiniCluster(); + + LOG.info(" - minicluster started"); + } + + @AfterClass + public static void tearDownAfterClass() throws Exception { + LOG.info("shuting down minicluster"); + TEST_UTIL.shutdownMiniCluster(); + LOG.info(" - minicluster shut down"); + TEST_UTIL.cleanupTestDir(); + + JSC.stop(); + JSC = null; + } + + @Before + public void setUp() throws Exception { + + try { + TEST_UTIL.deleteTable(TableName.valueOf(tableName)); + } catch (Exception e) { + LOG.info(" - no table {} found", Bytes.toString(tableName)); + } + + LOG.info(" - creating table {}", Bytes.toString(tableName)); + TEST_UTIL.createTable(TableName.valueOf(tableName), + new byte[][] { columnFamily, columnFamily1 }); + LOG.info(" - created table"); + } + + @After + public void tearDown() throws Exception { + TEST_UTIL.deleteTable(TableName.valueOf(tableName)); + } + + @Test + public void testBulkPut() throws IOException { + + List list = new ArrayList<>(5); + list.add("1," + columnFamilyStr + ",a,1"); + list.add("2," + columnFamilyStr + ",a,2"); + list.add("3," + columnFamilyStr + ",a,3"); + list.add("4," + columnFamilyStr + ",a,4"); + list.add("5," + columnFamilyStr + ",a,5"); + + JavaRDD rdd = JSC.parallelize(list); + + Configuration conf = TEST_UTIL.getConfiguration(); + + Connection conn = ConnectionFactory.createConnection(conf); + Table table = conn.getTable(TableName.valueOf(tableName)); + + try { + List deletes = new ArrayList<>(5); + for (int i = 1; i < 6; i++) { + deletes.add(new Delete(Bytes.toBytes(Integer.toString(i)))); + } + table.delete(deletes); + } finally { + table.close(); + } + + HBASE_CONTEXT.bulkPut(rdd, TableName.valueOf(tableName), new PutFunction()); + + table = conn.getTable(TableName.valueOf(tableName)); + + try { + Result result1 = table.get(new Get(Bytes.toBytes("1"))); + Assert.assertNotNull("Row 1 should had been deleted", result1.getRow()); + + Result result2 = table.get(new Get(Bytes.toBytes("2"))); + Assert.assertNotNull("Row 2 should had been deleted", result2.getRow()); + + Result result3 = table.get(new Get(Bytes.toBytes("3"))); + Assert.assertNotNull("Row 3 should had been deleted", result3.getRow()); + + Result result4 = table.get(new Get(Bytes.toBytes("4"))); + Assert.assertNotNull("Row 4 should had been deleted", result4.getRow()); + + Result result5 = table.get(new Get(Bytes.toBytes("5"))); + Assert.assertNotNull("Row 5 should had been deleted", result5.getRow()); + } finally { + table.close(); + conn.close(); + } + } + + public static class PutFunction implements Function { + + private static final long serialVersionUID = 1L; + + @Override + public Put call(String v) throws Exception { + String[] cells = v.split(","); + Put put = new Put(Bytes.toBytes(cells[0])); + + put.addColumn(Bytes.toBytes(cells[1]), Bytes.toBytes(cells[2]), Bytes.toBytes(cells[3])); + return put; + } + } + + @Test + public void testBulkDelete() throws IOException { + List list = new ArrayList<>(3); + list.add(Bytes.toBytes("1")); + list.add(Bytes.toBytes("2")); + list.add(Bytes.toBytes("3")); + + JavaRDD rdd = JSC.parallelize(list); + + Configuration conf = TEST_UTIL.getConfiguration(); + + populateTableWithMockData(conf, TableName.valueOf(tableName)); + + HBASE_CONTEXT.bulkDelete(rdd, TableName.valueOf(tableName), + new JavaHBaseBulkDeleteExample.DeleteFunction(), 2); + + try (Connection conn = ConnectionFactory.createConnection(conf); + Table table = conn.getTable(TableName.valueOf(tableName))) { + Result result1 = table.get(new Get(Bytes.toBytes("1"))); + Assert.assertNull("Row 1 should had been deleted", result1.getRow()); + + Result result2 = table.get(new Get(Bytes.toBytes("2"))); + Assert.assertNull("Row 2 should had been deleted", result2.getRow()); + + Result result3 = table.get(new Get(Bytes.toBytes("3"))); + Assert.assertNull("Row 3 should had been deleted", result3.getRow()); + + Result result4 = table.get(new Get(Bytes.toBytes("4"))); + Assert.assertNotNull("Row 4 should had been deleted", result4.getRow()); + + Result result5 = table.get(new Get(Bytes.toBytes("5"))); + Assert.assertNotNull("Row 5 should had been deleted", result5.getRow()); + } + } + + @Test + public void testDistributedScan() throws IOException { + Configuration conf = TEST_UTIL.getConfiguration(); + + populateTableWithMockData(conf, TableName.valueOf(tableName)); + + Scan scan = new Scan(); + scan.setCaching(100); + + JavaRDD javaRdd = + HBASE_CONTEXT.hbaseRDD(TableName.valueOf(tableName), scan).map(new ScanConvertFunction()); + + List results = javaRdd.collect(); + + Assert.assertEquals(results.size(), 5); + } + + private static class ScanConvertFunction + implements Function, String> { + @Override + public String call(Tuple2 v1) throws Exception { + return Bytes.toString(v1._1().copyBytes()); + } + } + + @Test + public void testBulkGet() throws IOException { + List list = new ArrayList<>(5); + list.add(Bytes.toBytes("1")); + list.add(Bytes.toBytes("2")); + list.add(Bytes.toBytes("3")); + list.add(Bytes.toBytes("4")); + list.add(Bytes.toBytes("5")); + + JavaRDD rdd = JSC.parallelize(list); + + Configuration conf = TEST_UTIL.getConfiguration(); + + populateTableWithMockData(conf, TableName.valueOf(tableName)); + + final JavaRDD stringJavaRDD = HBASE_CONTEXT.bulkGet(TableName.valueOf(tableName), 2, + rdd, new GetFunction(), new ResultFunction()); + + Assert.assertEquals(stringJavaRDD.count(), 5); + } + + @Test + public void testBulkLoad() throws Exception { + + Path output = TEST_UTIL.getDataTestDir("testBulkLoad"); + // Add cell as String: "row,falmily,qualifier,value" + List list = new ArrayList(); + // row1 + list.add("1," + columnFamilyStr + ",b,1"); + // row3 + list.add("3," + columnFamilyStr + ",a,2"); + list.add("3," + columnFamilyStr + ",b,1"); + list.add("3," + columnFamilyStr1 + ",a,1"); + // row2 + list.add("2," + columnFamilyStr + ",a,3"); + list.add("2," + columnFamilyStr + ",b,3"); + + JavaRDD rdd = JSC.parallelize(list); + + Configuration conf = TEST_UTIL.getConfiguration(); + + HBASE_CONTEXT.bulkLoad(rdd, TableName.valueOf(tableName), new BulkLoadFunction(), + output.toUri().getPath(), new HashMap(), false, + HConstants.DEFAULT_MAX_FILE_SIZE); + + try (Connection conn = ConnectionFactory.createConnection(conf); + Admin admin = conn.getAdmin()) { + Table table = conn.getTable(TableName.valueOf(tableName)); + // Do bulk load + LoadIncrementalHFiles load = new LoadIncrementalHFiles(conf); + load.doBulkLoad(output, admin, table, conn.getRegionLocator(TableName.valueOf(tableName))); + + // Check row1 + List cell1 = table.get(new Get(Bytes.toBytes("1"))).listCells(); + Assert.assertEquals(cell1.size(), 1); + Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell1.get(0))), columnFamilyStr); + Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell1.get(0))), "b"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell1.get(0))), "1"); + + // Check row3 + List cell3 = table.get(new Get(Bytes.toBytes("3"))).listCells(); + Assert.assertEquals(cell3.size(), 3); + Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell3.get(0))), columnFamilyStr); + Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell3.get(0))), "a"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell3.get(0))), "2"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell3.get(1))), columnFamilyStr); + Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell3.get(1))), "b"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell3.get(1))), "1"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell3.get(2))), columnFamilyStr1); + Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell3.get(2))), "a"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell3.get(2))), "1"); + + // Check row2 + List cell2 = table.get(new Get(Bytes.toBytes("2"))).listCells(); + Assert.assertEquals(cell2.size(), 2); + Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell2.get(0))), columnFamilyStr); + Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell2.get(0))), "a"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell2.get(0))), "3"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell2.get(1))), columnFamilyStr); + Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell2.get(1))), "b"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell2.get(1))), "3"); + } + } + + @Test + public void testBulkLoadThinRows() throws Exception { + Path output = TEST_UTIL.getDataTestDir("testBulkLoadThinRows"); + // because of the limitation of scala bulkLoadThinRows API + // we need to provide data as + List> list = new ArrayList>(); + // row1 + List list1 = new ArrayList(); + list1.add("1," + columnFamilyStr + ",b,1"); + list.add(list1); + // row3 + List list3 = new ArrayList(); + list3.add("3," + columnFamilyStr + ",a,2"); + list3.add("3," + columnFamilyStr + ",b,1"); + list3.add("3," + columnFamilyStr1 + ",a,1"); + list.add(list3); + // row2 + List list2 = new ArrayList(); + list2.add("2," + columnFamilyStr + ",a,3"); + list2.add("2," + columnFamilyStr + ",b,3"); + list.add(list2); + + JavaRDD> rdd = JSC.parallelize(list); + + Configuration conf = TEST_UTIL.getConfiguration(); + + HBASE_CONTEXT.bulkLoadThinRows(rdd, TableName.valueOf(tableName), + new BulkLoadThinRowsFunction(), output.toString(), new HashMap<>(), false, + HConstants.DEFAULT_MAX_FILE_SIZE); + + try (Connection conn = ConnectionFactory.createConnection(conf); + Admin admin = conn.getAdmin()) { + Table table = conn.getTable(TableName.valueOf(tableName)); + // Do bulk load + LoadIncrementalHFiles load = new LoadIncrementalHFiles(conf); + load.doBulkLoad(output, admin, table, conn.getRegionLocator(TableName.valueOf(tableName))); + + // Check row1 + List cell1 = table.get(new Get(Bytes.toBytes("1"))).listCells(); + Assert.assertEquals(cell1.size(), 1); + Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell1.get(0))), columnFamilyStr); + Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell1.get(0))), "b"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell1.get(0))), "1"); + + // Check row3 + List cell3 = table.get(new Get(Bytes.toBytes("3"))).listCells(); + Assert.assertEquals(cell3.size(), 3); + Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell3.get(0))), columnFamilyStr); + Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell3.get(0))), "a"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell3.get(0))), "2"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell3.get(1))), columnFamilyStr); + Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell3.get(1))), "b"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell3.get(1))), "1"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell3.get(2))), columnFamilyStr1); + Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell3.get(2))), "a"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell3.get(2))), "1"); + + // Check row2 + List cell2 = table.get(new Get(Bytes.toBytes("2"))).listCells(); + Assert.assertEquals(cell2.size(), 2); + Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell2.get(0))), columnFamilyStr); + Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell2.get(0))), "a"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell2.get(0))), "3"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneFamily(cell2.get(1))), columnFamilyStr); + Assert.assertEquals(Bytes.toString(CellUtil.cloneQualifier(cell2.get(1))), "b"); + Assert.assertEquals(Bytes.toString(CellUtil.cloneValue(cell2.get(1))), "3"); + } + + } + + public static class BulkLoadFunction + implements Function> { + @Override + public Pair call(String v1) throws Exception { + if (v1 == null) { + return null; + } + + String[] strs = v1.split(","); + if (strs.length != 4) { + return null; + } + + KeyFamilyQualifier kfq = new KeyFamilyQualifier(Bytes.toBytes(strs[0]), + Bytes.toBytes(strs[1]), Bytes.toBytes(strs[2])); + return new Pair(kfq, Bytes.toBytes(strs[3])); + } + } + + public static class BulkLoadThinRowsFunction + implements Function, Pair> { + @Override + public Pair call(List list) { + if (list == null) { + return null; + } + + ByteArrayWrapper rowKey = null; + FamiliesQualifiersValues fqv = new FamiliesQualifiersValues(); + for (String cell : list) { + String[] strs = cell.split(","); + if (rowKey == null) { + rowKey = new ByteArrayWrapper(Bytes.toBytes(strs[0])); + } + fqv.add(Bytes.toBytes(strs[1]), Bytes.toBytes(strs[2]), Bytes.toBytes(strs[3])); + } + return new Pair(rowKey, fqv); + } + } + + public static class GetFunction implements Function { + + private static final long serialVersionUID = 1L; + + @Override + public Get call(byte[] v) throws Exception { + return new Get(v); + } + } + + public static class ResultFunction implements Function { + + private static final long serialVersionUID = 1L; + + @Override + public String call(Result result) throws Exception { + Iterator it = result.listCells().iterator(); + StringBuilder b = new StringBuilder(); + + b.append(Bytes.toString(result.getRow())).append(":"); + + while (it.hasNext()) { + Cell cell = it.next(); + String q = Bytes.toString(CellUtil.cloneQualifier(cell)); + if ("counter".equals(q)) { + b.append("(").append(q).append(",").append(Bytes.toLong(CellUtil.cloneValue(cell))) + .append(")"); + } else { + b.append("(").append(q).append(",").append(Bytes.toString(CellUtil.cloneValue(cell))) + .append(")"); + } + } + return b.toString(); + } + } + + protected void populateTableWithMockData(Configuration conf, TableName tableName) + throws IOException { + try (Connection conn = ConnectionFactory.createConnection(conf); + Table table = conn.getTable(tableName); Admin admin = conn.getAdmin()) { + + List puts = new ArrayList<>(5); + + for (int i = 1; i < 6; i++) { + Put put = new Put(Bytes.toBytes(Integer.toString(i))); + put.addColumn(columnFamily, columnFamily, columnFamily); + puts.add(put); + } + table.put(puts); + admin.flush(tableName); + } + } +} diff --git a/spark4/hbase-spark4/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContextForLargeRows.java b/spark4/hbase-spark4/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContextForLargeRows.java new file mode 100644 index 00000000..824b53d4 --- /dev/null +++ b/spark4/hbase-spark4/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContextForLargeRows.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseClassTestRule; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Admin; +import org.apache.hadoop.hbase.client.Connection; +import org.apache.hadoop.hbase.client.ConnectionFactory; +import org.apache.hadoop.hbase.client.Put; +import org.apache.hadoop.hbase.client.Table; +import org.apache.hadoop.hbase.testclassification.MediumTests; +import org.apache.hadoop.hbase.testclassification.MiscTests; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.spark.api.java.JavaSparkContext; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.experimental.categories.Category; + +@Category({ MiscTests.class, MediumTests.class }) +public class TestJavaHBaseContextForLargeRows extends TestJavaHBaseContext { + + @ClassRule + public static final HBaseClassTestRule TIMEOUT = + HBaseClassTestRule.forClass(TestJavaHBaseContextForLargeRows.class); + + @BeforeClass + public static void setUpBeforeClass() throws Exception { + JSC = new JavaSparkContext("local", "JavaHBaseContextSuite"); + + init(); + } + + protected void populateTableWithMockData(Configuration conf, TableName tableName) + throws IOException { + try (Connection conn = ConnectionFactory.createConnection(conf); + Table table = conn.getTable(tableName); Admin admin = conn.getAdmin()) { + + List puts = new ArrayList<>(5); + + for (int i = 1; i < 6; i++) { + Put put = new Put(Bytes.toBytes(Integer.toString(i))); + // We are trying to generate a large row value here + char[] chars = new char[1024 * 1024 * 2]; + // adding '0' to convert int to char + Arrays.fill(chars, (char) (i + '0')); + put.addColumn(columnFamily, columnFamily, Bytes.toBytes(String.valueOf(chars))); + puts.add(put); + } + table.put(puts); + admin.flush(tableName); + } + } +} diff --git a/spark4/hbase-spark4/src/test/resources/hbase-site.xml b/spark4/hbase-spark4/src/test/resources/hbase-site.xml new file mode 100644 index 00000000..b3fb0d90 --- /dev/null +++ b/spark4/hbase-spark4/src/test/resources/hbase-site.xml @@ -0,0 +1,157 @@ + + + + + + hbase.regionserver.msginterval + 1000 + Interval between messages from the RegionServer to HMaster + in milliseconds. Default is 15. Set this value low if you want unit + tests to be responsive. + + + + hbase.defaults.for.version.skip + true + + + hbase.server.thread.wakefrequency + 1000 + Time to sleep in between searches for work (in milliseconds). + Used as sleep interval by service threads such as hbase:meta scanner and log roller. + + + + hbase.master.event.waiting.time + 50 + Time to sleep between checks to see if a table event took place. + + + + hbase.regionserver.handler.count + 5 + + + hbase.regionserver.metahandler.count + 5 + + + hbase.ipc.server.read.threadpool.size + 3 + + + hbase.master.info.port + -1 + The port for the hbase master web UI + Set to -1 if you do not want the info server to run. + + + + hbase.master.port + 0 + Always have masters and regionservers come up on port '0' so we don't clash over + default ports. + + + + hbase.regionserver.port + 0 + Always have masters and regionservers come up on port '0' so we don't clash over + default ports. + + + + hbase.ipc.client.fallback-to-simple-auth-allowed + true + + + + hbase.regionserver.info.port + -1 + The port for the hbase regionserver web UI + Set to -1 if you do not want the info server to run. + + + + hbase.regionserver.info.port.auto + true + Info server auto port bind. Enables automatic port + search if hbase.regionserver.info.port is already in use. + Enabled for testing to run multiple tests on one machine. + + + + hbase.regionserver.safemode + false + + Turn on/off safe mode in region server. Always on for production, always off + for tests. + + + + hbase.hregion.max.filesize + 67108864 + + Maximum desired file size for an HRegion. If filesize exceeds + value + (value / 2), the HRegion is split in two. Default: 256M. + + Keep the maximum filesize small so we split more often in tests. + + + + hadoop.log.dir + ${user.dir}/../logs + + + hbase.zookeeper.property.clientPort + 21818 + Property from ZooKeeper's config zoo.cfg. + The port at which the clients will connect. + + + + hbase.defaults.for.version.skip + true + + Set to true to skip the 'hbase.defaults.for.version'. + Setting this to true can be useful in contexts other than + the other side of a maven generation; i.e. running in an + ide. You'll want to set this boolean to true to avoid + seeing the RuntimeException complaint: "hbase-default.xml file + seems to be for and old version of HBase (@@@VERSION@@@), this + version is X.X.X-SNAPSHOT" + + + + hbase.table.sanity.checks + false + Skip sanity checks in tests + + + + hbase.procedure.fail.on.corruption + true + + Enable replay sanity checks on procedure tests. + + + diff --git a/spark4/hbase-spark4/src/test/resources/log4j.properties b/spark4/hbase-spark4/src/test/resources/log4j.properties new file mode 100644 index 00000000..cd3b8e9d --- /dev/null +++ b/spark4/hbase-spark4/src/test/resources/log4j.properties @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Define some default values that can be overridden by system properties +hbase.root.logger=INFO,FA +hbase.log.dir=. +hbase.log.file=hbase.log + +# Define the root logger to the system property "hbase.root.logger". +log4j.rootLogger=${hbase.root.logger} + +# Logging Threshold +log4j.threshold=ALL + +# +# Daily Rolling File Appender +# +log4j.appender.DRFA=org.apache.log4j.DailyRollingFileAppender +log4j.appender.DRFA.File=${hbase.log.dir}/${hbase.log.file} + +# Rollver at midnight +log4j.appender.DRFA.DatePattern=.yyyy-MM-dd + +# 30-day backup +#log4j.appender.DRFA.MaxBackupIndex=30 +log4j.appender.DRFA.layout=org.apache.log4j.PatternLayout +# Debugging Pattern format +log4j.appender.DRFA.layout.ConversionPattern=%d{ISO8601} %-5p [%t] %C{2}(%L): %m%n + + +# +# console +# Add "console" to rootlogger above if you want to use this +# +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{ISO8601} %-5p [%t] %C{2}(%L): %m%n + +#File Appender +log4j.appender.FA=org.apache.log4j.FileAppender +log4j.appender.FA.append=false +log4j.appender.FA.file=target/log-output.txt +log4j.appender.FA.layout=org.apache.log4j.PatternLayout +log4j.appender.FA.layout.ConversionPattern=%d{ISO8601} %-5p [%t] %C{2}(%L): %m%n +log4j.appender.FA.Threshold = INFO + +# Custom Logging levels + +#log4j.logger.org.apache.hadoop.fs.FSNamesystem=DEBUG + +log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.apache.zookeeper=ERROR +log4j.logger.org.apache.hadoop.hbase=DEBUG + +#These settings are workarounds against spurious logs from the minicluster. +#See HBASE-4709 +log4j.logger.org.apache.hadoop.metrics2.impl.MetricsConfig=WARN +log4j.logger.org.apache.hadoop.metrics2.impl.MetricsSinkAdapter=WARN +log4j.logger.org.apache.hadoop.metrics2.impl.MetricsSystemImpl=WARN +log4j.logger.org.apache.hadoop.metrics2.util.MBeans=WARN +# Enable this to get detailed connection error/retry logging. +# log4j.logger.org.apache.hadoop.hbase.client.ConnectionImplementation=TRACE diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/BulkLoadSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/BulkLoadSuite.scala new file mode 100644 index 00000000..2b03c0f9 --- /dev/null +++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/BulkLoadSuite.scala @@ -0,0 +1,1055 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.io.File +import java.net.URI +import java.nio.file.Files +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hbase.{CellUtil, HBaseTestingUtility, HConstants, TableName} +import org.apache.hadoop.hbase.client.{ConnectionFactory, Get} +import org.apache.hadoop.hbase.io.hfile.{CacheConfig, HFile} +import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._ +import org.apache.hadoop.hbase.tool.LoadIncrementalHFiles +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkContext +import org.junit.rules.TemporaryFolder +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +class BulkLoadSuite extends AnyFunSuite with BeforeAndAfterEach with BeforeAndAfterAll with Logging { + @transient var sc: SparkContext = null + var TEST_UTIL = new HBaseTestingUtility + + val tableName = "t1" + val columnFamily1 = "f1" + val columnFamily2 = "f2" + val testFolder = new TemporaryFolder() + + override def beforeAll() { + TEST_UTIL.startMiniCluster() + logInfo(" - minicluster started") + + try { + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + } catch { + case e: Exception => + logInfo(" - no table " + tableName + " found") + } + + logInfo(" - created table") + + val envMap = Map[String, String](("Xmx", "512m")) + + sc = new SparkContext("local", "test", null, Nil, envMap) + } + + override def afterAll() { + logInfo("shuting down minicluster") + TEST_UTIL.shutdownMiniCluster() + logInfo(" - minicluster shut down") + TEST_UTIL.cleanupTestDir() + sc.stop() + } + + test("Staging dir: Test usage of staging dir on a separate filesystem") { + val config = TEST_UTIL.getConfiguration + + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable( + TableName.valueOf(tableName), + Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2))) + + // Test creates rdd with 2 column families and + // write those to hfiles on local filesystem + // using bulkLoad functionality. We don't check the load functionality + // due the limitations of the HBase Minicluster + + val rdd = sc.parallelize( + Array[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))]( + ( + Bytes.toBytes("1"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))), + ( + Bytes.toBytes("2"), + (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("bar.2"))))) + + val hbaseContext = new HBaseContext(sc, config) + val uri = Files.createTempDirectory("tmpDirPrefix").toUri + val stagingUri = new URI(uri + "staging_dir") + val stagingFolder = new File(stagingUri) + val fs = new Path(stagingUri.toString).getFileSystem(config) + try { + hbaseContext.bulkLoad[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))]( + rdd, + TableName.valueOf(tableName), + t => { + val rowKey = t._1 + val family: Array[Byte] = t._2._1 + val qualifier = t._2._2 + val value: Array[Byte] = t._2._3 + + val keyFamilyQualifier = new KeyFamilyQualifier(rowKey, family, qualifier) + + Seq((keyFamilyQualifier, value)).iterator + }, + stagingUri.toString) + + assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 2) + + } finally { + val admin = ConnectionFactory.createConnection(config).getAdmin + try { + admin.disableTable(TableName.valueOf(tableName)) + admin.deleteTable(TableName.valueOf(tableName)) + } finally { + admin.close() + } + fs.delete(new Path(stagingFolder.getPath), true) + + testFolder.delete() + + } + } + + test( + "Wide Row Bulk Load: Test multi family and multi column tests " + + "with all default HFile Configs.") { + val config = TEST_UTIL.getConfiguration + + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable( + TableName.valueOf(tableName), + Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2))) + + // There are a number of tests in here. + // 1. Row keys are not in order + // 2. Qualifiers are not in order + // 3. Column Families are not in order + // 4. There are tests for records in one column family and some in two column families + // 5. There are records will a single qualifier and some with two + val rdd = sc.parallelize( + Array[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))]( + ( + Bytes.toBytes("1"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))), + ( + Bytes.toBytes("3"), + (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo2.a"))), + ( + Bytes.toBytes("3"), + (Bytes.toBytes(columnFamily2), Bytes.toBytes("a"), Bytes.toBytes("foo2.b"))), + ( + Bytes.toBytes("3"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo2.c"))), + ( + Bytes.toBytes("5"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo3"))), + ( + Bytes.toBytes("4"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo.1"))), + ( + Bytes.toBytes("4"), + (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo.2"))), + ( + Bytes.toBytes("2"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("bar.1"))), + ( + Bytes.toBytes("2"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("bar.2"))))) + + val hbaseContext = new HBaseContext(sc, config) + + testFolder.create() + val stagingFolder = testFolder.newFolder() + + hbaseContext.bulkLoad[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))]( + rdd, + TableName.valueOf(tableName), + t => { + val rowKey = t._1 + val family: Array[Byte] = t._2._1 + val qualifier = t._2._2 + val value: Array[Byte] = t._2._3 + + val keyFamilyQualifier = new KeyFamilyQualifier(rowKey, family, qualifier) + + Seq((keyFamilyQualifier, value)).iterator + }, + stagingFolder.getPath) + + val fs = FileSystem.get(config) + assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 2) + + val conn = ConnectionFactory.createConnection(config) + + val load = new LoadIncrementalHFiles(config) + val table = conn.getTable(TableName.valueOf(tableName)) + try { + load.doBulkLoad( + new Path(stagingFolder.getPath), + conn.getAdmin, + table, + conn.getRegionLocator(TableName.valueOf(tableName))) + + val cells5 = table.get(new Get(Bytes.toBytes("5"))).listCells() + assert(cells5.size == 1) + assert(Bytes.toString(CellUtil.cloneValue(cells5.get(0))).equals("foo3")) + assert(Bytes.toString(CellUtil.cloneFamily(cells5.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells5.get(0))).equals("a")) + + val cells4 = table.get(new Get(Bytes.toBytes("4"))).listCells() + assert(cells4.size == 2) + assert(Bytes.toString(CellUtil.cloneValue(cells4.get(0))).equals("foo.1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells4.get(1))).equals("foo.2")) + assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(1))).equals("f2")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(1))).equals("b")) + + val cells3 = table.get(new Get(Bytes.toBytes("3"))).listCells() + assert(cells3.size == 3) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(0))).equals("foo2.c")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(1))).equals("foo2.b")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(1))).equals("f2")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(1))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(2))).equals("foo2.a")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(2))).equals("f2")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(2))).equals("b")) + + val cells2 = table.get(new Get(Bytes.toBytes("2"))).listCells() + assert(cells2.size == 2) + assert(Bytes.toString(CellUtil.cloneValue(cells2.get(0))).equals("bar.1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells2.get(1))).equals("bar.2")) + assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(1))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(1))).equals("b")) + + val cells1 = table.get(new Get(Bytes.toBytes("1"))).listCells() + assert(cells1.size == 1) + assert(Bytes.toString(CellUtil.cloneValue(cells1.get(0))).equals("foo1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells1.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells1.get(0))).equals("a")) + + } finally { + table.close() + val admin = ConnectionFactory.createConnection(config).getAdmin + try { + admin.disableTable(TableName.valueOf(tableName)) + admin.deleteTable(TableName.valueOf(tableName)) + } finally { + admin.close() + } + fs.delete(new Path(stagingFolder.getPath), true) + + testFolder.delete() + + } + } + + test( + "Wide Row Bulk Load: Test HBase client: Test Roll Over and " + + "using an implicit call to bulk load") { + val config = TEST_UTIL.getConfiguration + + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable( + TableName.valueOf(tableName), + Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2))) + + // There are a number of tests in here. + // 1. Row keys are not in order + // 2. Qualifiers are not in order + // 3. Column Families are not in order + // 4. There are tests for records in one column family and some in two column families + // 5. There are records will a single qualifier and some with two + val rdd = sc.parallelize( + Array[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))]( + ( + Bytes.toBytes("1"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))), + ( + Bytes.toBytes("3"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("foo2.b"))), + ( + Bytes.toBytes("3"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo2.a"))), + ( + Bytes.toBytes("3"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("c"), Bytes.toBytes("foo2.c"))), + ( + Bytes.toBytes("5"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo3"))), + ( + Bytes.toBytes("4"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo.1"))), + ( + Bytes.toBytes("4"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("foo.2"))), + ( + Bytes.toBytes("2"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("bar.1"))), + ( + Bytes.toBytes("2"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("bar.2"))))) + + val hbaseContext = new HBaseContext(sc, config) + + testFolder.create() + val stagingFolder = testFolder.newFolder() + + rdd.hbaseBulkLoad( + hbaseContext, + TableName.valueOf(tableName), + t => { + val rowKey = t._1 + val family: Array[Byte] = t._2._1 + val qualifier = t._2._2 + val value = t._2._3 + + val keyFamilyQualifier = new KeyFamilyQualifier(rowKey, family, qualifier) + + Seq((keyFamilyQualifier, value)).iterator + }, + stagingFolder.getPath, + new java.util.HashMap[Array[Byte], FamilyHFileWriteOptions], + compactionExclude = false, + 20) + + val fs = FileSystem.get(config) + assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 1) + + assert(fs.listStatus(new Path(stagingFolder.getPath + "/f1")).length == 5) + + val conn = ConnectionFactory.createConnection(config) + + val load = new LoadIncrementalHFiles(config) + val table = conn.getTable(TableName.valueOf(tableName)) + try { + load.doBulkLoad( + new Path(stagingFolder.getPath), + conn.getAdmin, + table, + conn.getRegionLocator(TableName.valueOf(tableName))) + + val cells5 = table.get(new Get(Bytes.toBytes("5"))).listCells() + assert(cells5.size == 1) + assert(Bytes.toString(CellUtil.cloneValue(cells5.get(0))).equals("foo3")) + assert(Bytes.toString(CellUtil.cloneFamily(cells5.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells5.get(0))).equals("a")) + + val cells4 = table.get(new Get(Bytes.toBytes("4"))).listCells() + assert(cells4.size == 2) + assert(Bytes.toString(CellUtil.cloneValue(cells4.get(0))).equals("foo.1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells4.get(1))).equals("foo.2")) + assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(1))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(1))).equals("b")) + + val cells3 = table.get(new Get(Bytes.toBytes("3"))).listCells() + assert(cells3.size == 3) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(0))).equals("foo2.a")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(1))).equals("foo2.b")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(1))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(1))).equals("b")) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(2))).equals("foo2.c")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(2))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(2))).equals("c")) + + val cells2 = table.get(new Get(Bytes.toBytes("2"))).listCells() + assert(cells2.size == 2) + assert(Bytes.toString(CellUtil.cloneValue(cells2.get(0))).equals("bar.1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells2.get(1))).equals("bar.2")) + assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(1))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(1))).equals("b")) + + val cells1 = table.get(new Get(Bytes.toBytes("1"))).listCells() + assert(cells1.size == 1) + assert(Bytes.toString(CellUtil.cloneValue(cells1.get(0))).equals("foo1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells1.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells1.get(0))).equals("a")) + + } finally { + table.close() + val admin = ConnectionFactory.createConnection(config).getAdmin + try { + admin.disableTable(TableName.valueOf(tableName)) + admin.deleteTable(TableName.valueOf(tableName)) + } finally { + admin.close() + } + fs.delete(new Path(stagingFolder.getPath), true) + + testFolder.delete() + } + } + + test( + "Wide Row Bulk Load: Test multi family and multi column tests" + + " with one column family with custom configs plus multi region") { + val config = TEST_UTIL.getConfiguration + + val splitKeys: Array[Array[Byte]] = new Array[Array[Byte]](2) + splitKeys(0) = Bytes.toBytes("2") + splitKeys(1) = Bytes.toBytes("4") + + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable( + TableName.valueOf(tableName), + Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2)), + splitKeys) + + // There are a number of tests in here. + // 1. Row keys are not in order + // 2. Qualifiers are not in order + // 3. Column Families are not in order + // 4. There are tests for records in one column family and some in two column families + // 5. There are records will a single qualifier and some with two + val rdd = sc.parallelize( + Array[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))]( + ( + Bytes.toBytes("1"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))), + ( + Bytes.toBytes("3"), + (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo2.a"))), + ( + Bytes.toBytes("3"), + (Bytes.toBytes(columnFamily2), Bytes.toBytes("a"), Bytes.toBytes("foo2.b"))), + ( + Bytes.toBytes("3"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo2.c"))), + ( + Bytes.toBytes("5"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo3"))), + ( + Bytes.toBytes("4"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo.1"))), + ( + Bytes.toBytes("4"), + (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo.2"))), + ( + Bytes.toBytes("2"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("bar.1"))), + ( + Bytes.toBytes("2"), + (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("bar.2"))))) + + val hbaseContext = new HBaseContext(sc, config) + + testFolder.create() + val stagingFolder = testFolder.newFolder() + + val familyHBaseWriterOptions = new java.util.HashMap[Array[Byte], FamilyHFileWriteOptions] + + val f1Options = new FamilyHFileWriteOptions("GZ", "ROW", 128, "PREFIX") + + familyHBaseWriterOptions.put(Bytes.toBytes(columnFamily1), f1Options) + + hbaseContext.bulkLoad[(Array[Byte], (Array[Byte], Array[Byte], Array[Byte]))]( + rdd, + TableName.valueOf(tableName), + t => { + val rowKey = t._1 + val family: Array[Byte] = t._2._1 + val qualifier = t._2._2 + val value = t._2._3 + + val keyFamilyQualifier = new KeyFamilyQualifier(rowKey, family, qualifier) + + Seq((keyFamilyQualifier, value)).iterator + }, + stagingFolder.getPath, + familyHBaseWriterOptions, + compactionExclude = false, + HConstants.DEFAULT_MAX_FILE_SIZE) + + val fs = FileSystem.get(config) + assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 2) + + val f1FileList = fs.listStatus(new Path(stagingFolder.getPath + "/f1")) + for (i <- 0 until f1FileList.length) { + val reader = + HFile.createReader(fs, f1FileList(i).getPath, new CacheConfig(config), true, config) + assert(reader.getTrailer.getCompressionCodec().getName.equals("gz")) + assert(reader.getDataBlockEncoding.name().equals("PREFIX")) + } + + assert(3 == f1FileList.length) + + val f2FileList = fs.listStatus(new Path(stagingFolder.getPath + "/f2")) + for (i <- 0 until f2FileList.length) { + val reader = + HFile.createReader(fs, f2FileList(i).getPath, new CacheConfig(config), true, config) + assert(reader.getTrailer.getCompressionCodec().getName.equals("none")) + assert(reader.getDataBlockEncoding.name().equals("NONE")) + } + + assert(2 == f2FileList.length) + + val conn = ConnectionFactory.createConnection(config) + + val load = new LoadIncrementalHFiles(config) + val table = conn.getTable(TableName.valueOf(tableName)) + try { + load.doBulkLoad( + new Path(stagingFolder.getPath), + conn.getAdmin, + table, + conn.getRegionLocator(TableName.valueOf(tableName))) + + val cells5 = table.get(new Get(Bytes.toBytes("5"))).listCells() + assert(cells5.size == 1) + assert(Bytes.toString(CellUtil.cloneValue(cells5.get(0))).equals("foo3")) + assert(Bytes.toString(CellUtil.cloneFamily(cells5.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells5.get(0))).equals("a")) + + val cells4 = table.get(new Get(Bytes.toBytes("4"))).listCells() + assert(cells4.size == 2) + assert(Bytes.toString(CellUtil.cloneValue(cells4.get(0))).equals("foo.1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells4.get(1))).equals("foo.2")) + assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(1))).equals("f2")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(1))).equals("b")) + + val cells3 = table.get(new Get(Bytes.toBytes("3"))).listCells() + assert(cells3.size == 3) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(0))).equals("foo2.c")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(1))).equals("foo2.b")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(1))).equals("f2")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(1))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(2))).equals("foo2.a")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(2))).equals("f2")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(2))).equals("b")) + + val cells2 = table.get(new Get(Bytes.toBytes("2"))).listCells() + assert(cells2.size == 2) + assert(Bytes.toString(CellUtil.cloneValue(cells2.get(0))).equals("bar.1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells2.get(1))).equals("bar.2")) + assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(1))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(1))).equals("b")) + + val cells1 = table.get(new Get(Bytes.toBytes("1"))).listCells() + assert(cells1.size == 1) + assert(Bytes.toString(CellUtil.cloneValue(cells1.get(0))).equals("foo1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells1.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells1.get(0))).equals("a")) + + } finally { + table.close() + val admin = ConnectionFactory.createConnection(config).getAdmin + try { + admin.disableTable(TableName.valueOf(tableName)) + admin.deleteTable(TableName.valueOf(tableName)) + } finally { + admin.close() + } + fs.delete(new Path(stagingFolder.getPath), true) + + testFolder.delete() + + } + } + + test("Test partitioner") { + + var splitKeys: Array[Array[Byte]] = new Array[Array[Byte]](3) + splitKeys(0) = Bytes.toBytes("") + splitKeys(1) = Bytes.toBytes("3") + splitKeys(2) = Bytes.toBytes("7") + + var partitioner = new BulkLoadPartitioner(splitKeys) + + assert(0 == partitioner.getPartition(Bytes.toBytes(""))) + assert(0 == partitioner.getPartition(Bytes.toBytes("1"))) + assert(0 == partitioner.getPartition(Bytes.toBytes("2"))) + assert(1 == partitioner.getPartition(Bytes.toBytes("3"))) + assert(1 == partitioner.getPartition(Bytes.toBytes("4"))) + assert(1 == partitioner.getPartition(Bytes.toBytes("6"))) + assert(2 == partitioner.getPartition(Bytes.toBytes("7"))) + assert(2 == partitioner.getPartition(Bytes.toBytes("8"))) + + splitKeys = new Array[Array[Byte]](1) + splitKeys(0) = Bytes.toBytes("") + + partitioner = new BulkLoadPartitioner(splitKeys) + + assert(0 == partitioner.getPartition(Bytes.toBytes(""))) + assert(0 == partitioner.getPartition(Bytes.toBytes("1"))) + assert(0 == partitioner.getPartition(Bytes.toBytes("2"))) + assert(0 == partitioner.getPartition(Bytes.toBytes("3"))) + assert(0 == partitioner.getPartition(Bytes.toBytes("4"))) + assert(0 == partitioner.getPartition(Bytes.toBytes("6"))) + assert(0 == partitioner.getPartition(Bytes.toBytes("7"))) + + splitKeys = new Array[Array[Byte]](7) + splitKeys(0) = Bytes.toBytes("") + splitKeys(1) = Bytes.toBytes("02") + splitKeys(2) = Bytes.toBytes("04") + splitKeys(3) = Bytes.toBytes("06") + splitKeys(4) = Bytes.toBytes("08") + splitKeys(5) = Bytes.toBytes("10") + splitKeys(6) = Bytes.toBytes("12") + + partitioner = new BulkLoadPartitioner(splitKeys) + + assert(0 == partitioner.getPartition(Bytes.toBytes(""))) + assert(0 == partitioner.getPartition(Bytes.toBytes("01"))) + assert(1 == partitioner.getPartition(Bytes.toBytes("02"))) + assert(1 == partitioner.getPartition(Bytes.toBytes("03"))) + assert(2 == partitioner.getPartition(Bytes.toBytes("04"))) + assert(2 == partitioner.getPartition(Bytes.toBytes("05"))) + assert(3 == partitioner.getPartition(Bytes.toBytes("06"))) + assert(3 == partitioner.getPartition(Bytes.toBytes("07"))) + assert(4 == partitioner.getPartition(Bytes.toBytes("08"))) + assert(4 == partitioner.getPartition(Bytes.toBytes("09"))) + assert(5 == partitioner.getPartition(Bytes.toBytes("10"))) + assert(5 == partitioner.getPartition(Bytes.toBytes("11"))) + assert(6 == partitioner.getPartition(Bytes.toBytes("12"))) + assert(6 == partitioner.getPartition(Bytes.toBytes("13"))) + } + + test( + "Thin Row Bulk Load: Test multi family and multi column tests " + + "with all default HFile Configs") { + val config = TEST_UTIL.getConfiguration + + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable( + TableName.valueOf(tableName), + Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2))) + + // There are a number of tests in here. + // 1. Row keys are not in order + // 2. Qualifiers are not in order + // 3. Column Families are not in order + // 4. There are tests for records in one column family and some in two column families + // 5. There are records will a single qualifier and some with two + val rdd = sc + .parallelize( + Array[(String,(Array[Byte], Array[Byte], Array[Byte]))]( + ("1", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))), + ("3", (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo2.a"))), + ("3", (Bytes.toBytes(columnFamily2), Bytes.toBytes("a"), Bytes.toBytes("foo2.b"))), + ("3", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo2.c"))), + ("5", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo3"))), + ("4", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo.1"))), + ("4", (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo.2"))), + ("2", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("bar.1"))), + ("2", (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("bar.2"))))) + .groupByKey() + + val hbaseContext = new HBaseContext(sc, config) + + testFolder.create() + val stagingFolder = testFolder.newFolder() + + hbaseContext.bulkLoadThinRows[(String, Iterable[(Array[Byte], Array[Byte], Array[Byte])])]( + rdd, + TableName.valueOf(tableName), + t => { + val rowKey = Bytes.toBytes(t._1) + + val familyQualifiersValues = new FamiliesQualifiersValues + t._2.foreach( + f => { + val family: Array[Byte] = f._1 + val qualifier = f._2 + val value: Array[Byte] = f._3 + + familyQualifiersValues += (family, qualifier, value) + }) + (new ByteArrayWrapper(rowKey), familyQualifiersValues) + }, + stagingFolder.getPath) + + val fs = FileSystem.get(config) + assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 2) + + val conn = ConnectionFactory.createConnection(config) + + val load = new LoadIncrementalHFiles(config) + val table = conn.getTable(TableName.valueOf(tableName)) + try { + load.doBulkLoad( + new Path(stagingFolder.getPath), + conn.getAdmin, + table, + conn.getRegionLocator(TableName.valueOf(tableName))) + + val cells5 = table.get(new Get(Bytes.toBytes("5"))).listCells() + assert(cells5.size == 1) + assert(Bytes.toString(CellUtil.cloneValue(cells5.get(0))).equals("foo3")) + assert(Bytes.toString(CellUtil.cloneFamily(cells5.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells5.get(0))).equals("a")) + + val cells4 = table.get(new Get(Bytes.toBytes("4"))).listCells() + assert(cells4.size == 2) + assert(Bytes.toString(CellUtil.cloneValue(cells4.get(0))).equals("foo.1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells4.get(1))).equals("foo.2")) + assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(1))).equals("f2")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(1))).equals("b")) + + val cells3 = table.get(new Get(Bytes.toBytes("3"))).listCells() + assert(cells3.size == 3) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(0))).equals("foo2.c")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(1))).equals("foo2.b")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(1))).equals("f2")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(1))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(2))).equals("foo2.a")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(2))).equals("f2")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(2))).equals("b")) + + val cells2 = table.get(new Get(Bytes.toBytes("2"))).listCells() + assert(cells2.size == 2) + assert(Bytes.toString(CellUtil.cloneValue(cells2.get(0))).equals("bar.1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells2.get(1))).equals("bar.2")) + assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(1))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(1))).equals("b")) + + val cells1 = table.get(new Get(Bytes.toBytes("1"))).listCells() + assert(cells1.size == 1) + assert(Bytes.toString(CellUtil.cloneValue(cells1.get(0))).equals("foo1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells1.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells1.get(0))).equals("a")) + + } finally { + table.close() + val admin = ConnectionFactory.createConnection(config).getAdmin + try { + admin.disableTable(TableName.valueOf(tableName)) + admin.deleteTable(TableName.valueOf(tableName)) + } finally { + admin.close() + } + fs.delete(new Path(stagingFolder.getPath), true) + + testFolder.delete() + + } + } + + test( + "Thin Row Bulk Load: Test HBase client: Test Roll Over and " + + "using an implicit call to bulk load") { + val config = TEST_UTIL.getConfiguration + + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable( + TableName.valueOf(tableName), + Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2))) + + // There are a number of tests in here. + // 1. Row keys are not in order + // 2. Qualifiers are not in order + // 3. Column Families are not in order + // 4. There are tests for records in one column family and some in two column families + // 5. There are records will a single qualifier and some with two + val rdd = sc + .parallelize( + Array[(String,(Array[Byte], Array[Byte], Array[Byte]))]( + ("1", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))), + ("3", (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("foo2.b"))), + ("3", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo2.a"))), + ("3", (Bytes.toBytes(columnFamily1), Bytes.toBytes("c"), Bytes.toBytes("foo2.c"))), + ("5", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo3"))), + ("4", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo.1"))), + ("4", (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("foo.2"))), + ("2", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("bar.1"))), + ("2", (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("bar.2"))))) + .groupByKey() + + val hbaseContext = new HBaseContext(sc, config) + + testFolder.create() + val stagingFolder = testFolder.newFolder() + + rdd.hbaseBulkLoadThinRows( + hbaseContext, + TableName.valueOf(tableName), + t => { + val rowKey = t._1 + + val familyQualifiersValues = new FamiliesQualifiersValues + t._2.foreach( + f => { + val family: Array[Byte] = f._1 + val qualifier = f._2 + val value: Array[Byte] = f._3 + + familyQualifiersValues += (family, qualifier, value) + }) + (new ByteArrayWrapper(Bytes.toBytes(rowKey)), familyQualifiersValues) + }, + stagingFolder.getPath, + new java.util.HashMap[Array[Byte], FamilyHFileWriteOptions], + compactionExclude = false, + 20) + + val fs = FileSystem.get(config) + assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 1) + + assert(fs.listStatus(new Path(stagingFolder.getPath + "/f1")).length == 5) + + val conn = ConnectionFactory.createConnection(config) + + val load = new LoadIncrementalHFiles(config) + val table = conn.getTable(TableName.valueOf(tableName)) + try { + load.doBulkLoad( + new Path(stagingFolder.getPath), + conn.getAdmin, + table, + conn.getRegionLocator(TableName.valueOf(tableName))) + + val cells5 = table.get(new Get(Bytes.toBytes("5"))).listCells() + assert(cells5.size == 1) + assert(Bytes.toString(CellUtil.cloneValue(cells5.get(0))).equals("foo3")) + assert(Bytes.toString(CellUtil.cloneFamily(cells5.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells5.get(0))).equals("a")) + + val cells4 = table.get(new Get(Bytes.toBytes("4"))).listCells() + assert(cells4.size == 2) + assert(Bytes.toString(CellUtil.cloneValue(cells4.get(0))).equals("foo.1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells4.get(1))).equals("foo.2")) + assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(1))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(1))).equals("b")) + + val cells3 = table.get(new Get(Bytes.toBytes("3"))).listCells() + assert(cells3.size == 3) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(0))).equals("foo2.a")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(1))).equals("foo2.b")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(1))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(1))).equals("b")) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(2))).equals("foo2.c")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(2))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(2))).equals("c")) + + val cells2 = table.get(new Get(Bytes.toBytes("2"))).listCells() + assert(cells2.size == 2) + assert(Bytes.toString(CellUtil.cloneValue(cells2.get(0))).equals("bar.1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells2.get(1))).equals("bar.2")) + assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(1))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(1))).equals("b")) + + val cells1 = table.get(new Get(Bytes.toBytes("1"))).listCells() + assert(cells1.size == 1) + assert(Bytes.toString(CellUtil.cloneValue(cells1.get(0))).equals("foo1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells1.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells1.get(0))).equals("a")) + + } finally { + table.close() + val admin = ConnectionFactory.createConnection(config).getAdmin + try { + admin.disableTable(TableName.valueOf(tableName)) + admin.deleteTable(TableName.valueOf(tableName)) + } finally { + admin.close() + } + fs.delete(new Path(stagingFolder.getPath), true) + + testFolder.delete() + } + } + + test( + "Thin Row Bulk Load: Test multi family and multi column tests" + + " with one column family with custom configs plus multi region") { + val config = TEST_UTIL.getConfiguration + + val splitKeys: Array[Array[Byte]] = new Array[Array[Byte]](2) + splitKeys(0) = Bytes.toBytes("2") + splitKeys(1) = Bytes.toBytes("4") + + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable( + TableName.valueOf(tableName), + Array(Bytes.toBytes(columnFamily1), Bytes.toBytes(columnFamily2)), + splitKeys) + + // There are a number of tests in here. + // 1. Row keys are not in order + // 2. Qualifiers are not in order + // 3. Column Families are not in order + // 4. There are tests for records in one column family and some in two column families + // 5. There are records will a single qualifier and some with two + val rdd = sc + .parallelize( + Array[(String, (Array[Byte], Array[Byte], Array[Byte]))]( + ("1", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo1"))), + ("3", (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo2.a"))), + ("3", (Bytes.toBytes(columnFamily2), Bytes.toBytes("a"), Bytes.toBytes("foo2.b"))), + ("3", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo2.c"))), + ("5", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo3"))), + ("4", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("foo.1"))), + ("4", (Bytes.toBytes(columnFamily2), Bytes.toBytes("b"), Bytes.toBytes("foo.2"))), + ("2", (Bytes.toBytes(columnFamily1), Bytes.toBytes("a"), Bytes.toBytes("bar.1"))), + ("2", (Bytes.toBytes(columnFamily1), Bytes.toBytes("b"), Bytes.toBytes("bar.2"))))) + .groupByKey() + + val hbaseContext = new HBaseContext(sc, config) + + testFolder.create() + val stagingFolder = testFolder.newFolder() + + val familyHBaseWriterOptions = new java.util.HashMap[Array[Byte], FamilyHFileWriteOptions] + + val f1Options = new FamilyHFileWriteOptions("GZ", "ROW", 128, "PREFIX") + + familyHBaseWriterOptions.put(Bytes.toBytes(columnFamily1), f1Options) + + hbaseContext.bulkLoadThinRows[(String, Iterable[(Array[Byte], Array[Byte], Array[Byte])])]( + rdd, + TableName.valueOf(tableName), + t => { + val rowKey = t._1 + + val familyQualifiersValues = new FamiliesQualifiersValues + t._2.foreach( + f => { + val family: Array[Byte] = f._1 + val qualifier = f._2 + val value: Array[Byte] = f._3 + + familyQualifiersValues += (family, qualifier, value) + }) + (new ByteArrayWrapper(Bytes.toBytes(rowKey)), familyQualifiersValues) + }, + stagingFolder.getPath, + familyHBaseWriterOptions, + compactionExclude = false, + HConstants.DEFAULT_MAX_FILE_SIZE) + + val fs = FileSystem.get(config) + assert(fs.listStatus(new Path(stagingFolder.getPath)).length == 2) + + val f1FileList = fs.listStatus(new Path(stagingFolder.getPath + "/f1")) + for (i <- 0 until f1FileList.length) { + val reader = + HFile.createReader(fs, f1FileList(i).getPath, new CacheConfig(config), true, config) + assert(reader.getTrailer.getCompressionCodec().getName.equals("gz")) + assert(reader.getDataBlockEncoding.name().equals("PREFIX")) + } + + assert(3 == f1FileList.length) + + val f2FileList = fs.listStatus(new Path(stagingFolder.getPath + "/f2")) + for (i <- 0 until f2FileList.length) { + val reader = + HFile.createReader(fs, f2FileList(i).getPath, new CacheConfig(config), true, config) + assert(reader.getTrailer.getCompressionCodec().getName.equals("none")) + assert(reader.getDataBlockEncoding.name().equals("NONE")) + } + + assert(2 == f2FileList.length) + + val conn = ConnectionFactory.createConnection(config) + + val load = new LoadIncrementalHFiles(config) + val table = conn.getTable(TableName.valueOf(tableName)) + try { + load.doBulkLoad( + new Path(stagingFolder.getPath), + conn.getAdmin, + table, + conn.getRegionLocator(TableName.valueOf(tableName))) + + val cells5 = table.get(new Get(Bytes.toBytes("5"))).listCells() + assert(cells5.size == 1) + assert(Bytes.toString(CellUtil.cloneValue(cells5.get(0))).equals("foo3")) + assert(Bytes.toString(CellUtil.cloneFamily(cells5.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells5.get(0))).equals("a")) + + val cells4 = table.get(new Get(Bytes.toBytes("4"))).listCells() + assert(cells4.size == 2) + assert(Bytes.toString(CellUtil.cloneValue(cells4.get(0))).equals("foo.1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells4.get(1))).equals("foo.2")) + assert(Bytes.toString(CellUtil.cloneFamily(cells4.get(1))).equals("f2")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells4.get(1))).equals("b")) + + val cells3 = table.get(new Get(Bytes.toBytes("3"))).listCells() + assert(cells3.size == 3) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(0))).equals("foo2.c")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(1))).equals("foo2.b")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(1))).equals("f2")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(1))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells3.get(2))).equals("foo2.a")) + assert(Bytes.toString(CellUtil.cloneFamily(cells3.get(2))).equals("f2")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells3.get(2))).equals("b")) + + val cells2 = table.get(new Get(Bytes.toBytes("2"))).listCells() + assert(cells2.size == 2) + assert(Bytes.toString(CellUtil.cloneValue(cells2.get(0))).equals("bar.1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(0))).equals("a")) + assert(Bytes.toString(CellUtil.cloneValue(cells2.get(1))).equals("bar.2")) + assert(Bytes.toString(CellUtil.cloneFamily(cells2.get(1))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells2.get(1))).equals("b")) + + val cells1 = table.get(new Get(Bytes.toBytes("1"))).listCells() + assert(cells1.size == 1) + assert(Bytes.toString(CellUtil.cloneValue(cells1.get(0))).equals("foo1")) + assert(Bytes.toString(CellUtil.cloneFamily(cells1.get(0))).equals("f1")) + assert(Bytes.toString(CellUtil.cloneQualifier(cells1.get(0))).equals("a")) + + } finally { + table.close() + val admin = ConnectionFactory.createConnection(config).getAdmin + try { + admin.disableTable(TableName.valueOf(tableName)) + admin.deleteTable(TableName.valueOf(tableName)) + } finally { + admin.close() + } + fs.delete(new Path(stagingFolder.getPath), true) + + testFolder.delete() + + } + } +} diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala new file mode 100644 index 00000000..58e4d37f --- /dev/null +++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala @@ -0,0 +1,1365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.sql.{Date, Timestamp} +import org.apache.avro.Schema +import org.apache.avro.generic.GenericData +import org.apache.hadoop.hbase.{HBaseTestingUtility, TableName} +import org.apache.hadoop.hbase.client.{ConnectionFactory, Put} +import org.apache.hadoop.hbase.spark.datasources.{HBaseSparkConf, HBaseTableCatalog} +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.functions._ +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.xml.sax.SAXParseException + +import scala.collection.IterableOnce.iterableOnceExtensionMethods + +case class HBaseRecord( + col0: String, + col1: Boolean, + col2: Double, + col3: Float, + col4: Int, + col5: Long, + col6: Short, + col7: String, + col8: Byte) + +object HBaseRecord { + def apply(i: Int, t: String): HBaseRecord = { + val s = s"""row${"%03d".format(i)}""" + HBaseRecord( + s, + i % 2 == 0, + i.toDouble, + i.toFloat, + i, + i.toLong, + i.toShort, + s"String$i: $t", + i.toByte) + } +} + +case class AvroHBaseKeyRecord(col0: Array[Byte], col1: Array[Byte]) + +object AvroHBaseKeyRecord { + val schemaString = + s"""{"namespace": "example.avro", + | "type": "record", "name": "User", + | "fields": [ {"name": "name", "type": "string"}, + | {"name": "favorite_number", "type": ["int", "null"]}, + | {"name": "favorite_color", "type": ["string", "null"]} ] }""".stripMargin + + val avroSchema: Schema = { + val p = new Schema.Parser + p.parse(schemaString) + } + + def apply(i: Int): AvroHBaseKeyRecord = { + val user = new GenericData.Record(avroSchema); + user.put("name", s"name${"%03d".format(i)}") + user.put("favorite_number", i) + user.put("favorite_color", s"color${"%03d".format(i)}") + val avroByte = AvroSerdes.serialize(user, avroSchema) + AvroHBaseKeyRecord(avroByte, avroByte) + } +} + +class DefaultSourceSuite + extends AnyFunSuite + with BeforeAndAfterEach + with BeforeAndAfterAll + with Logging { + @transient var sc: SparkContext = null + var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility + + val t1TableName = "t1" + val t2TableName = "t2" + val t3TableName = "t3" + val columnFamily = "c" + + val timestamp = 1234567890000L + + var sqlContext: SQLContext = null + var df: DataFrame = null + + override def beforeAll() { + + TEST_UTIL.startMiniCluster + + logInfo(" - minicluster started") + try + TEST_UTIL.deleteTable(TableName.valueOf(t1TableName)) + catch { + case e: Exception => logInfo(" - no table " + t1TableName + " found") + } + try + TEST_UTIL.deleteTable(TableName.valueOf(t2TableName)) + catch { + case e: Exception => logInfo(" - no table " + t2TableName + " found") + } + try + TEST_UTIL.deleteTable(TableName.valueOf(t3TableName)) + catch { + case e: Exception => logInfo(" - no table " + t3TableName + " found") + } + + logInfo(" - creating table " + t1TableName) + TEST_UTIL.createTable(TableName.valueOf(t1TableName), Bytes.toBytes(columnFamily)) + logInfo(" - created table") + logInfo(" - creating table " + t2TableName) + TEST_UTIL.createTable(TableName.valueOf(t2TableName), Bytes.toBytes(columnFamily)) + logInfo(" - created table") + logInfo(" - creating table " + t3TableName) + TEST_UTIL.createTable(TableName.valueOf(t3TableName), Bytes.toBytes(columnFamily)) + logInfo(" - created table") + + val sparkConf = new SparkConf + sparkConf.set(HBaseSparkConf.QUERY_CACHEBLOCKS, "true") + sparkConf.set(HBaseSparkConf.QUERY_BATCHSIZE, "100") + sparkConf.set(HBaseSparkConf.QUERY_CACHEDROWS, "100") + + sc = new SparkContext("local", "test", sparkConf) + + val connection = ConnectionFactory.createConnection(TEST_UTIL.getConfiguration) + try { + val t1Table = connection.getTable(TableName.valueOf(t1TableName)) + + try { + var put = new Put(Bytes.toBytes("get1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(1)) + t1Table.put(put) + put = new Put(Bytes.toBytes("get2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("4")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(4)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("z"), Bytes.toBytes("FOO")) + t1Table.put(put) + put = new Put(Bytes.toBytes("get3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(8)) + t1Table.put(put) + put = new Put(Bytes.toBytes("get4")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo4")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("10")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(10)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("z"), Bytes.toBytes("BAR")) + t1Table.put(put) + put = new Put(Bytes.toBytes("get5")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo5")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(8)) + t1Table.put(put) + } finally { + t1Table.close() + } + + val t2Table = connection.getTable(TableName.valueOf(t2TableName)) + + try { + var put = new Put(Bytes.toBytes(1)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(1)) + t2Table.put(put) + put = new Put(Bytes.toBytes(2)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("4")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(4)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("z"), Bytes.toBytes("FOO")) + t2Table.put(put) + put = new Put(Bytes.toBytes(3)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(8)) + t2Table.put(put) + put = new Put(Bytes.toBytes(4)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo4")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("10")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(10)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("z"), Bytes.toBytes("BAR")) + t2Table.put(put) + put = new Put(Bytes.toBytes(5)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo5")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("i"), Bytes.toBytes(8)) + t2Table.put(put) + } finally { + t2Table.close() + } + + val t3Table = connection.getTable(TableName.valueOf(t3TableName)) + + try { + val put = new Put(Bytes.toBytes("row")) + put.addColumn( + Bytes.toBytes(columnFamily), + Bytes.toBytes("binary"), + Array(1.toByte, 2.toByte, 3.toByte)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("boolean"), Bytes.toBytes(true)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("byte"), Array(127.toByte)) + put.addColumn( + Bytes.toBytes(columnFamily), + Bytes.toBytes("short"), + Bytes.toBytes(32767.toShort)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("int"), Bytes.toBytes(1000000)) + put.addColumn( + Bytes.toBytes(columnFamily), + Bytes.toBytes("long"), + Bytes.toBytes(10000000000L)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("float"), Bytes.toBytes(0.5f)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("double"), Bytes.toBytes(0.125)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("date"), Bytes.toBytes(timestamp)) + put.addColumn( + Bytes.toBytes(columnFamily), + Bytes.toBytes("timestamp"), + Bytes.toBytes(timestamp)) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("string"), Bytes.toBytes("string")) + t3Table.put(put) + } finally { + t3Table.close() + } + } finally { + connection.close() + } + + def hbaseTable1Catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"} + |} + |}""".stripMargin + + new HBaseContext(sc, TEST_UTIL.getConfiguration) + sqlContext = new SQLContext(sc) + + df = sqlContext.load( + "org.apache.hadoop.hbase.spark", + Map(HBaseTableCatalog.tableCatalog -> hbaseTable1Catalog)) + + df.registerTempTable("hbaseTable1") + + def hbaseTable2Catalog = s"""{ + |"table":{"namespace":"default", "name":"t2"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"int"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"} + |} + |}""".stripMargin + + df = sqlContext.load( + "org.apache.hadoop.hbase.spark", + Map(HBaseTableCatalog.tableCatalog -> hbaseTable2Catalog)) + + df.registerTempTable("hbaseTable2") + } + + override def afterAll() { + TEST_UTIL.deleteTable(TableName.valueOf(t1TableName)) + logInfo("shuting down minicluster") + TEST_UTIL.shutdownMiniCluster() + + sc.stop() + } + + override def beforeEach(): Unit = { + DefaultSourceStaticUtils.lastFiveExecutionRules.clear() + } + + /** + * A example of query three fields and also only using rowkey points for the filter + */ + test("Test rowKey point only rowKey query") { + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " + + "WHERE " + + "(KEY_FIELD = 'get1' or KEY_FIELD = 'get2' or KEY_FIELD = 'get3')") + .take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 3) + + assert( + executionRules.dynamicLogicExpression.toExpressionString.equals( + "( ( KEY_FIELD == 0 OR KEY_FIELD == 1 ) OR KEY_FIELD == 2 )")) + + assert(executionRules.rowKeyFilter.points.size == 3) + assert(executionRules.rowKeyFilter.ranges.size == 0) + } + + /** + * A example of query three fields and also only using rowkey points for the filter, + * some rowkey points are duplicate. + */ + test("Test rowKey point only rowKey query, which contains duplicate rowkey") { + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " + + "WHERE " + + "(KEY_FIELD = 'get1' or KEY_FIELD = 'get2' or KEY_FIELD = 'get1')") + .take(10) + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + assert(results.length == 2) + assert( + executionRules.dynamicLogicExpression.toExpressionString.equals( + "( KEY_FIELD == 0 OR KEY_FIELD == 1 )")) + assert(executionRules.rowKeyFilter.points.size == 2) + assert(executionRules.rowKeyFilter.ranges.size == 0) + } + + /** + * A example of query three fields and also only using cell points for the filter + */ + test("Test cell point only rowKey query") { + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " + + "WHERE " + + "(B_FIELD = '4' or B_FIELD = '10' or A_FIELD = 'foo1')") + .take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 3) + + assert( + executionRules.dynamicLogicExpression.toExpressionString.equals( + "( ( B_FIELD == 0 OR B_FIELD == 1 ) OR A_FIELD == 2 )")) + } + + /** + * A example of a OR merge between to ranges the result is one range + * Also an example of less then and greater then + */ + test("Test two range rowKey query") { + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " + + "WHERE " + + "( KEY_FIELD < 'get2' or KEY_FIELD > 'get3')") + .take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 3) + + assert( + executionRules.dynamicLogicExpression.toExpressionString.equals( + "( KEY_FIELD < 0 OR KEY_FIELD > 1 )")) + + assert(executionRules.rowKeyFilter.points.size == 0) + assert(executionRules.rowKeyFilter.ranges.size == 2) + + val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0) + assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes(""))) + assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes("get2"))) + assert(scanRange1.isLowerBoundEqualTo) + assert(!scanRange1.isUpperBoundEqualTo) + + val scanRange2 = executionRules.rowKeyFilter.ranges.toList(1) + assert(Bytes.equals(scanRange2.lowerBound, Bytes.toBytes("get3"))) + assert(scanRange2.upperBound == null) + assert(!scanRange2.isLowerBoundEqualTo) + assert(scanRange2.isUpperBoundEqualTo) + } + + /** + * A example of a OR merge between to ranges the result is one range + * Also an example of less then and greater then + * + * This example makes sure the code works for a int rowKey + */ + test("Test two range rowKey query where the rowKey is Int and there is a range over lap") { + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable2 " + + "WHERE " + + "( KEY_FIELD < 4 or KEY_FIELD > 2)") + .take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert( + executionRules.dynamicLogicExpression.toExpressionString.equals( + "( KEY_FIELD < 0 OR KEY_FIELD > 1 )")) + + assert(executionRules.rowKeyFilter.points.size == 0) + assert(executionRules.rowKeyFilter.ranges.size == 2) + assert(results.length == 5) + } + + /** + * A example of a OR merge between to ranges the result is two ranges + * Also an example of less then and greater then + * + * This example makes sure the code works for a int rowKey + */ + test("Test two range rowKey query where the rowKey is Int and the ranges don't over lap") { + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable2 " + + "WHERE " + + "( KEY_FIELD < 2 or KEY_FIELD > 4)") + .take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert( + executionRules.dynamicLogicExpression.toExpressionString.equals( + "( KEY_FIELD < 0 OR KEY_FIELD > 1 )")) + + assert(executionRules.rowKeyFilter.points.size == 0) + + assert(executionRules.rowKeyFilter.ranges.size == 3) + + val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0) + assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes(2))) + assert(scanRange1.isLowerBoundEqualTo) + assert(!scanRange1.isUpperBoundEqualTo) + + val scanRange2 = executionRules.rowKeyFilter.ranges.toList(1) + assert(scanRange2.isUpperBoundEqualTo) + + assert(results.length == 2) + } + + /** + * A example of a AND merge between to ranges the result is one range + * Also an example of less then and equal to and greater then and equal to + */ + test("Test one combined range rowKey query") { + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " + + "WHERE " + + "(KEY_FIELD <= 'get3' and KEY_FIELD >= 'get2')") + .take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 2) + + val expr = executionRules.dynamicLogicExpression.toExpressionString + assert(expr.equals("( ( KEY_FIELD isNotNull AND KEY_FIELD <= 0 ) AND KEY_FIELD >= 1 )"), expr) + + assert(executionRules.rowKeyFilter.points.size == 0) + assert(executionRules.rowKeyFilter.ranges.size == 1) + + val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0) + assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes("get2"))) + assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes("get3"))) + assert(scanRange1.isLowerBoundEqualTo) + assert(scanRange1.isUpperBoundEqualTo) + + } + + /** + * Do a select with no filters + */ + test("Test select only query") { + + val results = df.select("KEY_FIELD").take(10) + assert(results.length == 5) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(executionRules.dynamicLogicExpression == null) + + } + + /** + * A complex query with one point and one range for both the + * rowKey and the a column + */ + test("Test SQL point and range combo") { + val results = sqlContext + .sql( + "SELECT KEY_FIELD FROM hbaseTable1 " + + "WHERE " + + "(KEY_FIELD = 'get1' and B_FIELD < '3') or " + + "(KEY_FIELD >= 'get3' and B_FIELD = '8')") + .take(5) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert( + executionRules.dynamicLogicExpression.toExpressionString.equals( + "( ( KEY_FIELD == 0 AND B_FIELD < 1 ) OR " + + "( KEY_FIELD >= 2 AND B_FIELD == 3 ) )")) + + assert(executionRules.rowKeyFilter.points.size == 1) + assert(executionRules.rowKeyFilter.ranges.size == 1) + + val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0) + assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes("get3"))) + assert(scanRange1.upperBound == null) + assert(scanRange1.isLowerBoundEqualTo) + assert(scanRange1.isUpperBoundEqualTo) + + assert(results.length == 3) + } + + /** + * A complex query with two complex ranges that doesn't merge into one + */ + test("Test two complete range non merge rowKey query") { + + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable2 " + + "WHERE " + + "( KEY_FIELD >= 1 and KEY_FIELD <= 2) or" + + "( KEY_FIELD > 3 and KEY_FIELD <= 5)") + .take(10) + + assert(results.length == 4) + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + assert( + executionRules.dynamicLogicExpression.toExpressionString.equals( + "( ( KEY_FIELD >= 0 AND KEY_FIELD <= 1 ) OR " + + "( KEY_FIELD > 2 AND KEY_FIELD <= 3 ) )")) + + assert(executionRules.rowKeyFilter.points.size == 0) + assert(executionRules.rowKeyFilter.ranges.size == 2) + + val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0) + assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes(1))) + assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes(2))) + assert(scanRange1.isLowerBoundEqualTo) + assert(scanRange1.isUpperBoundEqualTo) + + val scanRange2 = executionRules.rowKeyFilter.ranges.toList(1) + assert(Bytes.equals(scanRange2.lowerBound, Bytes.toBytes(3))) + assert(Bytes.equals(scanRange2.upperBound, Bytes.toBytes(5))) + assert(!scanRange2.isLowerBoundEqualTo) + assert(scanRange2.isUpperBoundEqualTo) + + } + + /** + * A complex query with two complex ranges that does merge into one + */ + test("Test two complete range merge rowKey query") { + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " + + "WHERE " + + "( KEY_FIELD >= 'get1' and KEY_FIELD <= 'get2') or" + + "( KEY_FIELD > 'get3' and KEY_FIELD <= 'get5')") + .take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 4) + + assert( + executionRules.dynamicLogicExpression.toExpressionString.equals( + "( ( KEY_FIELD >= 0 AND KEY_FIELD <= 1 ) OR " + + "( KEY_FIELD > 2 AND KEY_FIELD <= 3 ) )")) + + assert(executionRules.rowKeyFilter.points.size == 0) + assert(executionRules.rowKeyFilter.ranges.size == 2) + + val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0) + assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes("get1"))) + assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes("get2"))) + assert(scanRange1.isLowerBoundEqualTo) + assert(scanRange1.isUpperBoundEqualTo) + + val scanRange2 = executionRules.rowKeyFilter.ranges.toList(1) + assert(Bytes.equals(scanRange2.lowerBound, Bytes.toBytes("get3"))) + assert(Bytes.equals(scanRange2.upperBound, Bytes.toBytes("get5"))) + assert(!scanRange2.isLowerBoundEqualTo) + assert(scanRange2.isUpperBoundEqualTo) + } + + test("Test OR logic with a one RowKey and One column") { + + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " + + "WHERE " + + "( KEY_FIELD >= 'get1' or A_FIELD <= 'foo2') or" + + "( KEY_FIELD > 'get3' or B_FIELD <= '4')") + .take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 5) + + assert( + executionRules.dynamicLogicExpression.toExpressionString.equals( + "( ( KEY_FIELD >= 0 OR A_FIELD <= 1 ) OR " + + "( KEY_FIELD > 2 OR B_FIELD <= 3 ) )")) + + assert(executionRules.rowKeyFilter.points.size == 0) + assert(executionRules.rowKeyFilter.ranges.size == 1) + + val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0) + // This is the main test for 14406 + // Because the key is joined through a or with a qualifier + // There is no filter on the rowKey + assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes(""))) + assert(scanRange1.upperBound == null) + assert(scanRange1.isLowerBoundEqualTo) + assert(scanRange1.isUpperBoundEqualTo) + } + + test("Test OR logic with a two columns") { + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " + + "WHERE " + + "( B_FIELD > '4' or A_FIELD <= 'foo2') or" + + "( A_FIELD > 'foo2' or B_FIELD < '4')") + .take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 5) + + assert( + executionRules.dynamicLogicExpression.toExpressionString.equals( + "( ( B_FIELD > 0 OR A_FIELD <= 1 ) OR " + + "( A_FIELD > 2 OR B_FIELD < 3 ) )")) + + assert(executionRules.rowKeyFilter.points.size == 0) + assert(executionRules.rowKeyFilter.ranges.size == 1) + + val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0) + assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes(""))) + assert(scanRange1.upperBound == null) + assert(scanRange1.isLowerBoundEqualTo) + assert(scanRange1.isUpperBoundEqualTo) + + } + + test("Test single RowKey Or Column logic") { + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " + + "WHERE " + + "( KEY_FIELD >= 'get4' or A_FIELD <= 'foo2' )") + .take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 4) + + assert( + executionRules.dynamicLogicExpression.toExpressionString.equals( + "( KEY_FIELD >= 0 OR A_FIELD <= 1 )")) + + assert(executionRules.rowKeyFilter.points.size == 0) + assert(executionRules.rowKeyFilter.ranges.size == 1) + + val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0) + assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes(""))) + assert(scanRange1.upperBound == null) + assert(scanRange1.isLowerBoundEqualTo) + assert(scanRange1.isUpperBoundEqualTo) + } + + test("Test Rowkey And with complex logic (HBASE-26863)") { + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTable1 " + + "WHERE " + + "( KEY_FIELD >= 'get1' AND KEY_FIELD <= 'get3' ) AND (A_FIELD = 'foo1' OR B_FIELD = '8')") + .take(10) + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + assert(results.length == 2) + + assert( + executionRules.dynamicLogicExpression.toExpressionString + == "( ( ( KEY_FIELD isNotNull AND KEY_FIELD >= 0 ) AND KEY_FIELD <= 1 ) AND ( A_FIELD == 2 OR B_FIELD == 3 ) )") + + assert(executionRules.rowKeyFilter.points.size == 0) + assert(executionRules.rowKeyFilter.ranges.size == 1) + + val scanRange1 = executionRules.rowKeyFilter.ranges.toList(0) + assert(Bytes.equals(scanRange1.lowerBound, Bytes.toBytes("get1"))) + assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes("get3"))) + assert(scanRange1.isLowerBoundEqualTo) + assert(scanRange1.isUpperBoundEqualTo) + } + + test("Test table that doesn't exist") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1NotThere"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"c", "type":"string"} + |} + |}""".stripMargin + + intercept[Exception] { + df = sqlContext.load( + "org.apache.hadoop.hbase.spark", + Map(HBaseTableCatalog.tableCatalog -> catalog)) + + df.registerTempTable("hbaseNonExistingTmp") + + sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseNonExistingTmp " + + "WHERE " + + "( KEY_FIELD >= 'get1' and KEY_FIELD <= 'get3') or" + + "( KEY_FIELD > 'get3' and KEY_FIELD <= 'get5')") + .count() + } + DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + } + + test("Test table with column that doesn't exist") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"}, + |"C_FIELD":{"cf":"c", "col":"c", "type":"string"} + |} + |}""".stripMargin + df = sqlContext.load( + "org.apache.hadoop.hbase.spark", + Map(HBaseTableCatalog.tableCatalog -> catalog)) + + df.registerTempTable("hbaseFactColumnTmp") + + val result = sqlContext.sql( + "SELECT KEY_FIELD, " + + "B_FIELD, A_FIELD FROM hbaseFactColumnTmp") + + assert(result.count() == 5) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + assert(executionRules.dynamicLogicExpression == null) + + } + + test("Test table with INT column") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"}, + |"I_FIELD":{"cf":"c", "col":"i", "type":"int"} + |} + |}""".stripMargin + df = sqlContext.load( + "org.apache.hadoop.hbase.spark", + Map(HBaseTableCatalog.tableCatalog -> catalog)) + + df.registerTempTable("hbaseIntTmp") + + val result = sqlContext.sql( + "SELECT KEY_FIELD, B_FIELD, I_FIELD FROM hbaseIntTmp" + + " where I_FIELD > 4 and I_FIELD < 10") + + val localResult = result.take(5) + + assert(localResult.length == 2) + assert(localResult(0).getInt(2) == 8) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + val expr = executionRules.dynamicLogicExpression.toExpressionString + logInfo(expr) + assert(expr.equals("( ( I_FIELD isNotNull AND I_FIELD > 0 ) AND I_FIELD < 1 )"), expr) + + } + + test("Test table with INT column defined at wrong type") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"}, + |"I_FIELD":{"cf":"c", "col":"i", "type":"string"} + |} + |}""".stripMargin + df = sqlContext.load( + "org.apache.hadoop.hbase.spark", + Map(HBaseTableCatalog.tableCatalog -> catalog)) + + df.registerTempTable("hbaseIntWrongTypeTmp") + + val result = sqlContext.sql( + "SELECT KEY_FIELD, " + + "B_FIELD, I_FIELD FROM hbaseIntWrongTypeTmp") + + val localResult = result.take(10) + assert(localResult.length == 5) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + assert(executionRules.dynamicLogicExpression == null) + + assert(localResult(0).getString(2).length == 4) + assert(localResult(0).getString(2).charAt(0).toByte == 0) + assert(localResult(0).getString(2).charAt(1).toByte == 0) + assert(localResult(0).getString(2).charAt(2).toByte == 0) + assert(localResult(0).getString(2).charAt(3).toByte == 1) + } + + test("Test bad column type") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"FOOBAR"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"I_FIELD":{"cf":"c", "col":"i", "type":"string"} + |} + |}""".stripMargin + intercept[Exception] { + df = sqlContext.load( + "org.apache.hadoop.hbase.spark", + Map(HBaseTableCatalog.tableCatalog -> catalog)) + + df.registerTempTable("hbaseIntWrongTypeTmp") + + val result = sqlContext.sql( + "SELECT KEY_FIELD, " + + "B_FIELD, I_FIELD FROM hbaseIntWrongTypeTmp") + + val localResult = result.take(10) + assert(localResult.length == 5) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + assert(executionRules.dynamicLogicExpression == null) + + } + } + + test("Test HBaseSparkConf matching") { + val df = sqlContext.load( + "org.apache.hadoop.hbase.spark.HBaseTestSource", + Map( + "cacheSize" -> "100", + "batchNum" -> "100", + "blockCacheingEnable" -> "true", + "rowNum" -> "10")) + assert(df.count() == 10) + + val df1 = sqlContext.load( + "org.apache.hadoop.hbase.spark.HBaseTestSource", + Map( + "cacheSize" -> "1000", + "batchNum" -> "100", + "blockCacheingEnable" -> "true", + "rowNum" -> "10")) + intercept[Exception] { + assert(df1.count() == 10) + } + + val df2 = sqlContext.load( + "org.apache.hadoop.hbase.spark.HBaseTestSource", + Map( + "cacheSize" -> "100", + "batchNum" -> "1000", + "blockCacheingEnable" -> "true", + "rowNum" -> "10")) + intercept[Exception] { + assert(df2.count() == 10) + } + + val df3 = sqlContext.load( + "org.apache.hadoop.hbase.spark.HBaseTestSource", + Map( + "cacheSize" -> "100", + "batchNum" -> "100", + "blockCacheingEnable" -> "false", + "rowNum" -> "10")) + intercept[Exception] { + assert(df3.count() == 10) + } + } + + test("Test table with sparse column") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"}, + |"Z_FIELD":{"cf":"c", "col":"z", "type":"string"} + |} + |}""".stripMargin + df = sqlContext.load( + "org.apache.hadoop.hbase.spark", + Map(HBaseTableCatalog.tableCatalog -> catalog)) + + df.registerTempTable("hbaseZTmp") + + val result = sqlContext.sql("SELECT KEY_FIELD, B_FIELD, Z_FIELD FROM hbaseZTmp") + + val localResult = result.take(10) + assert(localResult.length == 5) + + assert(localResult(0).getString(2) == null) + assert(localResult(1).getString(2) == "FOO") + assert(localResult(2).getString(2) == null) + assert(localResult(3).getString(2) == "BAR") + assert(localResult(4).getString(2) == null) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + assert(executionRules.dynamicLogicExpression == null) + } + + test("Test with column logic disabled") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"}, + |"Z_FIELD":{"cf":"c", "col":"z", "type":"string"} + |} + |}""".stripMargin + df = sqlContext.load( + "org.apache.hadoop.hbase.spark", + Map( + HBaseTableCatalog.tableCatalog -> catalog, + HBaseSparkConf.PUSHDOWN_COLUMN_FILTER -> "false")) + + df.registerTempTable("hbaseNoPushDownTmp") + + val results = sqlContext + .sql( + "SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseNoPushDownTmp " + + "WHERE " + + "(KEY_FIELD <= 'get3' and KEY_FIELD >= 'get2')") + .take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 2) + + assert(executionRules.dynamicLogicExpression == null) + } + + test("Test mapping") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t3"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"BINARY_FIELD":{"cf":"c", "col":"binary", "type":"binary"}, + |"BOOLEAN_FIELD":{"cf":"c", "col":"boolean", "type":"boolean"}, + |"BYTE_FIELD":{"cf":"c", "col":"byte", "type":"byte"}, + |"SHORT_FIELD":{"cf":"c", "col":"short", "type":"short"}, + |"INT_FIELD":{"cf":"c", "col":"int", "type":"int"}, + |"LONG_FIELD":{"cf":"c", "col":"long", "type":"long"}, + |"FLOAT_FIELD":{"cf":"c", "col":"float", "type":"float"}, + |"DOUBLE_FIELD":{"cf":"c", "col":"double", "type":"double"}, + |"DATE_FIELD":{"cf":"c", "col":"date", "type":"date"}, + |"TIMESTAMP_FIELD":{"cf":"c", "col":"timestamp", "type":"timestamp"}, + |"STRING_FIELD":{"cf":"c", "col":"string", "type":"string"} + |} + |}""".stripMargin + df = sqlContext.load( + "org.apache.hadoop.hbase.spark", + Map(HBaseTableCatalog.tableCatalog -> catalog)) + + df.registerTempTable("hbaseTestMapping") + + val results = sqlContext + .sql( + "SELECT binary_field, boolean_field, " + + "byte_field, short_field, int_field, long_field, " + + "float_field, double_field, date_field, timestamp_field, " + + "string_field FROM hbaseTestMapping") + .collect() + + assert(results.length == 1) + + val result = results(0) + + System.out.println("row: " + result) + System.out.println("0: " + result.get(0)) + System.out.println("1: " + result.get(1)) + System.out.println("2: " + result.get(2)) + System.out.println("3: " + result.get(3)) + + assert( + result.get(0).asInstanceOf[Array[Byte]].sameElements(Array(1.toByte, 2.toByte, 3.toByte))) + assert(result.get(1) == true) + assert(result.get(2) == 127) + assert(result.get(3) == 32767) + assert(result.get(4) == 1000000) + assert(result.get(5) == 10000000000L) + assert(result.get(6) == 0.5) + assert(result.get(7) == 0.125) + // sql date stores only year, month and day, so checking it is within a day + assert(Math.abs(result.get(8).asInstanceOf[Date].getTime - timestamp) <= 86400000) + assert(result.get(9).asInstanceOf[Timestamp].getTime == timestamp) + assert(result.get(10) == "string") + } + + def writeCatalog = s"""{ + |"table":{"namespace":"default", "name":"table1"}, + |"rowkey":"key", + |"columns":{ + |"col0":{"cf":"rowkey", "col":"key", "type":"string"}, + |"col1":{"cf":"cf1", "col":"col1", "type":"boolean"}, + |"col2":{"cf":"cf1", "col":"col2", "type":"double"}, + |"col3":{"cf":"cf3", "col":"col3", "type":"float"}, + |"col4":{"cf":"cf3", "col":"col4", "type":"int"}, + |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"}, + |"col6":{"cf":"cf6", "col":"col6", "type":"smallint"}, + |"col7":{"cf":"cf7", "col":"col7", "type":"string"}, + |"col8":{"cf":"cf8", "col":"col8", "type":"tinyint"} + |} + |}""".stripMargin + + def withCatalog(cat: String): DataFrame = { + sqlContext.read + .options(Map(HBaseTableCatalog.tableCatalog -> cat)) + .format("org.apache.hadoop.hbase.spark") + .load() + } + + test("populate table") { + val sql = sqlContext + import sql.implicits._ + val data = (0 to 255).map { i => HBaseRecord(i, "extra") } + sc.parallelize(data) + .toDF + .write + .options( + Map(HBaseTableCatalog.tableCatalog -> writeCatalog, HBaseTableCatalog.newTable -> "5")) + .format("org.apache.hadoop.hbase.spark") + .save() + } + + test("empty column") { + val df = withCatalog(writeCatalog) + df.registerTempTable("table0") + val c = sqlContext.sql("select count(1) from table0").rdd.collect()(0)(0).asInstanceOf[Long] + assert(c == 256) + } + + test("full query") { + val df = withCatalog(writeCatalog) + df.show() + assert(df.count() == 256) + } + + test("filtered query0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(writeCatalog) + val s = df + .filter($"col0" <= "row005") + .select("col0", "col1") + s.show() + assert(s.count() == 6) + } + + test("filtered query01") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(writeCatalog) + val s = df + .filter(col("col0").startsWith("row00")) + .select("col0", "col1") + s.show() + assert(s.count() == 10) + } + + test("startsWith filtered query 1") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(writeCatalog) + val s = df + .filter(col("col0").startsWith("row005")) + .select("col0", "col1") + s.show() + assert(s.count() == 1) + } + + test("startsWith filtered query 2") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(writeCatalog) + val s = df + .filter(col("col0").startsWith("row")) + .select("col0", "col1") + s.show() + assert(s.count() == 256) + } + + test("startsWith filtered query 3") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(writeCatalog) + val s = df + .filter(col("col0").startsWith("row19")) + .select("col0", "col1") + s.show() + assert(s.count() == 10) + } + + test("startsWith filtered query 4") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(writeCatalog) + val s = df + .filter(col("col0").startsWith("")) + .select("col0", "col1") + s.show() + assert(s.count() == 256) + } + + test("Timestamp semantics") { + val sql = sqlContext + import sql.implicits._ + + // There's already some data in here from recently. Let's throw something in + // from 1993 which we can include/exclude and add some data with the implicit (now) timestamp. + // Then we should be able to cross-section it and only get points in between, get the most recent view + // and get an old view. + val oldMs = 754869600000L + val startMs = System.currentTimeMillis() + val oldData = (0 to 100).map { i => HBaseRecord(i, "old") } + val newData = (200 to 255).map { i => HBaseRecord(i, "new") } + + sc.parallelize(oldData) + .toDF + .write + .options( + Map( + HBaseTableCatalog.tableCatalog -> writeCatalog, + HBaseTableCatalog.tableName -> "5", + HBaseSparkConf.TIMESTAMP -> oldMs.toString)) + .format("org.apache.hadoop.hbase.spark") + .save() + sc.parallelize(newData) + .toDF + .write + .options( + Map(HBaseTableCatalog.tableCatalog -> writeCatalog, HBaseTableCatalog.tableName -> "5")) + .format("org.apache.hadoop.hbase.spark") + .save() + + // Test specific timestamp -- Full scan, Timestamp + val individualTimestamp = sqlContext.read + .options( + Map( + HBaseTableCatalog.tableCatalog -> writeCatalog, + HBaseSparkConf.TIMESTAMP -> oldMs.toString)) + .format("org.apache.hadoop.hbase.spark") + .load() + assert(individualTimestamp.count() == 101) + + // Test getting everything -- Full Scan, No range + val everything = sqlContext.read + .options(Map(HBaseTableCatalog.tableCatalog -> writeCatalog)) + .format("org.apache.hadoop.hbase.spark") + .load() + assert(everything.count() == 256) + // Test getting everything -- Pruned Scan, TimeRange + val element50 = everything.where(col("col0") === lit("row050")).select("col7").collect()(0)(0) + assert(element50 == "String50: extra") + val element200 = everything.where(col("col0") === lit("row200")).select("col7").collect()(0)(0) + assert(element200 == "String200: new") + + // Test Getting old stuff -- Full Scan, TimeRange + val oldRange = sqlContext.read + .options( + Map( + HBaseTableCatalog.tableCatalog -> writeCatalog, + HBaseSparkConf.TIMERANGE_START -> "0", + HBaseSparkConf.TIMERANGE_END -> (oldMs + 100).toString)) + .format("org.apache.hadoop.hbase.spark") + .load() + assert(oldRange.count() == 101) + // Test Getting old stuff -- Pruned Scan, TimeRange + val oldElement50 = oldRange.where(col("col0") === lit("row050")).select("col7").collect()(0)(0) + assert(oldElement50 == "String50: old") + + // Test Getting middle stuff -- Full Scan, TimeRange + val middleRange = sqlContext.read + .options( + Map( + HBaseTableCatalog.tableCatalog -> writeCatalog, + HBaseSparkConf.TIMERANGE_START -> "0", + HBaseSparkConf.TIMERANGE_END -> (startMs + 100).toString)) + .format("org.apache.hadoop.hbase.spark") + .load() + assert(middleRange.count() == 256) + // Test Getting middle stuff -- Pruned Scan, TimeRange + val middleElement200 = + middleRange.where(col("col0") === lit("row200")).select("col7").collect()(0)(0) + assert(middleElement200 == "String200: extra") + } + + // catalog for insertion + def avroWriteCatalog = s"""{ + |"table":{"namespace":"default", "name":"avrotable"}, + |"rowkey":"key", + |"columns":{ + |"col0":{"cf":"rowkey", "col":"key", "type":"binary"}, + |"col1":{"cf":"cf1", "col":"col1", "type":"binary"} + |} + |}""".stripMargin + + // catalog for read + def avroCatalog = s"""{ + |"table":{"namespace":"default", "name":"avrotable"}, + |"rowkey":"key", + |"columns":{ + |"col0":{"cf":"rowkey", "col":"key", "avro":"avroSchema"}, + |"col1":{"cf":"cf1", "col":"col1", "avro":"avroSchema"} + |} + |}""".stripMargin + + // for insert to another table + def avroCatalogInsert = s"""{ + |"table":{"namespace":"default", "name":"avrotableInsert"}, + |"rowkey":"key", + |"columns":{ + |"col0":{"cf":"rowkey", "col":"key", "avro":"avroSchema"}, + |"col1":{"cf":"cf1", "col":"col1", "avro":"avroSchema"} + |} + |}""".stripMargin + + def withAvroCatalog(cat: String): DataFrame = { + sqlContext.read + .options( + Map( + "avroSchema" -> AvroHBaseKeyRecord.schemaString, + HBaseTableCatalog.tableCatalog -> avroCatalog)) + .format("org.apache.hadoop.hbase.spark") + .load() + } + + test("populate avro table") { + val sql = sqlContext + import sql.implicits._ + + val data = (0 to 255).map { i => AvroHBaseKeyRecord(i) } + sc.parallelize(data) + .toDF + .write + .options( + Map(HBaseTableCatalog.tableCatalog -> avroWriteCatalog, HBaseTableCatalog.newTable -> "5")) + .format("org.apache.hadoop.hbase.spark") + .save() + } + + test("avro empty column") { + val df = withAvroCatalog(avroCatalog) + df.registerTempTable("avrotable") + val c = sqlContext.sql("select count(1) from avrotable").rdd.collect()(0)(0).asInstanceOf[Long] + assert(c == 256) + } + + test("avro full query") { + val df = withAvroCatalog(avroCatalog) + df.show() + df.printSchema() + assert(df.count() == 256) + } + + test("avro serialization and deserialization query") { + val df = withAvroCatalog(avroCatalog) + df.write + .options( + Map( + "avroSchema" -> AvroHBaseKeyRecord.schemaString, + HBaseTableCatalog.tableCatalog -> avroCatalogInsert, + HBaseTableCatalog.newTable -> "5")) + .format("org.apache.hadoop.hbase.spark") + .save() + val newDF = withAvroCatalog(avroCatalogInsert) + newDF.show() + newDF.printSchema() + assert(newDF.count() == 256) + } + + test("avro filtered query") { + val sql = sqlContext + import sql.implicits._ + val df = withAvroCatalog(avroCatalog) + val r = df + .filter($"col1.name" === "name005" || $"col1.name" <= "name005") + .select("col0", "col1.favorite_color", "col1.favorite_number") + r.show() + assert(r.count() == 6) + } + + test("avro Or filter") { + val sql = sqlContext + import sql.implicits._ + val df = withAvroCatalog(avroCatalog) + val s = df + .filter($"col1.name" <= "name005" || $"col1.name".contains("name007")) + .select("col0", "col1.favorite_color", "col1.favorite_number") + s.show() + assert(s.count() == 7) + } + + test("test create HBaseRelation with new context throws SAXParseException") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1NotThere"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"c", "type":"string"} + |} + |}""".stripMargin + try { + HBaseRelation( + Map(HBaseTableCatalog.tableCatalog -> catalog, HBaseSparkConf.USE_HBASECONTEXT -> "false"), + None)(sqlContext) + } catch { + case e: Throwable => + if (e.getCause.isInstanceOf[SAXParseException]) { + fail("SAXParseException due to configuration loading empty resource") + } else { + println("Failed due to some other exception, ignore " + e.getMessage) + } + } + } +} diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala new file mode 100644 index 00000000..7e913cec --- /dev/null +++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala @@ -0,0 +1,402 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.util +import org.apache.hadoop.hbase.spark.datasources.{HBaseSparkConf, JavaBytesEncoder} +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +class DynamicLogicExpressionSuite + extends AnyFunSuite + with BeforeAndAfterEach + with BeforeAndAfterAll + with Logging { + + val encoder = JavaBytesEncoder.create(HBaseSparkConf.DEFAULT_QUERY_ENCODER) + + test("Basic And Test") { + val leftLogic = new LessThanLogicExpression("Col1", 0) + leftLogic.setEncoder(encoder) + val rightLogic = new GreaterThanLogicExpression("Col1", 1) + rightLogic.setEncoder(encoder) + val andLogic = new AndLogicExpression(leftLogic, rightLogic) + + val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]() + + columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(10))) + val valueFromQueryValueArray = new Array[Array[Byte]](2) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) + assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10) + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + val expressionString = andLogic.toExpressionString + + assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )")) + + val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) + assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10) + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + } + + test("Basic OR Test") { + val leftLogic = new LessThanLogicExpression("Col1", 0) + leftLogic.setEncoder(encoder) + val rightLogic = new GreaterThanLogicExpression("Col1", 1) + rightLogic.setEncoder(encoder) + val OrLogic = new OrLogicExpression(leftLogic, rightLogic) + + val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]() + + columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(10))) + val valueFromQueryValueArray = new Array[Array[Byte]](2) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) + assert(OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) + assert(OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10) + assert(OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10) + assert(!OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + val expressionString = OrLogic.toExpressionString + + assert(expressionString.equals("( Col1 < 0 OR Col1 > 1 )")) + + val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) + assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) + assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10) + assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10) + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + } + + test("Basic Command Test") { + val greaterLogic = new GreaterThanLogicExpression("Col1", 0) + greaterLogic.setEncoder(encoder) + val greaterAndEqualLogic = new GreaterThanOrEqualLogicExpression("Col1", 0) + greaterAndEqualLogic.setEncoder(encoder) + val lessLogic = new LessThanLogicExpression("Col1", 0) + lessLogic.setEncoder(encoder) + val lessAndEqualLogic = new LessThanOrEqualLogicExpression("Col1", 0) + lessAndEqualLogic.setEncoder(encoder) + val equalLogic = new EqualLogicExpression("Col1", 0, false) + val notEqualLogic = new EqualLogicExpression("Col1", 0, true) + val passThrough = new PassThroughLogicExpression + + val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]() + columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(10))) + val valueFromQueryValueArray = new Array[Array[Byte]](1) + + // great than + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + assert(!greaterLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 20) + assert(!greaterLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + // great than and equal + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 5) + assert(greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + assert(greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 20) + assert(!greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + // less than + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + assert(!lessLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 5) + assert(!lessLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + // less than and equal + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 20) + assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 20) + assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + // equal too + valueFromQueryValueArray(0) = Bytes.toBytes(10) + assert(equalLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = Bytes.toBytes(5) + assert(!equalLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + // not equal too + valueFromQueryValueArray(0) = Bytes.toBytes(10) + assert(!notEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = Bytes.toBytes(5) + assert(notEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + // pass through + valueFromQueryValueArray(0) = Bytes.toBytes(10) + assert(passThrough.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = Bytes.toBytes(5) + assert(passThrough.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + } + + test("Double Type") { + val leftLogic = new LessThanLogicExpression("Col1", 0) + leftLogic.setEncoder(encoder) + val rightLogic = new GreaterThanLogicExpression("Col1", 1) + rightLogic.setEncoder(encoder) + val andLogic = new AndLogicExpression(leftLogic, rightLogic) + + val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]() + + columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(-4.0d))) + val valueFromQueryValueArray = new Array[Array[Byte]](2) + valueFromQueryValueArray(0) = encoder.encode(DoubleType, 15.0d) + valueFromQueryValueArray(1) = encoder.encode(DoubleType, -5.0d) + assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(DoubleType, 10.0d) + valueFromQueryValueArray(1) = encoder.encode(DoubleType, -1.0d) + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(DoubleType, -10.0d) + valueFromQueryValueArray(1) = encoder.encode(DoubleType, -20.0d) + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + val expressionString = andLogic.toExpressionString + // Note that here 0 and 1 is index, instead of value. + assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )")) + + val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder) + valueFromQueryValueArray(0) = encoder.encode(DoubleType, 15.0d) + valueFromQueryValueArray(1) = encoder.encode(DoubleType, -5.0d) + assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(DoubleType, 10.0d) + valueFromQueryValueArray(1) = encoder.encode(DoubleType, -1.0d) + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(DoubleType, -10.0d) + valueFromQueryValueArray(1) = encoder.encode(DoubleType, -20.0d) + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + } + + test("Float Type") { + val leftLogic = new LessThanLogicExpression("Col1", 0) + leftLogic.setEncoder(encoder) + val rightLogic = new GreaterThanLogicExpression("Col1", 1) + rightLogic.setEncoder(encoder) + val andLogic = new AndLogicExpression(leftLogic, rightLogic) + + val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]() + + columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(-4.0f))) + val valueFromQueryValueArray = new Array[Array[Byte]](2) + valueFromQueryValueArray(0) = encoder.encode(FloatType, 15.0f) + valueFromQueryValueArray(1) = encoder.encode(FloatType, -5.0f) + assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(FloatType, 10.0f) + valueFromQueryValueArray(1) = encoder.encode(FloatType, -1.0f) + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(FloatType, -10.0f) + valueFromQueryValueArray(1) = encoder.encode(FloatType, -20.0f) + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + val expressionString = andLogic.toExpressionString + // Note that here 0 and 1 is index, instead of value. + assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )")) + + val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder) + valueFromQueryValueArray(0) = encoder.encode(FloatType, 15.0f) + valueFromQueryValueArray(1) = encoder.encode(FloatType, -5.0f) + assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(FloatType, 10.0f) + valueFromQueryValueArray(1) = encoder.encode(FloatType, -1.0f) + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(FloatType, -10.0f) + valueFromQueryValueArray(1) = encoder.encode(FloatType, -20.0f) + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + } + + test("Long Type") { + val greaterLogic = new GreaterThanLogicExpression("Col1", 0) + greaterLogic.setEncoder(encoder) + val greaterAndEqualLogic = new GreaterThanOrEqualLogicExpression("Col1", 0) + greaterAndEqualLogic.setEncoder(encoder) + val lessLogic = new LessThanLogicExpression("Col1", 0) + lessLogic.setEncoder(encoder) + val lessAndEqualLogic = new LessThanOrEqualLogicExpression("Col1", 0) + lessAndEqualLogic.setEncoder(encoder) + val equalLogic = new EqualLogicExpression("Col1", 0, false) + val notEqualLogic = new EqualLogicExpression("Col1", 0, true) + + val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]() + columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(10L))) + val valueFromQueryValueArray = new Array[Array[Byte]](1) + + // great than + valueFromQueryValueArray(0) = encoder.encode(LongType, 10L) + assert(!greaterLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(LongType, 20L) + assert(!greaterLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + // great than and equal + valueFromQueryValueArray(0) = encoder.encode(LongType, 5L) + assert(greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(LongType, 10L) + assert(greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(LongType, 20L) + assert(!greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + // less than + valueFromQueryValueArray(0) = encoder.encode(LongType, 10L) + assert(!lessLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(LongType, 5L) + assert(!lessLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + // less than and equal + valueFromQueryValueArray(0) = encoder.encode(LongType, 20L) + assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(LongType, 20L) + assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(LongType, 10L) + assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + // equal too + valueFromQueryValueArray(0) = Bytes.toBytes(10L) + assert(equalLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = Bytes.toBytes(5L) + assert(!equalLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + // not equal too + valueFromQueryValueArray(0) = Bytes.toBytes(10L) + assert(!notEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = Bytes.toBytes(5L) + assert(notEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + } + + test("String Type") { + val leftLogic = new LessThanLogicExpression("Col1", 0) + leftLogic.setEncoder(encoder) + val rightLogic = new GreaterThanLogicExpression("Col1", 1) + rightLogic.setEncoder(encoder) + val andLogic = new AndLogicExpression(leftLogic, rightLogic) + + val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]() + + columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes("row005"))) + val valueFromQueryValueArray = new Array[Array[Byte]](2) + valueFromQueryValueArray(0) = encoder.encode(StringType, "row015") + valueFromQueryValueArray(1) = encoder.encode(StringType, "row000") + assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(StringType, "row004") + valueFromQueryValueArray(1) = encoder.encode(StringType, "row000") + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(StringType, "row020") + valueFromQueryValueArray(1) = encoder.encode(StringType, "row010") + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + val expressionString = andLogic.toExpressionString + // Note that here 0 and 1 is index, instead of value. + assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )")) + + val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder) + valueFromQueryValueArray(0) = encoder.encode(StringType, "row015") + valueFromQueryValueArray(1) = encoder.encode(StringType, "row000") + assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(StringType, "row004") + valueFromQueryValueArray(1) = encoder.encode(StringType, "row000") + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(StringType, "row020") + valueFromQueryValueArray(1) = encoder.encode(StringType, "row010") + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + } + + test("Boolean Type") { + val leftLogic = new LessThanLogicExpression("Col1", 0) + leftLogic.setEncoder(encoder) + val rightLogic = new GreaterThanLogicExpression("Col1", 1) + rightLogic.setEncoder(encoder) + + val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]() + + columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(false))) + val valueFromQueryValueArray = new Array[Array[Byte]](2) + valueFromQueryValueArray(0) = encoder.encode(BooleanType, true) + valueFromQueryValueArray(1) = encoder.encode(BooleanType, false) + assert(leftLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + assert(!rightLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + } +} diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseCatalogSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseCatalogSuite.scala new file mode 100644 index 00000000..f0a4d238 --- /dev/null +++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseCatalogSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.spark.datasources.{DataTypeParserWrapper, DoubleSerDes, HBaseTableCatalog} +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +class HBaseCatalogSuite + extends AnyFunSuite + with BeforeAndAfterEach + with BeforeAndAfterAll + with Logging { + + val map = s"""MAP>""" + val array = s"""array>""" + val arrayMap = s"""MAp>""" + val catalog = s"""{ + |"table":{"namespace":"default", "name":"htable"}, + |"rowkey":"key1:key2", + |"columns":{ + |"col1":{"cf":"rowkey", "col":"key1", "type":"string"}, + |"col2":{"cf":"rowkey", "col":"key2", "type":"double"}, + |"col3":{"cf":"cf1", "col":"col2", "type":"binary"}, + |"col4":{"cf":"cf1", "col":"col3", "type":"timestamp"}, + |"col5":{"cf":"cf1", "col":"col4", "type":"double", "serdes":"${classOf[ + DoubleSerDes].getName}"}, + |"col6":{"cf":"cf1", "col":"col5", "type":"$map"}, + |"col7":{"cf":"cf1", "col":"col6", "type":"$array"}, + |"col8":{"cf":"cf1", "col":"col7", "type":"$arrayMap"}, + |"col9":{"cf":"cf1", "col":"col8", "type":"date"}, + |"col10":{"cf":"cf1", "col":"col9", "type":"timestamp"} + |} + |}""".stripMargin + val parameters = Map(HBaseTableCatalog.tableCatalog -> catalog) + val t = HBaseTableCatalog(parameters) + + def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { + test(s"parse ${dataTypeString.replace("\n", "")}") { + assert(DataTypeParserWrapper.parse(dataTypeString) == expectedDataType) + } + } + test("basic") { + assert(t.getField("col1").isRowKey == true) + assert(t.getPrimaryKey == "key1") + assert(t.getField("col3").dt == BinaryType) + assert(t.getField("col4").dt == TimestampType) + assert(t.getField("col5").dt == DoubleType) + assert(t.getField("col5").serdes != None) + assert(t.getField("col4").serdes == None) + assert(t.getField("col1").isRowKey) + assert(t.getField("col2").isRowKey) + assert(!t.getField("col3").isRowKey) + assert(t.getField("col2").length == Bytes.SIZEOF_DOUBLE) + assert(t.getField("col1").length == -1) + assert(t.getField("col8").length == -1) + assert(t.getField("col9").dt == DateType) + assert(t.getField("col10").dt == TimestampType) + } + + checkDataType(map, t.getField("col6").dt) + + checkDataType(array, t.getField("col7").dt) + + checkDataType(arrayMap, t.getField("col8").dt) + + test("convert") { + val m = Map( + "hbase.columns.mapping" -> + "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD DOUBLE c:b, C_FIELD BINARY c:c,", + "hbase.table" -> "NAMESPACE:TABLE") + val map = HBaseTableCatalog.convert(m) + val json = map.get(HBaseTableCatalog.tableCatalog).get + val parameters = Map(HBaseTableCatalog.tableCatalog -> json) + val t = HBaseTableCatalog(parameters) + assert(t.namespace == "NAMESPACE") + assert(t.name == "TABLE") + assert(t.getField("KEY_FIELD").isRowKey) + assert(DataTypeParserWrapper.parse("STRING") == t.getField("A_FIELD").dt) + assert(!t.getField("A_FIELD").isRowKey) + assert(DataTypeParserWrapper.parse("DOUBLE") == t.getField("B_FIELD").dt) + assert(DataTypeParserWrapper.parse("BINARY") == t.getField("C_FIELD").dt) + } + + test("compatibility") { + val m = Map( + "hbase.columns.mapping" -> + "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD DOUBLE c:b, C_FIELD BINARY c:c,", + "hbase.table" -> "t1") + val t = HBaseTableCatalog(m) + assert(t.namespace == "default") + assert(t.name == "t1") + assert(t.getField("KEY_FIELD").isRowKey) + assert(DataTypeParserWrapper.parse("STRING") == t.getField("A_FIELD").dt) + assert(!t.getField("A_FIELD").isRowKey) + assert(DataTypeParserWrapper.parse("DOUBLE") == t.getField("B_FIELD").dt) + assert(DataTypeParserWrapper.parse("BINARY") == t.getField("C_FIELD").dt) + } +} diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCacheSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCacheSuite.scala new file mode 100644 index 00000000..309880f7 --- /dev/null +++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCacheSuite.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.util.concurrent.ExecutorService +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.{Admin, BufferedMutator, BufferedMutatorParams, Connection, RegionLocator, Table, TableBuilder} + +import org.scalatest.funsuite.AnyFunSuite + +import scala.util.Random + +case class HBaseConnectionKeyMocker(confId: Int) extends HBaseConnectionKey(null) { + override def hashCode: Int = { + confId + } + + override def equals(obj: Any): Boolean = { + if (!obj.isInstanceOf[HBaseConnectionKeyMocker]) + false + else + confId == obj.asInstanceOf[HBaseConnectionKeyMocker].confId + } +} + +class ConnectionMocker extends Connection { + var isClosed: Boolean = false + + def getRegionLocator(tableName: TableName): RegionLocator = null + def getConfiguration: Configuration = null + override def getTable(tableName: TableName): Table = null + override def getTable(tableName: TableName, pool: ExecutorService): Table = null + def getBufferedMutator(params: BufferedMutatorParams): BufferedMutator = null + def getBufferedMutator(tableName: TableName): BufferedMutator = null + def getAdmin: Admin = null + def getTableBuilder(tableName: TableName, pool: ExecutorService): TableBuilder = null + + def close(): Unit = { + if (isClosed) + throw new IllegalStateException() + isClosed = true + } + + def isAborted: Boolean = true + def abort(why: String, e: Throwable) = {} + + /* Without override, we can also compile it against HBase 2.1. */ + /* override */ + def clearRegionLocationCache(): Unit = {} +} + +class HBaseConnectionCacheSuite extends AnyFunSuite with Logging { + /* + * These tests must be performed sequentially as they operate with an + * unique running thread and resource. + * + * It looks there's no way to tell FunSuite to do so, so making those + * test cases normal functions which are called sequentially in a single + * test case. + */ + test("all test cases") { + testBasic() + testWithPressureWithoutClose() + testWithPressureWithClose() + } + + def cleanEnv() { + HBaseConnectionCache.connectionMap.synchronized { + HBaseConnectionCache.connectionMap.clear() + HBaseConnectionCache.cacheStat.numActiveConnections = 0 + HBaseConnectionCache.cacheStat.numActualConnectionsCreated = 0 + HBaseConnectionCache.cacheStat.numTotalRequests = 0 + } + } + + def testBasic() { + cleanEnv() + HBaseConnectionCache.setTimeout(1 * 1000) + + val connKeyMocker1 = new HBaseConnectionKeyMocker(1) + val connKeyMocker1a = new HBaseConnectionKeyMocker(1) + val connKeyMocker2 = new HBaseConnectionKeyMocker(2) + + val c1 = HBaseConnectionCache + .getConnection(connKeyMocker1, new ConnectionMocker) + + assert(HBaseConnectionCache.connectionMap.size == 1) + assert(HBaseConnectionCache.getStat.numTotalRequests == 1) + assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 1) + assert(HBaseConnectionCache.getStat.numActiveConnections == 1) + + val c1a = HBaseConnectionCache + .getConnection(connKeyMocker1a, new ConnectionMocker) + + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size == 1) + assert(HBaseConnectionCache.getStat.numTotalRequests == 2) + assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 1) + assert(HBaseConnectionCache.getStat.numActiveConnections == 1) + } + + val c2 = HBaseConnectionCache + .getConnection(connKeyMocker2, new ConnectionMocker) + + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size == 2) + assert(HBaseConnectionCache.getStat.numTotalRequests == 3) + assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 2) + assert(HBaseConnectionCache.getStat.numActiveConnections == 2) + } + + c1.close() + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size == 2) + assert(HBaseConnectionCache.getStat.numActiveConnections == 2) + } + + c1a.close() + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size == 2) + assert(HBaseConnectionCache.getStat.numActiveConnections == 2) + } + + Thread.sleep(3 * 1000) // Leave housekeeping thread enough time + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size == 1) + assert( + HBaseConnectionCache.connectionMap.iterator + .next() + ._1 + .asInstanceOf[HBaseConnectionKeyMocker] + .confId == 2) + assert(HBaseConnectionCache.getStat.numActiveConnections == 1) + } + + c2.close() + } + + def testWithPressureWithoutClose() { + cleanEnv() + + class TestThread extends Runnable { + override def run() { + for (i <- 0 to 999) { + val c = HBaseConnectionCache.getConnection( + new HBaseConnectionKeyMocker(Random.nextInt(10)), + new ConnectionMocker) + } + } + } + + HBaseConnectionCache.setTimeout(500) + val threads: Array[Thread] = new Array[Thread](100) + for (i <- 0 to 99) { + threads.update(i, new Thread(new TestThread())) + threads(i).run() + } + try { + threads.foreach { x => x.join() } + } catch { + case e: InterruptedException => println(e.getMessage) + } + + Thread.sleep(1000) + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size == 10) + assert(HBaseConnectionCache.getStat.numTotalRequests == 100 * 1000) + assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 10) + assert(HBaseConnectionCache.getStat.numActiveConnections == 10) + + var totalRc: Int = 0 + HBaseConnectionCache.connectionMap.foreach { x => totalRc += x._2.refCount } + assert(totalRc == 100 * 1000) + HBaseConnectionCache.connectionMap.foreach { + x => + { + x._2.refCount = 0 + x._2.timestamp = System.currentTimeMillis() - 1000 + } + } + } + Thread.sleep(1000) + assert(HBaseConnectionCache.connectionMap.size == 0) + assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 10) + assert(HBaseConnectionCache.getStat.numActiveConnections == 0) + } + + def testWithPressureWithClose() { + cleanEnv() + + class TestThread extends Runnable { + override def run() { + for (i <- 0 to 999) { + val c = HBaseConnectionCache.getConnection( + new HBaseConnectionKeyMocker(Random.nextInt(10)), + new ConnectionMocker) + Thread.`yield`() + c.close() + } + } + } + + HBaseConnectionCache.setTimeout(3 * 1000) + val threads: Array[Thread] = new Array[Thread](100) + for (i <- threads.indices) { + threads.update(i, new Thread(new TestThread())) + threads(i).run() + } + try { + threads.foreach { x => x.join() } + } catch { + case e: InterruptedException => println(e.getMessage) + } + + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size == 10) + assert(HBaseConnectionCache.getStat.numTotalRequests == 100 * 1000) + assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 10) + assert(HBaseConnectionCache.getStat.numActiveConnections == 10) + } + + Thread.sleep(6 * 1000) + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size == 0) + assert(HBaseConnectionCache.getStat.numActualConnectionsCreated == 10) + assert(HBaseConnectionCache.getStat.numActiveConnections == 0) + } + } +} diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala new file mode 100644 index 00000000..a45988be --- /dev/null +++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.{CellUtil, HBaseTestingUtility, TableName} +import org.apache.hadoop.hbase.client._ +import org.apache.hadoop.hbase.filter.FirstKeyOnlyFilter +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.{SparkContext, SparkException} +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +class HBaseContextSuite + extends AnyFunSuite + with BeforeAndAfterEach + with BeforeAndAfterAll + with Logging { + + @transient var sc: SparkContext = null + var hbaseContext: HBaseContext = null + var TEST_UTIL = new HBaseTestingUtility + + val tableName = "t1" + val columnFamily = "c" + + override def beforeAll() { + TEST_UTIL.startMiniCluster() + logInfo(" - minicluster started") + + try { + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + } catch { + case e: Exception => + logInfo(" - no table " + tableName + " found") + } + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily)) + logInfo(" - created table") + + val envMap = Map[String, String](("Xmx", "512m")) + + sc = new SparkContext("local", "test", null, Nil, envMap) + + val config = TEST_UTIL.getConfiguration + hbaseContext = new HBaseContext(sc, config) + } + + override def afterAll() { + logInfo("shuting down minicluster") + TEST_UTIL.shutdownMiniCluster() + logInfo(" - minicluster shut down") + TEST_UTIL.cleanupTestDir() + sc.stop() + } + + test("bulkput to test HBase client") { + val config = TEST_UTIL.getConfiguration + val rdd = sc.parallelize( + Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( + ( + Bytes.toBytes("1"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")))), + ( + Bytes.toBytes("2"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("foo2")))), + ( + Bytes.toBytes("3"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("c"), Bytes.toBytes("foo3")))), + ( + Bytes.toBytes("4"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("d"), Bytes.toBytes("foo")))), + ( + Bytes.toBytes("5"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("e"), Bytes.toBytes("bar")))))) + + hbaseContext.bulkPut[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( + rdd, + TableName.valueOf(tableName), + (putRecord) => { + val put = new Put(putRecord._1) + putRecord._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3)) + put + }) + + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + val foo1 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("1"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")))) + assert(foo1 == "foo1") + + val foo2 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("2"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("b")))) + assert(foo2 == "foo2") + + val foo3 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("3"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("c")))) + assert(foo3 == "foo3") + + val foo4 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("4"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("d")))) + assert(foo4 == "foo") + + val foo5 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("5"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("e")))) + assert(foo5 == "bar") + + } finally { + table.close() + connection.close() + } + } + + test("bulkDelete to test HBase client") { + val config = TEST_UTIL.getConfiguration + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + var put = new Put(Bytes.toBytes("delete1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + table.put(put) + put = new Put(Bytes.toBytes("delete2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + table.put(put) + put = new Put(Bytes.toBytes("delete3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + table.put(put) + + val rdd = sc.parallelize(Array[Array[Byte]](Bytes.toBytes("delete1"), Bytes.toBytes("delete3"))) + + hbaseContext.bulkDelete[Array[Byte]]( + rdd, + TableName.valueOf(tableName), + putRecord => new Delete(putRecord), + 4) + + assert( + table + .get(new Get(Bytes.toBytes("delete1"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")) == null) + assert( + table + .get(new Get(Bytes.toBytes("delete3"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")) == null) + assert( + Bytes + .toString( + CellUtil.cloneValue(table + .get(new Get(Bytes.toBytes("delete2"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")))) + .equals("foo2")) + } finally { + table.close() + connection.close() + } + } + + test("bulkGet to test HBase client") { + val config = TEST_UTIL.getConfiguration + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + var put = new Put(Bytes.toBytes("get1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + table.put(put) + put = new Put(Bytes.toBytes("get2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + table.put(put) + put = new Put(Bytes.toBytes("get3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + table.put(put) + } finally { + table.close() + connection.close() + } + val rdd = sc.parallelize( + Array[Array[Byte]]( + Bytes.toBytes("get1"), + Bytes.toBytes("get2"), + Bytes.toBytes("get3"), + Bytes.toBytes("get4"))) + + val getRdd = hbaseContext.bulkGet[Array[Byte], String]( + TableName.valueOf(tableName), + 2, + rdd, + record => { + new Get(record) + }, + (result: Result) => { + if (result.listCells() != null) { + val it = result.listCells().iterator() + val B = new StringBuilder + + B.append(Bytes.toString(result.getRow) + ":") + + while (it.hasNext) { + val cell = it.next() + val q = Bytes.toString(CellUtil.cloneQualifier(cell)) + if (q.equals("counter")) { + B.append("(" + q + "," + Bytes.toLong(CellUtil.cloneValue(cell)) + ")") + } else { + B.append("(" + q + "," + Bytes.toString(CellUtil.cloneValue(cell)) + ")") + } + } + "" + B.toString + } else { + "" + } + }) + val getArray = getRdd.collect() + + assert(getArray.length == 4) + assert(getArray.contains("get1:(a,foo1)")) + assert(getArray.contains("get2:(a,foo2)")) + assert(getArray.contains("get3:(a,foo3)")) + + } + + test("BulkGet failure test: bad table") { + val config = TEST_UTIL.getConfiguration + + val rdd = sc.parallelize( + Array[Array[Byte]]( + Bytes.toBytes("get1"), + Bytes.toBytes("get2"), + Bytes.toBytes("get3"), + Bytes.toBytes("get4"))) + + intercept[SparkException] { + try { + val getRdd = hbaseContext.bulkGet[Array[Byte], String]( + TableName.valueOf("badTableName"), + 2, + rdd, + record => { + new Get(record) + }, + (result: Result) => "1") + + getRdd.collect() + + fail("We should have failed and not reached this line") + } catch { + case ex: SparkException => { + assert( + ex.getMessage.contains( + "org.apache.hadoop.hbase.client.RetriesExhaustedWithDetailsException")) + throw ex + } + } + } + } + + test("BulkGet failure test: bad column") { + + val config = TEST_UTIL.getConfiguration + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + var put = new Put(Bytes.toBytes("get1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + table.put(put) + put = new Put(Bytes.toBytes("get2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + table.put(put) + put = new Put(Bytes.toBytes("get3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + table.put(put) + } finally { + table.close() + connection.close() + } + + val rdd = sc.parallelize( + Array[Array[Byte]]( + Bytes.toBytes("get1"), + Bytes.toBytes("get2"), + Bytes.toBytes("get3"), + Bytes.toBytes("get4"))) + + val getRdd = hbaseContext.bulkGet[Array[Byte], String]( + TableName.valueOf(tableName), + 2, + rdd, + record => { + new Get(record) + }, + (result: Result) => { + if (result.listCells() != null) { + val cellValue = + result.getColumnLatestCell(Bytes.toBytes("c"), Bytes.toBytes("bad_column")) + if (cellValue == null) "null" else "bad" + } else "noValue" + }) + var nullCounter = 0 + var noValueCounter = 0 + getRdd + .collect() + .foreach( + r => { + if ("null".equals(r)) nullCounter += 1 + else if ("noValue".equals(r)) noValueCounter += 1 + }) + assert(nullCounter == 3) + assert(noValueCounter == 1) + } + + test("distributedScan to test HBase client") { + val config = TEST_UTIL.getConfiguration + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + var put = new Put(Bytes.toBytes("scan1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + table.put(put) + put = new Put(Bytes.toBytes("scan2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + table.put(put) + put = new Put(Bytes.toBytes("scan2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("foo-2")) + table.put(put) + put = new Put(Bytes.toBytes("scan3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + table.put(put) + put = new Put(Bytes.toBytes("scan4")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + table.put(put) + put = new Put(Bytes.toBytes("scan5")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + table.put(put) + } finally { + table.close() + connection.close() + } + + val scan = new Scan() + val filter = new FirstKeyOnlyFilter() + scan.setCaching(100) + scan.setStartRow(Bytes.toBytes("scan2")) + scan.setStopRow(Bytes.toBytes("scan4_")) + scan.setFilter(filter) + + val scanRdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan) + + try { + val scanList = scanRdd + .map(r => r._1.copyBytes()) + .collect() + assert(scanList.length == 3) + var cnt = 0 + scanRdd + .map(r => r._2.listCells().size()) + .collect() + .foreach( + l => { + cnt += l + }) + // the number of cells returned would be 4 without the Filter + assert(cnt == 3); + } catch { + case ex: Exception => ex.printStackTrace() + } + } +} diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseDStreamFunctionsSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseDStreamFunctionsSuite.scala new file mode 100644 index 00000000..d7e99267 --- /dev/null +++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseDStreamFunctionsSuite.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.{CellUtil, HBaseTestingUtility, TableName} +import org.apache.hadoop.hbase.client._ +import org.apache.hadoop.hbase.spark.HBaseDStreamFunctions._ +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Milliseconds, StreamingContext} +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +import scala.collection.mutable + +class HBaseDStreamFunctionsSuite + extends AnyFunSuite + with BeforeAndAfterEach + with BeforeAndAfterAll + with Logging { + @transient var sc: SparkContext = null + + var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility + + val tableName = "t1" + val columnFamily = "c" + + override def beforeAll() { + + TEST_UTIL.startMiniCluster() + + logInfo(" - minicluster started") + try + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + catch { + case e: Exception => logInfo(" - no table " + tableName + " found") + + } + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily)) + logInfo(" - created table") + + sc = new SparkContext("local", "test") + } + + override def afterAll() { + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + TEST_UTIL.shutdownMiniCluster() + sc.stop() + } + + test("bulkput to test HBase client") { + val config = TEST_UTIL.getConfiguration + val rdd1 = sc.parallelize( + Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( + ( + Bytes.toBytes("1"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")))), + ( + Bytes.toBytes("2"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("foo2")))), + ( + Bytes.toBytes("3"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("c"), Bytes.toBytes("foo3")))))) + + val rdd2 = sc.parallelize( + Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( + ( + Bytes.toBytes("4"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("d"), Bytes.toBytes("foo")))), + ( + Bytes.toBytes("5"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("e"), Bytes.toBytes("bar")))))) + + var isFinished = false + + val hbaseContext = new HBaseContext(sc, config) + val ssc = new StreamingContext(sc, Milliseconds(200)) + + val queue = mutable.Queue[RDD[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]]() + queue += rdd1 + queue += rdd2 + val dStream = ssc.queueStream(queue) + + dStream.hbaseBulkPut( + hbaseContext, + TableName.valueOf(tableName), + (putRecord) => { + val put = new Put(putRecord._1) + putRecord._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3)) + put + }) + + dStream.foreachRDD( + rdd => { + if (rdd.count() == 0) { + isFinished = true + } + }) + + ssc.start() + + while (!isFinished) { + Thread.sleep(100) + } + + ssc.stop(true, true) + + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + val foo1 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("1"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")))) + assert(foo1 == "foo1") + + val foo2 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("2"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("b")))) + assert(foo2 == "foo2") + + val foo3 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("3"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("c")))) + assert(foo3 == "foo3") + + val foo4 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("4"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("d")))) + assert(foo4 == "foo") + + val foo5 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("5"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("e")))) + assert(foo5 == "bar") + } finally { + table.close() + connection.close() + } + } + +} diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseRDDFunctionsSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseRDDFunctionsSuite.scala new file mode 100644 index 00000000..900a003d --- /dev/null +++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseRDDFunctionsSuite.scala @@ -0,0 +1,472 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.{CellUtil, HBaseTestingUtility, TableName} +import org.apache.hadoop.hbase.client._ +import org.apache.hadoop.hbase.spark.HBaseRDDFunctions._ +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.SparkContext +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +import scala.collection.mutable + +class HBaseRDDFunctionsSuite + extends AnyFunSuite + with BeforeAndAfterEach + with BeforeAndAfterAll + with Logging { + @transient var sc: SparkContext = null + var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility + + val tableName = "t1" + val columnFamily = "c" + + override def beforeAll() { + + TEST_UTIL.startMiniCluster + + logInfo(" - minicluster started") + try + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + catch { + case e: Exception => logInfo(" - no table " + tableName + " found") + + } + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily)) + logInfo(" - created table") + + sc = new SparkContext("local", "test") + } + + override def afterAll() { + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + logInfo("shuting down minicluster") + TEST_UTIL.shutdownMiniCluster() + + sc.stop() + } + + test("bulkput to test HBase client") { + val config = TEST_UTIL.getConfiguration + val rdd = sc.parallelize( + Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( + ( + Bytes.toBytes("1"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")))), + ( + Bytes.toBytes("2"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("foo2")))), + ( + Bytes.toBytes("3"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("c"), Bytes.toBytes("foo3")))), + ( + Bytes.toBytes("4"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("d"), Bytes.toBytes("foo")))), + ( + Bytes.toBytes("5"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("e"), Bytes.toBytes("bar")))))) + + val hbaseContext = new HBaseContext(sc, config) + + rdd.hbaseBulkPut( + hbaseContext, + TableName.valueOf(tableName), + (putRecord) => { + val put = new Put(putRecord._1) + putRecord._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3)) + put + }) + + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + val foo1 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("1"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")))) + assert(foo1 == "foo1") + + val foo2 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("2"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("b")))) + assert(foo2 == "foo2") + + val foo3 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("3"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("c")))) + assert(foo3 == "foo3") + + val foo4 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("4"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("d")))) + assert(foo4 == "foo") + + val foo5 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("5"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("e")))) + assert(foo5 == "bar") + } finally { + table.close() + connection.close() + } + } + + test("bulkDelete to test HBase client") { + val config = TEST_UTIL.getConfiguration + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + var put = new Put(Bytes.toBytes("delete1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + table.put(put) + put = new Put(Bytes.toBytes("delete2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + table.put(put) + put = new Put(Bytes.toBytes("delete3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + table.put(put) + + val rdd = sc.parallelize(Array[Array[Byte]](Bytes.toBytes("delete1"), Bytes.toBytes("delete3"))) + + val hbaseContext = new HBaseContext(sc, config) + + rdd.hbaseBulkDelete( + hbaseContext, + TableName.valueOf(tableName), + putRecord => new Delete(putRecord), + 4) + + assert( + table + .get(new Get(Bytes.toBytes("delete1"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")) == null) + assert( + table + .get(new Get(Bytes.toBytes("delete3"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")) == null) + assert( + Bytes + .toString( + CellUtil.cloneValue(table + .get(new Get(Bytes.toBytes("delete2"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")))) + .equals("foo2")) + } finally { + table.close() + connection.close() + } + + } + + test("bulkGet to test HBase client") { + val config = TEST_UTIL.getConfiguration + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + var put = new Put(Bytes.toBytes("get1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + table.put(put) + put = new Put(Bytes.toBytes("get2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + table.put(put) + put = new Put(Bytes.toBytes("get3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + table.put(put) + } finally { + table.close() + connection.close() + } + + val rdd = sc.parallelize( + Array[Array[Byte]]( + Bytes.toBytes("get1"), + Bytes.toBytes("get2"), + Bytes.toBytes("get3"), + Bytes.toBytes("get4"))) + val hbaseContext = new HBaseContext(sc, config) + + // Get with custom convert logic + val getRdd = rdd.hbaseBulkGet[String]( + hbaseContext, + TableName.valueOf(tableName), + 2, + record => { + new Get(record) + }, + (result: Result) => { + if (result.listCells() != null) { + val it = result.listCells().iterator() + val B = new StringBuilder + + B.append(Bytes.toString(result.getRow) + ":") + + while (it.hasNext) { + val cell = it.next + val q = Bytes.toString(CellUtil.cloneQualifier(cell)) + if (q.equals("counter")) { + B.append("(" + q + "," + Bytes.toLong(CellUtil.cloneValue(cell)) + ")") + } else { + B.append("(" + q + "," + Bytes.toString(CellUtil.cloneValue(cell)) + ")") + } + } + "" + B.toString + } else { + "" + } + }) + + val getArray = getRdd.collect() + + assert(getArray.length == 4) + assert(getArray.contains("get1:(a,foo1)")) + assert(getArray.contains("get2:(a,foo2)")) + assert(getArray.contains("get3:(a,foo3)")) + } + + test("bulkGet default converter to test HBase client") { + val config = TEST_UTIL.getConfiguration + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + var put = new Put(Bytes.toBytes("get1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + table.put(put) + put = new Put(Bytes.toBytes("get2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + table.put(put) + put = new Put(Bytes.toBytes("get3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + table.put(put) + } finally { + table.close() + connection.close() + } + + val rdd = sc.parallelize( + Array[Array[Byte]]( + Bytes.toBytes("get1"), + Bytes.toBytes("get2"), + Bytes.toBytes("get3"), + Bytes.toBytes("get4"))) + val hbaseContext = new HBaseContext(sc, config) + + val getRdd = rdd + .hbaseBulkGet( + hbaseContext, + TableName.valueOf("t1"), + 2, + record => { + new Get(record) + }) + .map( + (row) => { + if (row != null && row._2.listCells() != null) { + val it = row._2.listCells().iterator() + val B = new StringBuilder + + B.append(Bytes.toString(row._2.getRow) + ":") + + while (it.hasNext) { + val cell = it.next + val q = Bytes.toString(CellUtil.cloneQualifier(cell)) + if (q.equals("counter")) { + B.append("(" + q + "," + Bytes.toLong(CellUtil.cloneValue(cell)) + ")") + } else { + B.append("(" + q + "," + Bytes.toString(CellUtil.cloneValue(cell)) + ")") + } + } + "" + B.toString + } else { + "" + } + }) + + val getArray = getRdd.collect() + + assert(getArray.length == 4) + assert(getArray.contains("get1:(a,foo1)")) + assert(getArray.contains("get2:(a,foo2)")) + assert(getArray.contains("get3:(a,foo3)")) + } + + test("foreachPartition with puts to test HBase client") { + val config = TEST_UTIL.getConfiguration + val rdd = sc.parallelize( + Array[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( + ( + Bytes.toBytes("1foreach"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")))), + ( + Bytes.toBytes("2foreach"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("foo2")))), + ( + Bytes.toBytes("3foreach"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("c"), Bytes.toBytes("foo3")))), + ( + Bytes.toBytes("4foreach"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("d"), Bytes.toBytes("foo")))), + ( + Bytes.toBytes("5foreach"), + Array((Bytes.toBytes(columnFamily), Bytes.toBytes("e"), Bytes.toBytes("bar")))))) + + val hbaseContext = new HBaseContext(sc, config) + + rdd.hbaseForeachPartition( + hbaseContext, + (it, conn) => { + val bufferedMutator = conn.getBufferedMutator(TableName.valueOf("t1")) + it.foreach( + (putRecord) => { + val put = new Put(putRecord._1) + putRecord._2.foreach((putValue) => put.addColumn(putValue._1, putValue._2, putValue._3)) + bufferedMutator.mutate(put) + }) + bufferedMutator.flush() + bufferedMutator.close() + }) + + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + val foo1 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("1foreach"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("a")))) + assert(foo1 == "foo1") + + val foo2 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("2foreach"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("b")))) + assert(foo2 == "foo2") + + val foo3 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("3foreach"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("c")))) + assert(foo3 == "foo3") + + val foo4 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("4foreach"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("d")))) + assert(foo4 == "foo") + + val foo5 = Bytes.toString( + CellUtil.cloneValue( + table + .get(new Get(Bytes.toBytes("5"))) + .getColumnLatestCell(Bytes.toBytes(columnFamily), Bytes.toBytes("e")))) + assert(foo5 == "bar") + } finally { + table.close() + connection.close() + } + } + + test("mapPartitions with Get from test HBase client") { + val config = TEST_UTIL.getConfiguration + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + var put = new Put(Bytes.toBytes("get1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + table.put(put) + put = new Put(Bytes.toBytes("get2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + table.put(put) + put = new Put(Bytes.toBytes("get3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + table.put(put) + } finally { + table.close() + connection.close() + } + + val rdd = sc.parallelize( + Array[Array[Byte]]( + Bytes.toBytes("get1"), + Bytes.toBytes("get2"), + Bytes.toBytes("get3"), + Bytes.toBytes("get4"))) + val hbaseContext = new HBaseContext(sc, config) + + // Get with custom convert logic + val getRdd = rdd.hbaseMapPartitions( + hbaseContext, + (it, conn) => { + val table = conn.getTable(TableName.valueOf("t1")) + var res = mutable.ListBuffer[String]() + + it.foreach( + r => { + val get = new Get(r) + val result = table.get(get) + if (result.listCells != null) { + val it = result.listCells().iterator() + val B = new StringBuilder + + B.append(Bytes.toString(result.getRow) + ":") + + while (it.hasNext) { + val cell = it.next() + val q = Bytes.toString(CellUtil.cloneQualifier(cell)) + if (q.equals("counter")) { + B.append("(" + q + "," + Bytes.toLong(CellUtil.cloneValue(cell)) + ")") + } else { + B.append("(" + q + "," + Bytes.toString(CellUtil.cloneValue(cell)) + ")") + } + } + res += "" + B.toString + } else { + res += "" + } + }) + res.iterator + }) + + val getArray = getRdd.collect() + + assert(getArray.length == 4) + assert(getArray.contains("get1:(a,foo1)")) + assert(getArray.contains("get2:(a,foo2)")) + assert(getArray.contains("get3:(a,foo3)")) + } +} diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseTestSource.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseTestSource.scala new file mode 100644 index 00000000..a27876f1 --- /dev/null +++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/HBaseTestSource.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf +import org.apache.spark.SparkEnv +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +class HBaseTestSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + DummyScan( + parameters("cacheSize").toInt, + parameters("batchNum").toInt, + parameters("blockCacheingEnable").toBoolean, + parameters("rowNum").toInt)(sqlContext) + } +} + +case class DummyScan(cacheSize: Int, batchNum: Int, blockCachingEnable: Boolean, rowNum: Int)( + @transient val sqlContext: SQLContext) + extends BaseRelation + with TableScan { + private def sparkConf = SparkEnv.get.conf + override def schema: StructType = + StructType(StructField("i", IntegerType, nullable = false) :: Nil) + + override def buildScan(): RDD[Row] = sqlContext.sparkContext + .parallelize(0 until rowNum) + .map(Row(_)) + .map { + x => + if (sparkConf.getInt(HBaseSparkConf.QUERY_BATCHSIZE, -1) != batchNum || + sparkConf.getInt(HBaseSparkConf.QUERY_CACHEDROWS, -1) != cacheSize || + sparkConf.getBoolean(HBaseSparkConf.QUERY_CACHEBLOCKS, false) != blockCachingEnable) { + throw new Exception("HBase Spark configuration cannot be set properly") + } + x + } +} diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/PartitionFilterSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/PartitionFilterSuite.scala new file mode 100644 index 00000000..7718fa40 --- /dev/null +++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/PartitionFilterSuite.scala @@ -0,0 +1,533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.{HBaseTestingUtility, TableName} +import org.apache.hadoop.hbase.spark.datasources.{HBaseSparkConf, HBaseTableCatalog} +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +import scala.reflect.ClassTag + +case class FilterRangeRecord( + intCol0: Int, + boolCol1: Boolean, + doubleCol2: Double, + floatCol3: Float, + intCol4: Int, + longCol5: Long, + shortCol6: Short, + stringCol7: String, + byteCol8: Byte) + +object FilterRangeRecord { + def apply(i: Int): FilterRangeRecord = { + FilterRangeRecord( + if (i % 2 == 0) i else -i, + i % 2 == 0, + if (i % 2 == 0) i.toDouble else -i.toDouble, + i.toFloat, + if (i % 2 == 0) i else -i, + i.toLong, + i.toShort, + s"String$i extra", + i.toByte) + } +} + +class PartitionFilterSuite + extends AnyFunSuite + with BeforeAndAfterEach + with BeforeAndAfterAll + with Logging { + @transient var sc: SparkContext = null + var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility + + var sqlContext: SQLContext = null + var df: DataFrame = null + + def withCatalog(cat: String): DataFrame = { + sqlContext.read + .options(Map(HBaseTableCatalog.tableCatalog -> cat)) + .format("org.apache.hadoop.hbase.spark") + .load() + } + + override def beforeAll() { + + TEST_UTIL.startMiniCluster + val sparkConf = new SparkConf + sparkConf.set(HBaseSparkConf.QUERY_CACHEBLOCKS, "true") + sparkConf.set(HBaseSparkConf.QUERY_BATCHSIZE, "100") + sparkConf.set(HBaseSparkConf.QUERY_CACHEDROWS, "100") + + sc = new SparkContext("local", "test", sparkConf) + new HBaseContext(sc, TEST_UTIL.getConfiguration) + sqlContext = new SQLContext(sc) + } + + override def afterAll() { + logInfo("shutting down minicluster") + TEST_UTIL.shutdownMiniCluster() + + sc.stop() + } + + override def beforeEach(): Unit = { + DefaultSourceStaticUtils.lastFiveExecutionRules.clear() + } + + // The original raw data used for construct result set without going through + // data frame logic. It is used to verify the result set retrieved from data frame logic. + val rawResult = (0 until 32).map { i => FilterRangeRecord(i) } + + def collectToSet[T: ClassTag](df: DataFrame): Set[T] = { + df.collect().map(_.getAs[T](0)).toSet + } + val catalog = s"""{ + |"table":{"namespace":"default", "name":"rangeTable"}, + |"rowkey":"key", + |"columns":{ + |"intCol0":{"cf":"rowkey", "col":"key", "type":"int"}, + |"boolCol1":{"cf":"cf1", "col":"boolCol1", "type":"boolean"}, + |"doubleCol2":{"cf":"cf2", "col":"doubleCol2", "type":"double"}, + |"floatCol3":{"cf":"cf3", "col":"floatCol3", "type":"float"}, + |"intCol4":{"cf":"cf4", "col":"intCol4", "type":"int"}, + |"longCol5":{"cf":"cf5", "col":"longCol5", "type":"bigint"}, + |"shortCol6":{"cf":"cf6", "col":"shortCol6", "type":"smallint"}, + |"stringCol7":{"cf":"cf7", "col":"stringCol7", "type":"string"}, + |"byteCol8":{"cf":"cf8", "col":"byteCol8", "type":"tinyint"} + |} + |}""".stripMargin + + test("populate rangeTable") { + val sql = sqlContext + import sql.implicits._ + + sc.parallelize(rawResult) + .toDF + .write + .options(Map(HBaseTableCatalog.tableCatalog -> catalog, HBaseTableCatalog.newTable -> "5")) + .format("org.apache.hadoop.hbase.spark") + .save() + } + test("rangeTable full query") { + val df = withCatalog(catalog) + df.show + assert(df.count() === 32) + } + + /** + * expected result: only showing top 20 rows + * +-------+ + * |intCol0| + * +-------+ + * | -31 | + * | -29 | + * | -27 | + * | -25 | + * | -23 | + * | -21 | + * | -19 | + * | -17 | + * | -15 | + * | -13 | + * | -11 | + * | -9 | + * | -7 | + * | -5 | + * | -3 | + * | -1 | + * +---- + + */ + test("rangeTable rowkey less than 0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"intCol0" < 0).select($"intCol0") + s.show + // filter results without going through dataframe + val expected = rawResult.filter(_.intCol0 < 0).map(_.intCol0).toSet + // filter results going through dataframe + val result = collectToSet[Int](s) + assert(expected === result) + } + + /** + * expected result: only showing top 20 rows + * +-------+ + * |intCol4| + * +-------+ + * | -31 | + * | -29 | + * | -27 | + * | -25 | + * | -23 | + * | -21 | + * | -19 | + * | -17 | + * | -15 | + * | -13 | + * | -11 | + * | -9 | + * | -7 | + * | -5 | + * | -3 | + * | -1 | + * +-------+ + */ + test("rangeTable int col less than 0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"intCol4" < 0).select($"intCol4") + s.show + // filter results without going through dataframe + val expected = rawResult.filter(_.intCol4 < 0).map(_.intCol4).toSet + // filter results going through dataframe + val result = collectToSet[Int](s) + assert(expected === result) + } + + /** + * expected result: only showing top 20 rows + * +-----------+ + * | doubleCol2| + * +-----------+ + * | 0.0 | + * | 2.0 | + * |-31.0 | + * |-29.0 | + * |-27.0 | + * |-25.0 | + * |-23.0 | + * |-21.0 | + * |-19.0 | + * |-17.0 | + * |-15.0 | + * |-13.0 | + * |-11.0 | + * | -9.0 | + * | -7.0 | + * | -5.0 | + * | -3.0 | + * | -1.0 | + * +-----------+ + */ + test("rangeTable double col less than 0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"doubleCol2" < 3.0).select($"doubleCol2") + s.show + // filter results without going through dataframe + val expected = rawResult.filter(_.doubleCol2 < 3.0).map(_.doubleCol2).toSet + // filter results going through dataframe + val result = collectToSet[Double](s) + assert(expected === result) + } + + /** + * expected result: only showing top 20 rows + * +-------+ + * |intCol0| + * +-------+ + * | -31 | + * | -29 | + * | -27 | + * | -25 | + * | -23 | + * | -21 | + * | -19 | + * | -17 | + * | -15 | + * | -13 | + * | -11 | + * +-------+ + */ + test("rangeTable lessequal than -10") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"intCol0" <= -10).select($"intCol0") + s.show + // filter results without going through dataframe + val expected = rawResult.filter(_.intCol0 <= -10).map(_.intCol0).toSet + // filter results going through dataframe + val result = collectToSet[Int](s) + assert(expected === result) + } + + /** + * expected result: only showing top 20 rows + * +-------+ + * |intCol0| + * +----+ + * | -31 | + * | -29 | + * | -27 | + * | -25 | + * | -23 | + * | -21 | + * | -19 | + * | -17 | + * | -15 | + * | -13 | + * | -11 | + * | -9 | + * +-------+ + */ + test("rangeTable lessequal than -9") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"intCol0" <= -9).select($"intCol0") + s.show + // filter results without going through dataframe + val expected = rawResult.filter(_.intCol0 <= -9).map(_.intCol0).toSet + // filter results going through dataframe + val result = collectToSet[Int](s) + assert(expected === result) + } + + /** + * expected result: only showing top 20 rows + * +-------+ + * |intCol0| + * +-------+ + * | 0 | + * | 2 | + * | 4 | + * | 6 | + * | 8 | + * | 10 | + * | 12 | + * | 14 | + * | 16 | + * | 18 | + * | 20 | + * | 22 | + * | 24 | + * | 26 | + * | 28 | + * | 30 | + * | -9 | + * | -7 | + * | -5 | + * | -3 | + * +-------+ + */ + test("rangeTable greaterequal than -9") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"intCol0" >= -9).select($"intCol0") + s.show + // filter results without going through dataframe + val expected = rawResult.filter(_.intCol0 >= -9).map(_.intCol0).toSet + // filter results going through dataframe + val result = collectToSet[Int](s) + assert(expected === result) + } + + /** + * expected result: only showing top 20 rows + * +-------+ + * |intCol0| + * +-------+ + * | 0 | + * | 2 | + * | 4 | + * | 6 | + * | 8 | + * | 10 | + * | 12 | + * | 14 | + * | 16 | + * | 18 | + * | 20 | + * | 22 | + * | 24 | + * | 26 | + * | 28 | + * | 30 | + * +-------+ + */ + test("rangeTable greaterequal than 0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"intCol0" >= 0).select($"intCol0") + s.show + // filter results without going through dataframe + val expected = rawResult.filter(_.intCol0 >= 0).map(_.intCol0).toSet + // filter results going through dataframe + val result = collectToSet[Int](s) + assert(expected === result) + } + + /** + * expected result: only showing top 20 rows + * +-------+ + * |intCol0| + * +-------+ + * | 12 | + * | 14 | + * | 16 | + * | 18 | + * | 20 | + * | 22 | + * | 24 | + * | 26 | + * | 28 | + * | 30 | + * +-------+ + */ + test("rangeTable greater than 10") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"intCol0" > 10).select($"intCol0") + s.show + // filter results without going through dataframe + val expected = rawResult.filter(_.intCol0 > 10).map(_.intCol0).toSet + // filter results going through dataframe + val result = collectToSet[Int](s) + assert(expected === result) + } + + /** + * expected result: only showing top 20 rows + * +-------+ + * |intCol0| + * +-------+ + * | 0 | + * | 2 | + * | 4 | + * | 6 | + * | 8 | + * | 10 | + * | -9 | + * | -7 | + * | -5 | + * | -3 | + * | -1 | + * +-------+ + */ + test("rangeTable and") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"intCol0" > -10 && $"intCol0" <= 10).select($"intCol0") + s.show + // filter results without going through dataframe + val expected = rawResult + .filter(x => x.intCol0 > -10 && x.intCol0 <= 10) + .map(_.intCol0) + .toSet + // filter results going through dataframe + val result = collectToSet[Int](s) + assert(expected === result) + } + + /** + * expected result: only showing top 20 rows + * +-------+ + * |intCol0| + * +-------+ + * | 12 | + * | 14 | + * | 16 | + * | 18 | + * | 20 | + * | 22 | + * | 24 | + * | 26 | + * | 28 | + * | 30 | + * | -31 | + * | -29 | + * | -27 | + * | -25 | + * | -23 | + * | -21 | + * | -19 | + * | -17 | + * | -15 | + * | -13 | + * +-------+ + */ + + test("or") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"intCol0" <= -10 || $"intCol0" > 10).select($"intCol0") + s.show + // filter results without going through dataframe + val expected = rawResult + .filter(x => x.intCol0 <= -10 || x.intCol0 > 10) + .map(_.intCol0) + .toSet + // filter results going through dataframe + val result = collectToSet[Int](s) + assert(expected === result) + } + + /** + * expected result: only showing top 20 rows + * +-------+ + * |intCol0| + * +-------+ + * | 0 | + * | 2 | + * | 4 | + * | 6 | + * | 8 | + * | 10 | + * | 12 | + * | 14 | + * | 16 | + * | 18 | + * | 20 | + * | 22 | + * | 24 | + * | 26 | + * | 28 | + * | 30 | + * | -31 | + * | -29 | + * | -27 | + * | -25 | + * +-------+ + */ + test("rangeTable all") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"intCol0" >= -100).select($"intCol0") + s.show + // filter results without going through dataframe + val expected = rawResult.filter(_.intCol0 >= -100).map(_.intCol0).toSet + // filter results going through dataframe + val result = collectToSet[Int](s) + assert(expected === result) + } +} diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/StartsWithSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/StartsWithSuite.scala new file mode 100644 index 00000000..9a0d5675 --- /dev/null +++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/StartsWithSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.spark.datasources.Utils +import org.apache.hadoop.hbase.util.Bytes +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +class StartsWithSuite extends AnyFunSuite with BeforeAndAfterEach with BeforeAndAfterAll with Logging { + + test("simple1") { + val t = new Array[Byte](2) + t(0) = 1.toByte + t(1) = 2.toByte + + val expected = new Array[Byte](2) + expected(0) = 1.toByte + expected(1) = 3.toByte + + val res = Utils.incrementByteArray(t) + assert(res.sameElements(expected)) + } + + test("simple2") { + val t = new Array[Byte](1) + t(0) = 87.toByte + + val expected = new Array[Byte](1) + expected(0) = 88.toByte + + val res = Utils.incrementByteArray(t) + assert(res.sameElements(expected)) + } + + test("overflow1") { + val t = new Array[Byte](2) + t(0) = 1.toByte + t(1) = (-1).toByte + + val expected = new Array[Byte](2) + expected(0) = 2.toByte + expected(1) = 0.toByte + + val res = Utils.incrementByteArray(t) + + assert(res.sameElements(expected)) + } + + test("overflow2") { + val t = new Array[Byte](2) + t(0) = (-1).toByte + t(1) = (-1).toByte + + val expected = null + + val res = Utils.incrementByteArray(t) + + assert(res == expected) + } + + test("max-min-value") { + val t = new Array[Byte](2) + t(0) = 1.toByte + t(1) = (127).toByte + + val expected = new Array[Byte](2) + expected(0) = 1.toByte + expected(1) = (-128).toByte + + val res = Utils.incrementByteArray(t) + assert(res.sameElements(expected)) + } + + test("complicated") { + val imput = "row005" + val expectedOutput = "row006" + + val t = Bytes.toBytes(imput) + val expected = Bytes.toBytes(expectedOutput) + + val res = Utils.incrementByteArray(t) + assert(res.sameElements(expected)) + } + +} diff --git a/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/TableOutputFormatSuite.scala b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/TableOutputFormatSuite.scala new file mode 100644 index 00000000..a180dc97 --- /dev/null +++ b/spark4/hbase-spark4/src/test/scala/org/apache/hadoop/hbase/spark/TableOutputFormatSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.spark + +import java.text.SimpleDateFormat +import java.util.{Date, Locale} +import org.apache.hadoop.hbase.{HBaseTestingUtility, TableName, TableNotFoundException} +import org.apache.hadoop.hbase.mapreduce.TableOutputFormat +import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.mapreduce.{Job, TaskAttemptID, TaskType} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.spark.{SparkConf, SparkContext} +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +import scala.util.{Failure, Success, Try} + +// Unit tests for HBASE-20521: change get configuration(TableOutputFormat.conf) object first sequence from jobContext to getConf +// this suite contains two tests, one for normal case(getConf return null, use jobContext), create new TableOutputformat object without init TableOutputFormat.conf object, +// configuration object inside checkOutputSpecs came from jobContext. +// The other one(getConf return conf object) we manually call "setConf" to init TableOutputFormat.conf, for making it more straight forward, we specify a nonexistent table +// name in conf object, checkOutputSpecs will then throw TableNotFoundException exception +class TableOutputFormatSuite + extends AnyFunSuite + with BeforeAndAfterEach + with BeforeAndAfterAll + with Logging { + @transient var sc: SparkContext = null + var TEST_UTIL = new HBaseTestingUtility + + val tableName = "TableOutputFormatTest" + val tableNameTest = "NonExistentTable" + val columnFamily = "cf" + + override protected def beforeAll(): Unit = { + TEST_UTIL.startMiniCluster + + logInfo(" - minicluster started") + try { + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + } catch { + case e: Exception => logInfo(" - no table " + tableName + " found") + } + + TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily)) + logInfo(" - created table") + + // set "validateOutputSpecs" true anyway, force to validate output spec + val sparkConf = new SparkConf() + .setMaster("local") + .setAppName("test") + + sc = new SparkContext(sparkConf) + } + + override protected def afterAll(): Unit = { + logInfo(" - delete table: " + tableName) + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + logInfo(" - shutting down minicluster") + TEST_UTIL.shutdownMiniCluster() + + TEST_UTIL.cleanupTestDir() + sc.stop() + } + + def getJobContext() = { + val hConf = TEST_UTIL.getConfiguration + hConf.set(TableOutputFormat.OUTPUT_TABLE, tableName) + val job = Job.getInstance(hConf) + job.setOutputFormatClass(classOf[TableOutputFormat[String]]) + + val jobTrackerId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(new Date()) + val jobAttemptId = new TaskAttemptID(jobTrackerId, 1, TaskType.MAP, 0, 0) + new TaskAttemptContextImpl(job.getConfiguration, jobAttemptId) + } + + // Mock up jobContext object and execute actions in "write" function + // from "org.apache.spark.internal.io.SparkHadoopMapReduceWriter" + // this case should run normally without any exceptions + test( + "TableOutputFormat.checkOutputSpecs test without setConf called, should return true and without exceptions") { + val jobContext = getJobContext() + val format = jobContext.getOutputFormatClass + val jobFormat = format.newInstance + Try { + jobFormat.checkOutputSpecs(jobContext) + } match { + case Success(_) => assert(true) + case Failure(_) => assert(false) + } + } + + // Set configuration externally, checkOutputSpec should use configuration object set by "SetConf" method + // rather than jobContext, this case should throw "TableNotFoundException" exception + test( + "TableOutputFormat.checkOutputSpecs test without setConf called, should throw TableNotFoundException") { + val jobContext = getJobContext() + val format = jobContext.getOutputFormatClass + val jobFormat = format.newInstance + + val hConf = TEST_UTIL.getConfiguration + hConf.set(TableOutputFormat.OUTPUT_TABLE, tableNameTest) + jobFormat.asInstanceOf[TableOutputFormat[String]].setConf(hConf) + Try { + jobFormat.checkOutputSpecs(jobContext) + } match { + case Success(_) => assert(false) + case Failure(e: Exception) => { + if (e.isInstanceOf[TableNotFoundException]) + assert(true) + else + assert(false) + } + case _ => None + } + } + +} diff --git a/spark4/pom.xml b/spark4/pom.xml new file mode 100644 index 00000000..f6bfe2b5 --- /dev/null +++ b/spark4/pom.xml @@ -0,0 +1,103 @@ + + + + 4.0.0 + + + org.apache.hbase.connectors + hbase-connectors + ${revision} + ../pom.xml + + + spark4 + ${revision} + pom + Apache HBase - Spark4 + Spark4 Connectors for Apache HBase + + + hbase-spark4-protocol + hbase-spark4-protocol-shaded + hbase-spark4 + hbase-spark4-it + + + + 0.6.1 + 2.12.5 + 4.0.0-preview1 + + 2.13.14 + 2.13 + 3.2.3 + + + + + + org.apache.hbase.connectors.spark + hbase-spark4 + ${revision} + + + org.apache.hbase.connectors.spark + hbase-spark4-protocol + ${revision} + + + org.apache.hbase.connectors.spark + hbase-spark4-protocol-shaded + ${revision} + + + org.apache.hbase.thirdparty + hbase-shaded-miscellaneous + ${hbase-thirdparty.version} + + + + + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + ${protobuf.plugin.version} + + com.google.protobuf:protoc:${external.protobuf.version}:exe:${os.detected.classifier} + ${basedir}/src/main/protobuf/ + false + true + + + + net.revelc.code + warbucks-maven-plugin + 1.1.0 + + + + + +