Skip to content

List slices/ranges #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.catchThrowable;
import static org.assertj.core.api.Assertions.tuple;

import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
Expand Down Expand Up @@ -98,6 +99,30 @@ public void listIndexProjection() throws Exception {
.containsExactly(1L);
}

@Test
public void listNegativeIndex() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH [1, 2, 3] AS list " +
"RETURN list[-1] AS i1, list[-3] AS i2"
);

assertThat(results)
.extracting("i1", "i2")
.containsExactly(tuple(3L, 1L));
}

@Test
public void nonExistentListIndex() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH [1, 2, 3] AS list " +
"RETURN list[3] AS n1, list[-4] AS n2"
);

assertThat(results)
.extracting("n1", "n2")
.containsExactly(tuple(null, null));
}

@Test
public void nullList() throws Exception {
List<Map<String, Object>> results = submitAndGet(
Expand Down Expand Up @@ -164,4 +189,16 @@ public void nonExistentMapIndex() throws Exception {
.extracting("i")
.containsExactly((Object) null);
}

@Test
public void mapInMap() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH {foo: {bar: 'baz'}} AS map " +
"RETURN map['foo']['bar'] AS r"
);

assertThat(results)
.extracting("r")
.containsExactly("baz");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright (c) 2018 "Neo4j, Inc." [https://neo4j.com]
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opencypher.gremlin.queries;

import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;

import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.opencypher.gremlin.rules.GremlinServerExternalResource;

public class ListSliceTest {

@ClassRule
public static final GremlinServerExternalResource gremlinServer = new GremlinServerExternalResource();

@Before
public void setUp() {
gremlinServer.gremlinClient().submit("g.V().drop()").all().join();
}

private List<Map<String, Object>> submitAndGet(String cypher) {
return submitAndGet(cypher, emptyMap());
}

private List<Map<String, Object>> submitAndGet(String cypher, Map<String, ?> parameters) {
return gremlinServer.cypherGremlinClient().submit(cypher, parameters).all();
}

@Test
public void listRange() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH [1, 2, 3, 4, 5] AS list " +
"RETURN list[1..3] AS r"
);

assertThat(results)
.extracting("r")
.containsExactly(asList(2L, 3L));
}

@Test
public void listRangeImplicitEnd() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH [1, 2, 3] AS list " +
"RETURN list[1..] AS r"
);

assertThat(results)
.extracting("r")
.containsExactly(asList(2L, 3L));
}

@Test
public void listRangeImplicitStart() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH [1, 2, 3] AS list " +
"RETURN list[..2] AS r"
);

assertThat(results)
.extracting("r")
.containsExactly(asList(1L, 2L));
}

@Test
public void listSingletonRange() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH [1, 2, 3] AS list " +
"RETURN list[0..1] AS r"
);

assertThat(results)
.extracting("r")
.containsExactly(singletonList(1L));
}

@Test
public void listEmptyRange() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH [1, 2, 3] AS list " +
"RETURN list[0..0] AS r"
);

assertThat(results)
.extracting("r")
.containsExactly(emptyList());
}

@Test
public void listNegativeRange() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH [1, 2, 3] AS list " +
"RETURN list[-3..-1] AS r"
);

assertThat(results)
.extracting("r")
.containsExactly(asList(1L, 2L));
}

@Test
public void listInvalidRange() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH [1, 2, 3] AS list " +
"RETURN list[3..1] AS r"
);

assertThat(results)
.extracting("r")
.containsExactly(emptyList());
}

@Test
public void listExceedingRange() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH [1, 2, 3] AS list " +
"RETURN list[-5..5] AS r"
);

assertThat(results)
.extracting("r")
.containsExactly(asList(1L, 2L, 3L));
}

@Test
public void listRangeParametrized() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH [1, 2, 3] AS list " +
"RETURN list[$from..$to] AS r",
new HashMap<>(ImmutableMap.of("from", 1, "to", 3))
);

assertThat(results)
.extracting("r")
.containsExactly(asList(2L, 3L));
}

