Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OH client side spec to accept replication config #43

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package com.linkedin.openhouse.spark.statementtest;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.linkedin.openhouse.spark.sql.catalyst.parser.extensions.OpenhouseParseException;
import java.nio.file.Files;
import lombok.SneakyThrows;
import org.apache.hadoop.fs.Path;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.execution.ExplainMode;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class SetTableReplicationPolicyStatementTest {
private static SparkSession spark = null;

@SneakyThrows
@BeforeAll
public void setupSpark() {
Path unittest = new Path(Files.createTempDirectory("unittest_settablepolicy").toString());
spark =
SparkSession.builder()
.master("local[2]")
.config(
"spark.sql.extensions",
("org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions,"
+ "com.linkedin.openhouse.spark.extensions.OpenhouseSparkSessionExtensions"))
.config("spark.sql.catalog.openhouse", "org.apache.iceberg.spark.SparkCatalog")
.config("spark.sql.catalog.openhouse.type", "hadoop")
.config("spark.sql.catalog.openhouse.warehouse", unittest.toString())
.getOrCreate();
}

@Test
public void testSimpleSetReplicationPolicy() {
String replicationConfigJson =
"{\"cluster\":\"a\", \"schedule\":\"b\"}, {\"cluster\": \"aa\", \"schedule\": \"bb\"}";
Dataset<Row> ds =
spark.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = "
+ "({cluster:'a', schedule:'b'}, {cluster: 'aa', schedule: 'bb'}))");
assert isPlanValid(ds, replicationConfigJson);

replicationConfigJson = "{\"cluster\":\"a\", \"schedule\":\"b\"}";
ds =
spark.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster:'a', schedule:'b'}))");
assert isPlanValid(ds, replicationConfigJson);
}

@Test
public void testReplicationPolicyWithoutProperSyntax() {
// missing schedule keyword
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: 'aa'}))")
.show());

// Missing cluster keyword
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({schedule: 'ss'}))")
.show());

// Typo in keyword schedule
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: 'aa', schedul: 'ss'}))")
.show());

// Typo in keyword cluster
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({clustr: 'aa', schedule: 'ss'}))")
.show());

// Missing quote in cluster value
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: aa', schedule: 'ss}))")
.show());

// Type in REPLICATION keyword
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICAT = ({cluster: 'aa', schedule: 'ss}))")
.show());

// Missing cluster and schedule value
Assertions.assertThrows(
OpenhouseParseException.class,
() -> spark.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({}))").show());
}

@BeforeEach
public void setup() {
spark.sql("CREATE TABLE openhouse.db.table (id bigint, data string) USING iceberg").show();
spark.sql("CREATE TABLE openhouse.0_.0_ (id bigint, data string) USING iceberg").show();
spark
.sql("ALTER TABLE openhouse.db.table SET TBLPROPERTIES ('openhouse.tableId' = 'tableid')")
.show();
spark
.sql("ALTER TABLE openhouse.0_.0_ SET TBLPROPERTIES ('openhouse.tableId' = 'tableid')")
.show();
}

@AfterEach
public void tearDown() {
spark.sql("DROP TABLE openhouse.db.table").show();
spark.sql("DROP TABLE openhouse.0_.0_").show();
}

@AfterAll
public void tearDownSpark() {
spark.close();
}

