Skip to content

Commit

Permalink
Consider type in predicates:
Browse files Browse the repository at this point in the history
- Additional predicate to check string type
- Return NULL in wrong predicate type
- Fix bug when string predicates were not recognized as `WHERE` conditions
- Merge changes from #191

TCK +3

Signed-off-by: Dwitry [email protected]
  • Loading branch information
dwitry committed Jan 18, 2019
1 parent 890ad64 commit 2d43f50
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.opencypher.gremlin.groups.SkipExtensions;
import org.opencypher.gremlin.rules.GremlinServerExternalResource;

public class NullTest {
Expand Down Expand Up @@ -64,4 +66,34 @@ public void ignoreNullWhenRemovingProperty() {
.extracting("a")
.containsExactly((Object) null);
}

@Test
@Category(SkipExtensions.CustomPredicates.class)
public void predicateOnNull() {
submitAndGet("CREATE (a)");

List<Map<String, Object>> results = submitAndGet(
"MATCH (a)\n" +
"WHERE a.name CONTAINS 'b'\n" +
"RETURN count(a) as cnt"
);

assertThat(results)
.extracting("cnt")
.containsExactly(0L);
}

@Test
@Category(SkipExtensions.CustomPredicates.class)
public void nullOnIncompatibleTypes() {
submitAndGet(" CREATE ({val: 1})");

List<Map<String, Object>> results = submitAndGet(
"MATCH (n) RETURN 'a' STARTS WITH n.val as r"
);

assertThat(results)
.extracting("r")
.containsExactly((Object) null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@ public enum CustomPredicate implements BiPredicate<Object, Object> {
cypherStartsWith {
@Override
public boolean test(Object a, Object b) {
return a != null && b != null && a.toString().startsWith(b.toString());
return a.toString().startsWith(b.toString());
}
},

cypherEndsWith {
@Override
public boolean test(Object a, Object b) {
return a != null && b != null && a.toString().endsWith(b.toString());
return a.toString().endsWith(b.toString());
}
},

cypherContains {
@Override
public boolean test(Object a, Object b) {
return a != null && b != null && a.toString().contains(b.toString());
return a.toString().contains(b.toString());
}
},

Expand All @@ -55,6 +55,13 @@ public boolean test(Object a, Object b) {
public boolean test(Object a, Object b) {
return a instanceof Edge;
}
},

cypherIsString {
@Override
public boolean test(Object a, Object b) {
return a instanceof String;
}
};

public static P<Object> cypherStartsWith(final Object prefix) {
Expand All @@ -76,4 +83,8 @@ public static P<Object> cypherIsNode() {
public static P<Object> cypherIsRelationship() {
return new P<>(CustomPredicate.cypherIsRelationship, null);
}

public static P<Object> cypherIsString() {
return new P<>(CustomPredicate.cypherIsString, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,6 @@ public interface GremlinPredicates<P> {
P isNode();

P isRelationship();

P isString();
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ public P isRelationship() {
return CustomPredicate.cypherIsRelationship();
}

@Override
public P isString() {
return CustomPredicate.cypherIsString();
}

private static Object[] inlineParameters(Object... values) {
return Stream.of(values)
.map(BytecodeGremlinPredicates::inlineParameter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,9 @@ public GroovyPredicate isNode() {
public GroovyPredicate isRelationship() {
return new GroovyPredicate("cypherIsRelationship");
}

@Override
public GroovyPredicate isString() {
return new GroovyPredicate("cypherIsString");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,9 @@ public P isNode() {
public P isRelationship() {
return CustomPredicate.cypherIsRelationship();
}

@Override
public P isString() {
return CustomPredicate.cypherIsString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ sealed class TranslationWriter[T, P] private (translator: Translator[T, P], para
case Contains(value) => p.contains(writeValue(value))
case IsNode() => p.isNode
case IsRelationship() => p.isRelationship
case IsString() => p.isString
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,6 @@ class IRGremlinPredicates extends GremlinPredicates[GremlinPredicate] {
override def isNode: GremlinPredicate = IsNode()

override def isRelationship: GremlinPredicate = IsRelationship()

override def isString: GremlinPredicate = IsString()
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ case class StartsWith(value: Any) extends GremlinPredicate
case class EndsWith(value: Any) extends GremlinPredicate
case class Contains(value: Any) extends GremlinPredicate
case class IsNode() extends GremlinPredicate
case class IsString() extends GremlinPredicate
case class IsRelationship() extends GremlinPredicate
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ object RemoveUnusedAliases extends GremlinRewriter {
case Contains(value) => strings(value)
case IsNode() => Seq()
case IsRelationship() => Seq()
case IsString() => Seq()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ private class ExpressionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte
case LessThanOrEqual(lhs, rhs) => comparison(lhs, rhs, p.lte)
case GreaterThan(lhs, rhs) => comparison(lhs, rhs, p.gt)
case GreaterThanOrEqual(lhs, rhs) => comparison(lhs, rhs, p.gte)
case StartsWith(lhs, rhs) => comparison(lhs, rhs, p.startsWith)
case EndsWith(lhs, rhs) => comparison(lhs, rhs, p.endsWith)
case Contains(lhs, rhs) => comparison(lhs, rhs, p.contains)
case StartsWith(lhs, rhs) => comparison(lhs, rhs, p.isString, p.startsWith)
case EndsWith(lhs, rhs) => comparison(lhs, rhs, p.isString, p.endsWith)
case Contains(lhs, rhs) => comparison(lhs, rhs, p.isString, p.contains)

case In(lhs, rhs) =>
membership(lhs, rhs)
Expand Down Expand Up @@ -373,6 +373,32 @@ private class ExpressionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte
bothNotNull(lhs, rhs, traversal, rhsName)
}

private def comparison(
lhs: Expression,
rhs: Expression,
typePredicate: P,
predicate: String => P): GremlinSteps[T, P] = {
val rhsName = context.generateName()
val ifTrue = anyMatch(__.where(predicate(rhsName)))
val p = context.dsl.predicates()

val lhsT = walkLocal(lhs)
val rhsT = walkLocal(rhs)

rhsT
.as(rhsName)
.flatMap(lhsT)
.choose(
__.or(
__.is(p.isEq(NULL)),
__.not(__.is(typePredicate)),
__.select(rhsName).is(p.isEq(NULL)),
__.not(__.select(rhsName).is(typePredicate))),
__.constant(NULL),
ifTrue
)
}

private def membership(lhs: Expression, rhs: Expression): GremlinSteps[T, P] = {
val lhsT = walkLocal(lhs)
val rhsT = walkLocal(rhs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ private class ProjectionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte
case object Expression extends ReturnFunctionType
case object Pivot extends ReturnFunctionType

private val p = context.dsl.predicates()

def walk(
distinct: Boolean,
items: Seq[ReturnItem],
Expand Down Expand Up @@ -212,7 +214,7 @@ private class ProjectionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte
expression match {
case _: Add | _: ContainerIndex | _: CountStar | _: Divide | _: FunctionInvocation | _: ListLiteral | _: Literal |
_: MapExpression | _: Modulo | _: Multiply | _: Null | _: Parameter | _: PatternComprehension | _: Pow |
_: Property | _: Subtract | _: Variable =>
_: Property | _: Subtract | _: Variable | _: StartsWith | _: Contains | _: EndsWith =>
false
case _ =>
true
Expand Down Expand Up @@ -279,7 +281,6 @@ private class ProjectionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte
subTraversal: GremlinSteps[T, P],
variable: String,
expression: Expression): GremlinSteps[T, P] = {
val p = context.dsl.predicates()

lazy val finalizeNode =
__.valueMap(true)
Expand Down Expand Up @@ -334,8 +335,6 @@ private class ProjectionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte
}

private def aggregation(alias: String, expression: Expression): (ReturnFunctionType, GremlinSteps[T, P]) = {
val p = context.dsl.predicates()

expression match {
case FunctionInvocation(_, FunctionName(fnName), distinct, args) =>
if (args.flatMap(n => n +: n.subExpressions).exists {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,18 @@ public class CustomPredicateTest {
public void startsWith() throws Exception {
assertThat(CustomPredicate.cypherStartsWith("a").test("abcd")).isTrue();
assertThat(CustomPredicate.cypherStartsWith("x").test("abcd")).isFalse();
assertThat(CustomPredicate.cypherStartsWith("x").test(null)).isFalse();
assertThat(CustomPredicate.cypherStartsWith(null).test("abcd")).isFalse();
assertThat(CustomPredicate.cypherStartsWith(null).test(null)).isFalse();
}

@Test
public void endsWith() throws Exception {
assertThat(CustomPredicate.cypherEndsWith("d").test("abcd")).isTrue();
assertThat(CustomPredicate.cypherEndsWith("x").test("abcd")).isFalse();
assertThat(CustomPredicate.cypherEndsWith("x").test(null)).isFalse();
assertThat(CustomPredicate.cypherEndsWith(null).test("abcd")).isFalse();
assertThat(CustomPredicate.cypherEndsWith(null).test(null)).isFalse();
}

@Test
public void contains() throws Exception {
assertThat(CustomPredicate.cypherContains("bc").test("abcd")).isTrue();
assertThat(CustomPredicate.cypherContains("x").test("abcd")).isFalse();
assertThat(CustomPredicate.cypherContains("x").test(null)).isFalse();
assertThat(CustomPredicate.cypherContains(null).test("abcd")).isFalse();
assertThat(CustomPredicate.cypherContains(null).test(null)).isFalse();
}

}

0 comments on commit 2d43f50

Please sign in to comment.