@Test
public void listParametrizedEmptyRange() throws Exception {
List<Map<String, Object>> results = submitAndGet(
"WITH [1, 2, 3] AS list " +
"RETURN list[$from..$to] AS r",
new HashMap<>(ImmutableMap.of("from", 3, "to", 1))
);

assertThat(results)
.extracting("r")
.containsExactly(emptyList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,10 @@ public static CustomFunction containerIndex() {
}

if (container instanceof List) {
if (!(index instanceof Number)) {
String indexClass = index.getClass().getName();
throw new IllegalArgumentException("List element access by non-integer: " + indexClass);
}
List list = (List) container;
int i = ((Number) index).intValue();
if (i < 0 || i > list.size()) {
int size = list.size();
int i = normalizeContainerIndex(index, size);
if (i < 0 || i > size) {
return Tokens.NULL;
}
return list.get(i);
Expand All @@ -256,6 +253,58 @@ public static CustomFunction containerIndex() {
);
}

public static CustomFunction listSlice() {
return new CustomFunction(
"listSlice",
traverser -> {
List<?> args = (List<?>) traverser.get();
Object container = args.get(0);
Object from = args.get(1);
Object to = args.get(2);

if (container == Tokens.NULL) {
return Tokens.NULL;
}

if (container instanceof List) {
List list = (List) container;
int size = list.size();
int f = normalizeRangeIndex(from, size);
int t = normalizeRangeIndex(to, size);
if (f >= t) {
return new ArrayList<>();
}
return new ArrayList<>(list.subList(f, t));
}

String containerClass = container.getClass().getName();
throw new IllegalArgumentException(
"Invalid element access of " + containerClass + " by range"
);
}
);
}

private static int normalizeContainerIndex(Object index, int containerSize) {
if (!(index instanceof Number)) {
String indexClass = index.getClass().getName();
throw new IllegalArgumentException("List element access by non-integer: " + indexClass);
}
int i = ((Number) index).intValue();
return (i >= 0) ? i : containerSize + i;
}

private static int normalizeRangeIndex(Object index, int size) {
int i = normalizeContainerIndex(index, size);
if (i < 0) {
return 0;
}
if (i > size) {
return size;
}
return i;
}

public static CustomFunction pathComprehension() {
return new CustomFunction(
"pathComprehension",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.opencypher.gremlin.translation.exception.SyntaxException
import org.opencypher.gremlin.translation.walker.NodeUtils._
import org.opencypher.gremlin.traversal.CustomFunction
import org.opencypher.v9_0.expressions._
import org.opencypher.v9_0.util.InputPosition
import org.opencypher.v9_0.util.symbols._

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -170,16 +171,32 @@ private class ExpressionWalker[T, P](context: StatementContext[T, P], g: Gremlin

case ContainerIndex(expr, idx) =>
(typeOf(expr), idx) match {
case (_: ListType, _: IntegerLiteral) =>
val index = inlineExpressionValue(idx, context, classOf[java.lang.Long])
case (_: ListType, l: IntegerLiteral) if l.value >= 0 =>
walkLocal(expr).coalesce(
__.range(Scope.local, index, index + 1),
__.range(Scope.local, l.value, l.value + 1),
__.constant(NULL)
)
case _ =>
asList(expr, idx).map(CustomFunction.containerIndex())
}

case ListSlice(expr, maybeFrom, maybeTo) =>
val fromIdx = maybeFrom.getOrElse(SignedDecimalIntegerLiteral("0")(InputPosition.NONE))
val toIdx = maybeTo.getOrElse(SignedDecimalIntegerLiteral("-1")(InputPosition.NONE))
(fromIdx, toIdx) match {
case (from: IntegerLiteral, to: IntegerLiteral)
if from.value == to.value || (from.value > to.value && to.value >= 0) =>
walkLocal(expr).limit(0).fold()
case (from: IntegerLiteral, to: IntegerLiteral) if from.value >= 0 && (to.value >= 1 || to.value == -1) =>
val rangeT = __.range(Scope.local, from.value, to.value)
if (to.value - from.value == 1) {
rangeT.fold()
}
walkLocal(expr).coalesce(rangeT, __.constant(NULL))
case _ =>
asList(expr, fromIdx, toIdx).map(CustomFunction.listSlice())
}

case FunctionInvocation(_, FunctionName(fnName), distinct, args) =>
val traversals = args.map(walkLocal)
val traversal = fnName.toLowerCase match {
Expand Down