Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -548,4 +548,27 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| )
|)""".stripMargin)
}

protected def createTableHttpLog(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
|(
| id INT,
| status_code INT,
| request_path STRING,
| timestamp STRING
|)
| USING $tableType $tableOptions
|""".stripMargin)

sql(s"""
| INSERT INTO $testTable
| VALUES (1, 200, '/home', '2023-10-01 10:00:00'),
| (2, null, '/about', '2023-10-01 10:05:00'),
| (3, 500, '/contact', '2023-10-01 10:10:00'),
| (4, 301, '/home', '2023-10-01 10:15:00'),
| (5, 200, '/services', '2023-10-01 10:20:00'),
| (6, 403, '/home', '2023-10-01 10:25:00')
| """.stripMargin)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, LessThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort}
import org.apache.spark.sql.streaming.StreamTest

Expand All @@ -21,12 +21,14 @@ class FlintSparkPPLEvalITSuite

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"
private val testTableHttpLog = "spark_catalog.default.flint_ppl_test_http_log"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
createTableHttpLog(testTableHttpLog)
}

protected override def afterEach(): Unit = {
Expand Down Expand Up @@ -504,7 +506,134 @@ class FlintSparkPPLEvalITSuite
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("eval case function") {
val frame = sql(s"""
| source = $testTableHttpLog |
| eval status_category =
| case(status_code >= 200 AND status_code < 300, 'Success',
| status_code >= 300 AND status_code < 400, 'Redirection',
| status_code >= 400 AND status_code < 500, 'Client Error',
| status_code >= 500, 'Server Error'
| else concat('Incorrect HTTP status code for request ', request_path)
| )
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row(1, 200, "/home", "2023-10-01 10:00:00", "Success"),
Row(
2,
null,
"/about",
"2023-10-01 10:05:00",
"Incorrect HTTP status code for request /about"),
Row(3, 500, "/contact", "2023-10-01 10:10:00", "Server Error"),
Row(4, 301, "/home", "2023-10-01 10:15:00", "Redirection"),
Row(5, 200, "/services", "2023-10-01 10:20:00", "Success"),
Row(6, 403, "/home", "2023-10-01 10:25:00", "Client Error"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getInt(0))
assert(results.sorted.sameElements(expectedResults.sorted))
val expectedColumns =
Array[String]("id", "status_code", "request_path", "timestamp", "status_category")
assert(frame.columns.sameElements(expectedColumns))

val logicalPlan: LogicalPlan = frame.queryExecution.logical

val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_http_log"))
val conditionValueSequence = Seq(
(graterOrEqualAndLessThan("status_code", 200, 300), Literal("Success")),
(graterOrEqualAndLessThan("status_code", 300, 400), Literal("Redirection")),
(graterOrEqualAndLessThan("status_code", 400, 500), Literal("Client Error")),
(
EqualTo(
Literal(true),
GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(500))),
Literal("Server Error")))
val elseValue = UnresolvedFunction(
"concat",
Seq(
Literal("Incorrect HTTP status code for request "),
UnresolvedAttribute("request_path")),
isDistinct = false)
val caseFunction = CaseWhen(conditionValueSequence, elseValue)
val aliasStatusCategory = Alias(caseFunction, "status_category")()
val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory)
val evalProject = Project(evalProjectList, table)
val expectedPlan = Project(Seq(UnresolvedStar(None)), evalProject)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("eval case function in complex pipeline") {
val frame = sql(s"""
| source = $testTableHttpLog
| | where ispresent(status_code)
| | eval status_category =
| case(status_code >= 200 AND status_code < 300, 'Success',
| status_code >= 300 AND status_code < 400, 'Redirection',
| status_code >= 400 AND status_code < 500, 'Client Error',
| status_code >= 500, 'Server Error'
| else 'Unknown'
| )
| | stats count() by status_category
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row(1L, "Redirection"),
Row(1L, "Client Error"),
Row(1L, "Server Error"),
Row(2L, "Success"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getString(1))
assert(results.sorted.sameElements(expectedResults.sorted))
val expectedColumns = Array[String]("count()", "status_category")
assert(frame.columns.sameElements(expectedColumns))

val logicalPlan: LogicalPlan = frame.queryExecution.logical

val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_http_log"))
val filter = Filter(
UnresolvedFunction(
"isnotnull",
Seq(UnresolvedAttribute("status_code")),
isDistinct = false),
table)
val conditionValueSequence = Seq(
(graterOrEqualAndLessThan("status_code", 200, 300), Literal("Success")),
(graterOrEqualAndLessThan("status_code", 300, 400), Literal("Redirection")),
(graterOrEqualAndLessThan("status_code", 400, 500), Literal("Client Error")),
(
EqualTo(
Literal(true),
GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(500))),
Literal("Server Error")))
val elseValue = Literal("Unknown")
val caseFunction = CaseWhen(conditionValueSequence, elseValue)
val aliasStatusCategory = Alias(caseFunction, "status_category")()
val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory)
val evalProject = Project(evalProjectList, filter)
val aggregation = Aggregate(
Seq(Alias(UnresolvedAttribute("status_category"), "status_category")()),
Seq(
Alias(
UnresolvedFunction("COUNT", Seq(UnresolvedStar(None)), isDistinct = false),
"count()")(),
Alias(UnresolvedAttribute("status_category"), "status_category")()),
evalProject)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregation)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

