Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,24 @@
package org.apache.iceberg.spark.extensions;

import java.util.List;
import java.util.Map;
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Table;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.spark.sql.AnalysisException;
import org.assertj.core.api.Assertions;
import org.junit.After;
import org.junit.Test;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;

public class TestAncestorsOfProcedure extends SparkExtensionsTestBase {
@ExtendWith(ParameterizedTestExtension.class)
public class TestAncestorsOfProcedure extends ExtensionsTestBase {

public TestAncestorsOfProcedure(
String catalogName, String implementation, Map<String, String> config) {
super(catalogName, implementation, config);
}

@After
@AfterEach
public void removeTables() {
sql("DROP TABLE IF EXISTS %s", tableName);
}

@Test
@TestTemplate
public void testAncestorOfUsingEmptyArgs() {
sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
Expand All @@ -60,7 +57,7 @@ public void testAncestorOfUsingEmptyArgs() {
output);
}

@Test
@TestTemplate
public void testAncestorOfUsingSnapshotId() {
sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
Expand All @@ -84,7 +81,7 @@ public void testAncestorOfUsingSnapshotId() {
sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, preSnapshotId));
}

@Test
@TestTemplate
public void testAncestorOfWithRollBack() {
sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
Table table = validationCatalog.loadTable(tableIdent);
Expand Down Expand Up @@ -128,7 +125,7 @@ public void testAncestorOfWithRollBack() {
sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, thirdSnapshotId));
}

@Test
@TestTemplate
public void testAncestorOfUsingNamedArgs() {
sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
Expand All @@ -145,7 +142,7 @@ public void testAncestorOfUsingNamedArgs() {
catalogName, firstSnapshotId, tableIdent));
}