@SneakyThrows
private boolean isPlanValid(Dataset<Row> dataframe, String replicationConfigJson) {
replicationConfigJson = "[" + replicationConfigJson + "]";
String queryStr = dataframe.queryExecution().explainString(ExplainMode.fromString("simple"));
JsonArray jsonArray = new Gson().fromJson(replicationConfigJson, JsonArray.class);
boolean isValid = false;
for (JsonElement element : jsonArray) {
JsonObject entry = element.getAsJsonObject();
String cluster = entry.get("cluster").getAsString();
String schedule = entry.get("schedule").getAsString();
isValid = queryStr.contains(cluster) && queryStr.contains(schedule);
}
return isValid;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ singleStatement

statement
: ALTER TABLE multipartIdentifier SET POLICY '(' retentionPolicy (columnRetentionPolicy)? ')' #setRetentionPolicy
| ALTER TABLE multipartIdentifier SET POLICY '(' replicationPolicy ')' #setReplicationPolicy
| ALTER TABLE multipartIdentifier SET POLICY '(' sharingPolicy ')' #setSharingPolicy
| ALTER TABLE multipartIdentifier MODIFY columnNameClause SET columnPolicy #setColumnPolicyTag
| GRANT privilege ON grantableResource TO principal #grantStatement
Expand Down Expand Up @@ -64,7 +65,7 @@ quotedIdentifier
;

nonReserved
: ALTER | TABLE | SET | POLICY | RETENTION | SHARING
: ALTER | TABLE | SET | POLICY | RETENTION | SHARING | REPLICATION
| GRANT | REVOKE | ON | TO | SHOW | GRANTS | PATTERN | WHERE | COLUMN
;

Expand All @@ -83,6 +84,18 @@ columnRetentionPolicy
: ON columnNameClause (columnRetentionPolicyPatternClause)?
;

replicationPolicy
: REPLICATION '=' tableReplicationPolicy
;

tableReplicationPolicy
: '(' replicationPolicyClause (',' replicationPolicyClause)* ')'
;

replicationPolicyClause
: '{' CLUSTER ':' STRING ',' SCHEDULE ':' STRING '}'
;

columnRetentionPolicyPatternClause
: WHERE retentionColumnPatternClause
;
Expand Down Expand Up @@ -136,6 +149,7 @@ TABLE: 'TABLE';
SET: 'SET';
POLICY: 'POLICY';
RETENTION: 'RETENTION';
REPLICATION: 'REPLICATION';
SHARING: 'SHARING';
GRANT: 'GRANT';
REVOKE: 'REVOKE';
Expand All @@ -150,6 +164,8 @@ DATABASE: 'DATABASE';
SHOW: 'SHOW';
GRANTS: 'GRANTS';
PATTERN: 'PATTERN';
CLUSTER: 'CLUSTER';
SCHEDULE: 'SCHEDULE';
WHERE: 'WHERE';
COLUMN: 'COLUMN';
PII: 'PII';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package com.linkedin.openhouse.spark.sql.catalyst.parser.extensions

import com.linkedin.openhouse.spark.sql.catalyst.enums.GrantableResourceTypes
import com.linkedin.openhouse.spark.sql.catalyst.parser.extensions.OpenhouseSqlExtensionsParser._
import com.linkedin.openhouse.spark.sql.catalyst.plans.logical.{GrantRevokeStatement, SetRetentionPolicy, SetSharingPolicy, SetColumnPolicyTag, ShowGrantsStatement}
import com.linkedin.openhouse.spark.sql.catalyst.plans.logical.{GrantRevokeStatement, SetColumnPolicyTag, SetReplicationPolicy, SetRetentionPolicy, SetSharingPolicy, ShowGrantsStatement}
import com.linkedin.openhouse.spark.sql.catalyst.enums.GrantableResourceTypes.GrantableResourceType
import com.linkedin.openhouse.gen.tables.client.model.TimePartitionSpec
import org.antlr.v4.runtime.tree.ParseTree
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import scala.collection.JavaConversions.iterableAsScalaIterable
import scala.collection.JavaConverters._

class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends OpenhouseSqlExtensionsBaseVisitor[AnyRef] {
Expand All @@ -26,6 +27,12 @@ class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends Openh
SetRetentionPolicy(tableName, granularity, count, Option(colName), Option(colPattern))
}

override def visitSetReplicationPolicy(ctx: SetReplicationPolicyContext): SetReplicationPolicy = {
val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier)
val replicationPolicies = typedVisit[String](ctx.replicationPolicy())
SetReplicationPolicy(tableName, replicationPolicies)
}