private def graterOrEqualAndLessThan(fieldName: String, min: Int, max: Int) = {
val and = And(
GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(min)),
LessThan(UnresolvedAttribute(fieldName), Literal(max)))
EqualTo(Literal(true), and)
}

// Todo excluded fields not support yet

ignore("test single eval expression with excluded fields") {
val frame = sql(s"""
| source = $testTable | eval new_field = "New Field" | fields - age
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

Expand All @@ -19,11 +19,13 @@ class FlintSparkPPLFiltersITSuite

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"
private val duplicationTable = "spark_catalog.default.flint_ppl_test_duplication_table"

override def beforeAll(): Unit = {
super.beforeAll()
// Create test table
createPartitionedStateCountryTable(testTable)
createDuplicationNullableTable(duplicationTable)
}

protected override def afterEach(): Unit = {
Expand Down Expand Up @@ -348,4 +350,107 @@ class FlintSparkPPLFiltersITSuite
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("case function used as filter") {
val frame = sql(s"""
| source = $testTable case(country = 'USA', 'The United States of America' else 'Other country') = 'The United States of America'
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4),
Row("Hello", 30, "New York", "USA", 2023, 4))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))

assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val conditionValueSequence = Seq(
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("country"), Literal("USA"))),
Literal("The United States of America")))
val elseValue = Literal("Other country")
val caseFunction = CaseWhen(conditionValueSequence, elseValue)
val filterExpr = EqualTo(caseFunction, Literal("The United States of America"))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("case function used as filter complex filter") {
val frame = sql(s"""
| source = $duplicationTable
| | eval factor = case(id > 15, id - 14, isnull(name), id - 7, id < 3, id + 1 else 1)
| | where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even'
| | stats count() by factor
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect() // count(), factor
// Define the expected results
val expectedResults: Array[Row] = Array(Row(1, 4), Row(1, 6), Row(2, 2))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table =
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_duplication_table"))

// case function used in eval command
val conditionValueEval = Seq(
(
EqualTo(Literal(true), GreaterThan(UnresolvedAttribute("id"), Literal(15))),
UnresolvedFunction("-", Seq(UnresolvedAttribute("id"), Literal(14)), isDistinct = false)),
(
EqualTo(
Literal(true),
UnresolvedFunction("isnull", Seq(UnresolvedAttribute("name")), isDistinct = false)),
UnresolvedFunction("-", Seq(UnresolvedAttribute("id"), Literal(7)), isDistinct = false)),
(
EqualTo(Literal(true), LessThan(UnresolvedAttribute("id"), Literal(3))),
UnresolvedFunction("+", Seq(UnresolvedAttribute("id"), Literal(1)), isDistinct = false)))
val aliasCaseFactor = Alias(CaseWhen(conditionValueEval, Literal(1)), "factor")()
val evalProject = Project(Seq(UnresolvedStar(None), aliasCaseFactor), table)

// case in where clause
val conditionValueWhere = Seq(
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(2))),
Literal("even")),
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(4))),
Literal("even")),
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(6))),
Literal("even")),
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(8))),
Literal("even")))
val caseFunctionWhere = CaseWhen(conditionValueWhere, Literal("odd"))
val filterPlan = Filter(EqualTo(caseFunctionWhere, Literal("even")), evalProject)

val aggregation = Aggregate(
Seq(Alias(UnresolvedAttribute("factor"), "factor")()),
Seq(
Alias(
UnresolvedFunction("COUNT", Seq(UnresolvedStar(None)), isDistinct = false),
"count()")(),
Alias(UnresolvedAttribute("factor"), "factor")()),
filterPlan)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregation)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
48 changes: 48 additions & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,29 @@ See the next samples of PPL queries :
- `source = table | where ispresent(b)`
- `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3`
- `source = table | where isempty(a)`
- `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'`;
-
```
source = table | eval status_category =
case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Incorrect HTTP status code')
| where case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Incorrect HTTP status code'
) = 'Incorrect HTTP status code'
```
-
```
source = table
| eval factor = case(a > 15, a - 14, isnull(b), a - 7, a < 3, a + 1 else 1)
| where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even'
| stats count() by factor
```

**Filters With Logical Conditions**
- `source = table | where c = 'test' AND a = 1 | fields a,b,c`
Expand All @@ -265,6 +288,31 @@ Assumptions: `a`, `b`, `c` are existing fields in `table`
- `source = table | eval f = ispresent(a)`
- `source = table | eval r = coalesce(a, b, c) | fields r`
- `source = table | eval e = isempty(a) | fields e`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one', a = 2, 'two', a = 3, 'three', a = 4, 'four', a = 5, 'five', a = 6, 'six', a = 7, 'se7en', a = 8, 'eight', a = 9, 'nine')`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))`
-
```
source = table | eval e = eval status_category =
case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Unknown'
)
```
-
```
source = table | where ispresent(a) |
eval status_category =
case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Incorrect HTTP status code'
)
| stats count() by status_category
```

Limitation: Overriding existing field is unsupported, following queries throw exceptions with "Reference 'a' is ambiguous"
- `source = table | eval a = 10 | fields a,b,c`
Expand Down
Loading