Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ private String generateBinaryQuery(@NonNull Criteria criteria, @NonNull List<Pai

/**
* Get condition string with function
*
* @param ignoreCase ignore case flag
* @param sqlKeyword sql key word, operation name
* @param subject sql column name
Expand All @@ -84,6 +85,7 @@ private String getCondition(final Part.IgnoreCaseType ignoreCase, final String s

/**
* Get condition string without function
*
* @param ignoreCase ignore case flag
* @param sqlKeyword sql key word, operation name
* @param subject sql column name
Expand Down Expand Up @@ -118,35 +120,34 @@ private String generateBetween(@NonNull Criteria criteria, @NonNull List<Pair<St
private String generateClosedQuery(@NonNull String left, @NonNull String right, CriteriaType type) {
Assert.isTrue(CriteriaType.isClosed(type)
&& CriteriaType.isBinary(type),
"Criteria type should be binary and closure operation");
"Criteria type should be binary and closure operation");

return String.join(" ", left, type.getSqlKeyword(), right);
}

@SuppressWarnings("unchecked")
private String generateInQuery(Criteria criteria) {
private String generateInQuery(@NonNull Criteria criteria, @NonNull List<Pair<String, Object>> parameters) {
Assert.isTrue(criteria.getSubjectValues().size() == 1,
"Criteria should have only one subject value");
if (!(criteria.getSubjectValues().get(0) instanceof Collection)) {
throw new IllegalQueryException("IN keyword requires Collection type in parameters");
}
final List<String> inRangeValues = new ArrayList<>();

final Collection<Object> values = (Collection<Object>) criteria.getSubjectValues().get(0);

values.forEach(o -> {
if (o instanceof Integer || o instanceof Long) {
inRangeValues.add(String.format("%d", o));
} else if (o instanceof String) {
inRangeValues.add(String.format("'%s'", (String) o));
} else if (o instanceof Boolean) {
inRangeValues.add(String.format("%b", (Boolean) o));
final List<String> paras = new ArrayList<>();
for (Object o : values) {
if (o instanceof String || o instanceof Integer || o instanceof Long || o instanceof Boolean) {
String key = "p" + parameters.size();
paras.add("@" + key);
parameters.add(Pair.with(key, o));
} else {
throw new IllegalQueryException("IN keyword Range only support Number and String type.");
}
});
}

final String inRange = String.join(",", inRangeValues);
return String.format("r.%s %s (%s)", criteria.getSubject(), criteria.getType().getSqlKeyword(), inRange);
return String.format("r.%s %s (%s)", criteria.getSubject(), criteria.getType().getSqlKeyword(),
String.join(",", paras));
}

private String generateQueryBody(@NonNull Criteria criteria, @NonNull List<Pair<String, Object>> parameters) {
Expand All @@ -157,7 +158,7 @@ private String generateQueryBody(@NonNull Criteria criteria, @NonNull List<Pair<
return "";
case IN:
case NOT_IN:
return generateInQuery(criteria);
return generateInQuery(criteria, parameters);
case BETWEEN:
return generateBetween(criteria, parameters);
case IS_NULL:
Expand Down Expand Up @@ -193,9 +194,8 @@ private String generateQueryBody(@NonNull Criteria criteria, @NonNull List<Pair<
}

/**
* Generate a query body for interface QuerySpecGenerator.
* The query body compose of Sql query String and its' parameters.
* The parameters organized as a list of Pair, for each pair compose parameter name and value.
* Generate a query body for interface QuerySpecGenerator. The query body compose of Sql query String and its'
* parameters. The parameters organized as a list of Pair, for each pair compose parameter name and value.
*
* @param query the representation for query method.
* @return A pair tuple compose of Sql query.
Expand Down Expand Up @@ -229,8 +229,8 @@ private String generateQuerySort(@NonNull Sort sort) {
final List<String> subjects = sort.stream().map(this::getParameter).collect(Collectors.toList());

return queryTail
+ " "
+ String.join(",", subjects);
+ " "
+ String.join(",", subjects);
}

@NonNull
Expand All @@ -250,9 +250,9 @@ protected SqlQuerySpec generateCosmosQuery(@NonNull DocumentQuery query,
final List<Pair<String, Object>> parameters = queryBody.getValue1();

List<SqlParameter> sqlParameters = parameters.stream()
.map(p -> new SqlParameter("@" + p.getValue0(),
toCosmosDbValue(p.getValue1())))
.collect(Collectors.toList());
.map(p -> new SqlParameter("@" + p.getValue0(),
toCosmosDbValue(p.getValue1())))
.collect(Collectors.toList());

return new SqlQuerySpec(queryString, sqlParameters);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Optional;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertTrue;

@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(classes = TestRepositoryConfig.class)
Expand Down Expand Up @@ -73,7 +74,7 @@ public class ProjectRepositoryIT {
private static final List<Project> PROJECTS = Arrays.asList(PROJECT_0, PROJECT_1, PROJECT_2, PROJECT_3, PROJECT_4);

private static final CosmosEntityInformation<Project, String> entityInformation =
new CosmosEntityInformation<>(Project.class);
new CosmosEntityInformation<>(Project.class);

private static CosmosTemplate staticTemplate;
private static boolean isSetupDone;
Expand Down Expand Up @@ -217,7 +218,7 @@ public void testFindByWithOrOr() {
@Test
public void testFindByWithOrAndOr() {
List<Project> projects = repository.findByNameOrCreatorAndForkCountOrStarCount(NAME_1, CREATOR_0,
FORK_COUNT_2, STAR_COUNT_3);
FORK_COUNT_2, STAR_COUNT_3);

assertProjectListEquals(projects, Arrays.asList(PROJECT_1, PROJECT_3));

Expand All @@ -226,7 +227,7 @@ public void testFindByWithOrAndOr() {
assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_1, PROJECT_3, PROJECT_4));

projects = repository.findByNameOrCreatorAndForkCountOrStarCount(FAKE_NAME, CREATOR_1,
FORK_COUNT_0, FAKE_COUNT);
FORK_COUNT_0, FAKE_COUNT);

Assert.assertTrue(projects.isEmpty());
}
Expand Down Expand Up @@ -293,7 +294,7 @@ public void testFindByLessThanAndGreaterThan() {
@Test
public void testFindByLessThanEqualsAndGreaterThanEquals() {
List<Project> projects = repository.findByForkCountLessThanEqualAndStarCountGreaterThan(
STAR_COUNT_MIN, FORK_COUNT_0);
STAR_COUNT_MIN, FORK_COUNT_0);

Assert.assertTrue(projects.isEmpty());

Expand Down Expand Up @@ -393,25 +394,25 @@ public void testFindByIn() {
@Test
public void testFindByInWithAnd() {
List<Project> projects = repository.findByCreatorInAndStarCountIn(Arrays.asList(CREATOR_0, CREATOR_1),
Arrays.asList(STAR_COUNT_2, STAR_COUNT_3));
Arrays.asList(STAR_COUNT_2, STAR_COUNT_3));

Assert.assertTrue(projects.isEmpty());

projects = repository.findByCreatorInAndStarCountIn(Arrays.asList(CREATOR_0, CREATOR_1),
Arrays.asList(STAR_COUNT_0, STAR_COUNT_2));
Arrays.asList(STAR_COUNT_0, STAR_COUNT_2));

assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_4));

projects = repository.findByCreatorInAndStarCountIn(Arrays.asList(CREATOR_0, CREATOR_1, CREATOR_2),
Arrays.asList(STAR_COUNT_0, STAR_COUNT_1, STAR_COUNT_2));
Arrays.asList(STAR_COUNT_0, STAR_COUNT_1, STAR_COUNT_2));

assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_1, PROJECT_2, PROJECT_4));
}

@Test
public void testFindByNotIn() {
List<Project> projects = repository.findByCreatorNotIn(
Arrays.asList(CREATOR_0, CREATOR_1, CREATOR_2, CREATOR_3));
Arrays.asList(CREATOR_0, CREATOR_1, CREATOR_2, CREATOR_3));

Assert.assertTrue(projects.isEmpty());

Expand All @@ -427,17 +428,17 @@ public void testFindByNotIn() {
@Test
public void testFindByInWithNotIn() {
List<Project> projects = repository.findByCreatorInAndStarCountNotIn(Collections.singletonList(FAKE_CREATOR),
Arrays.asList(STAR_COUNT_2, STAR_COUNT_3));
Arrays.asList(STAR_COUNT_2, STAR_COUNT_3));

Assert.assertTrue(projects.isEmpty());

projects = repository.findByCreatorInAndStarCountNotIn(Arrays.asList(CREATOR_0, CREATOR_1),
Arrays.asList(STAR_COUNT_0, STAR_COUNT_2));
Arrays.asList(STAR_COUNT_0, STAR_COUNT_2));

assertProjectListEquals(projects, Collections.singletonList(PROJECT_1));

projects = repository.findByCreatorInAndStarCountNotIn(Arrays.asList(CREATOR_0, CREATOR_1, CREATOR_2),
Arrays.asList(STAR_COUNT_1, STAR_COUNT_2));
Arrays.asList(STAR_COUNT_1, STAR_COUNT_2));

assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_4));
}
Expand All @@ -449,7 +450,7 @@ public void testFindByNameIsNull() {
Assert.assertTrue(projects.isEmpty());

final Project nullNameProject = new Project("id-999", null, CREATOR_0, true, STAR_COUNT_0,
FORK_COUNT_0);
FORK_COUNT_0);

this.repository.save(nullNameProject);
projects = repository.findByNameIsNull();
Expand Down Expand Up @@ -478,7 +479,7 @@ public void testFindByNameIsNullWithAnd() {
Assert.assertTrue(projects.isEmpty());

final Project nullNameProject = new Project("id-999", null, CREATOR_0, true, STAR_COUNT_0,
FORK_COUNT_0);
FORK_COUNT_0);

this.repository.save(nullNameProject);
projects = repository.findByNameIsNullAndForkCount(FORK_COUNT_0);
Expand Down Expand Up @@ -526,4 +527,10 @@ public void testFindAllByPartitionKey() {
assertThat(findAll.size()).isEqualTo(1);
assertThat(findAll.contains(PROJECT_3)).isTrue();
}

@Test
public void testSqlInjection() {
List<Project> projects = this.repository.findAllByNameIn(Collections.singleton("sql) or (r.name <> ''"));
assertTrue(projects.isEmpty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,11 @@ List<Project> findByNameOrCreatorAndForkCountOrStarCount(String name, String cre
List<Project> findByNameIsNotNullAndHasReleased(boolean hasReleased);

Page<Project> findByForkCount(Long forkCount, Pageable pageable);


List<Project> findAllByNameIn(Collection<String> names);

List<Project> findAllByStarCountIn(Collection<Long> startCounts);

List<Project> findAllByHasReleasedIn(Collection<Boolean> releases);
}