override def visitSetSharingPolicy(ctx: SetSharingPolicyContext): SetSharingPolicy = {
val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier)
val sharing = typedVisit[String](ctx.sharingPolicy())
Expand Down Expand Up @@ -86,6 +93,20 @@ class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends Openh
typedVisit[(String, Int)](ctx.duration())
}

override def visitReplicationPolicy(ctx: ReplicationPolicyContext): (String) = {
typedVisit[(String)](ctx.tableReplicationPolicy())
}

override def visitTableReplicationPolicy(ctx: TableReplicationPolicyContext): (String) = {
ctx.replicationPolicyClause().forEach(ele => print(ele))
val policy = ctx.replicationPolicyClause().map(ele => typedVisit[String](ele))
policy.mkString(",")
}

override def visitReplicationPolicyClause(ctx: ReplicationPolicyClauseContext): AnyRef = {
ctx.getText
}

override def visitColumnRetentionPolicy(ctx: ColumnRetentionPolicyContext): (String, String) = {
if (ctx.columnRetentionPolicyPatternClause() != null) {
(ctx.columnNameClause().identifier().getText(), ctx.columnRetentionPolicyPatternClause().retentionColumnPatternClause().STRING().getText)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.linkedin.openhouse.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.plans.logical.Command

case class SetReplicationPolicy(tableName: Seq[String], replicationPolicies: String) extends Command {
override def simpleString(maxFields: Int): String = {
s"SetReplicationPolicy: ${tableName} ${replicationPolicies}}"
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.linkedin.openhouse.spark.sql.execution.datasources.v2

import com.linkedin.openhouse.spark.sql.catalyst.plans.logical.{GrantRevokeStatement, SetRetentionPolicy, SetSharingPolicy, SetColumnPolicyTag, ShowGrantsStatement}
import com.linkedin.openhouse.spark.sql.catalyst.plans.logical.{GrantRevokeStatement, SetColumnPolicyTag, SetReplicationPolicy, SetRetentionPolicy, SetSharingPolicy, ShowGrantsStatement}
import org.apache.iceberg.spark.{Spark3Util, SparkCatalog, SparkSessionCatalog}
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
Expand All @@ -15,6 +15,8 @@ case class OpenhouseDataSourceV2Strategy(spark: SparkSession) extends Strategy w
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case SetRetentionPolicy(CatalogAndIdentifierExtractor(catalog, ident), granularity, count, colName, colPattern) =>
SetRetentionPolicyExec(catalog, ident, granularity, count, colName, colPattern) :: Nil
case SetReplicationPolicy(CatalogAndIdentifierExtractor(catalog, ident), replicationPolicies) =>
SetReplicationPolicyExec(catalog, ident, replicationPolicies) :: Nil
case SetSharingPolicy(CatalogAndIdentifierExtractor(catalog, ident), sharing) =>
SetSharingPolicyExec(catalog, ident, sharing) :: Nil
case SetColumnPolicyTag(CatalogAndIdentifierExtractor(catalog, ident), policyTag, cols) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.linkedin.openhouse.spark.sql.execution.datasources.v2

import org.apache.iceberg.spark.source.SparkTable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec

case class SetReplicationPolicyExec(catalog: TableCatalog, ident: Identifier, replicationPolicies: String) extends V2CommandExec{
override protected def run(): Seq[InternalRow] = {
catalog.loadTable(ident) match {
case iceberg: SparkTable if iceberg.table().properties().containsKey("openhouse.tableId") =>
val key = "updated.openhouse.policy"
val value = s"""{"replication": [${replicationPolicies}]}"""
iceberg.table().updateProperties()
.set(key, value)
.commit()

case table =>
throw new UnsupportedOperationException(s"Cannot set replication policy for non-Openhouse table: $table")
}
Nil
}

override def output: Seq[Attribute] = Nil
}
Loading