@Test
@TestTemplate
public void testInvalidAncestorOfCases() {
Assertions.assertThatThrownBy(() -> sql("CALL %s.system.ancestors_of()", catalogName))
.isInstanceOf(AnalysisException.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
*/
package org.apache.iceberg.spark.extensions;

import static org.assertj.core.api.Assertions.assertThat;
import static scala.collection.JavaConverters.seqAsJavaList;

import java.math.BigDecimal;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.List;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.Expression;
Expand All @@ -38,22 +40,16 @@
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.assertj.core.api.Assertions;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import scala.collection.JavaConverters;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

public class TestCallStatementParser {

@Rule public TemporaryFolder temp = new TemporaryFolder();

private static SparkSession spark = null;
private static ParserInterface parser = null;

@BeforeClass
@BeforeAll
public static void startSpark() {
TestCallStatementParser.spark =
SparkSession.builder()
Expand All @@ -64,7 +60,7 @@ public static void startSpark() {
TestCallStatementParser.parser = spark.sessionState().sqlParser();
}

@AfterClass
@AfterAll
public static void stopSpark() {
SparkSession currentSpark = TestCallStatementParser.spark;
TestCallStatementParser.spark = null;
Expand All @@ -76,10 +72,9 @@ public static void stopSpark() {
public void testCallWithPositionalArgs() throws ParseException {
CallStatement call =
(CallStatement) parser.parsePlan("CALL c.n.func(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)");
Assert.assertEquals(
ImmutableList.of("c", "n", "func"), JavaConverters.seqAsJavaList(call.name()));
assertThat(seqAsJavaList(call.name())).containsExactly("c", "n", "func");

Assert.assertEquals(7, call.args().size());
assertThat(seqAsJavaList(call.args())).hasSize(7);

checkArg(call, 0, 1, DataTypes.IntegerType);
checkArg(call, 1, "2", DataTypes.StringType);
Expand All @@ -94,10 +89,9 @@ public void testCallWithPositionalArgs() throws ParseException {
public void testCallWithNamedArgs() throws ParseException {
CallStatement call =
(CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, c2 => '2', c3 => true)");
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
assertThat(seqAsJavaList(call.name())).containsExactly("cat", "system", "func");

Assert.assertEquals(3, call.args().size());
assertThat(seqAsJavaList(call.args())).hasSize(3);

checkArg(call, 0, "c1", 1, DataTypes.IntegerType);
checkArg(call, 1, "c2", "2", DataTypes.StringType);
Expand All @@ -107,10 +101,9 @@ public void testCallWithNamedArgs() throws ParseException {
@Test
public void testCallWithMixedArgs() throws ParseException {
CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, '2')");
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
assertThat(seqAsJavaList(call.name())).containsExactly("cat", "system", "func");

Assert.assertEquals(2, call.args().size());
assertThat(seqAsJavaList(call.args())).hasSize(2);

checkArg(call, 0, "c1", 1, DataTypes.IntegerType);
checkArg(call, 1, "2", DataTypes.StringType);
Expand All @@ -121,10 +114,9 @@ public void testCallWithTimestampArg() throws ParseException {
CallStatement call =
(CallStatement)
parser.parsePlan("CALL cat.system.func(TIMESTAMP '2017-02-03T10:37:30.00Z')");
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
assertThat(seqAsJavaList(call.name())).containsExactly("cat", "system", "func");

Assert.assertEquals(1, call.args().size());
assertThat(seqAsJavaList(call.args())).hasSize(1);

checkArg(
call, 0, Timestamp.from(Instant.parse("2017-02-03T10:37:30.00Z")), DataTypes.TimestampType);
Expand All @@ -134,10 +126,9 @@ public void testCallWithTimestampArg() throws ParseException {
public void testCallWithVarSubstitution() throws ParseException {
CallStatement call =
(CallStatement) parser.parsePlan("CALL cat.system.func('${spark.extra.prop}')");
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
assertThat(seqAsJavaList(call.name())).containsExactly("cat", "system", "func");

Assert.assertEquals(1, call.args().size());
assertThat(seqAsJavaList(call.args())).hasSize(1);

checkArg(call, 0, "value", DataTypes.StringType);
}
Expand Down Expand Up @@ -165,10 +156,9 @@ public void testCallStripsComments() throws ParseException {
"CALL -- a line ending comment\n" + "cat.system.func('${spark.extra.prop}')");
for (String sqlText : callStatementsWithComments) {
CallStatement call = (CallStatement) parser.parsePlan(sqlText);
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
assertThat(seqAsJavaList(call.name())).containsExactly("cat", "system", "func");

Assert.assertEquals(1, call.args().size());
assertThat(seqAsJavaList(call.args())).hasSize(1);

checkArg(call, 0, "value", DataTypes.StringType);
}
Expand All @@ -188,25 +178,24 @@ private void checkArg(

if (expectedName != null) {
NamedArgument arg = checkCast(call.args().apply(index), NamedArgument.class);
Assert.assertEquals(expectedName, arg.name());
assertThat(arg.name()).isEqualTo(expectedName);
} else {
CallArgument arg = call.args().apply(index);
checkCast(arg, PositionalArgument.class);
}

Expression expectedExpr = toSparkLiteral(expectedValue, expectedType);
Expression actualExpr = call.args().apply(index).expr();
Assert.assertEquals("Arg types must match", expectedExpr.dataType(), actualExpr.dataType());
Assert.assertEquals("Arg must match", expectedExpr, actualExpr);
assertThat(actualExpr.dataType()).as("Arg types must match").isEqualTo(expectedExpr.dataType());
assertThat(actualExpr).as("Arg must match").isEqualTo(expectedExpr);
}

private Literal toSparkLiteral(Object value, DataType dataType) {
return Literal$.MODULE$.create(value, dataType);
}

private <T> T checkCast(Object value, Class<T> expectedClass) {
Assert.assertTrue(
"Expected instance of " + expectedClass.getName(), expectedClass.isInstance(value));
assertThat(value).isInstanceOf(expectedClass);
return expectedClass.cast(value);
}
}
Loading