diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/generator/AbstractQueryGenerator.java b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/generator/AbstractQueryGenerator.java index e319398c49a0..bb20f8a3ed59 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/generator/AbstractQueryGenerator.java +++ b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/generator/AbstractQueryGenerator.java @@ -67,6 +67,7 @@ private String generateBinaryQuery(@NonNull Criteria criteria, @NonNull List> parameters) { Assert.isTrue(criteria.getSubjectValues().size() == 1, "Criteria should have only one subject value"); if (!(criteria.getSubjectValues().get(0) instanceof Collection)) { throw new IllegalQueryException("IN keyword requires Collection type in parameters"); } - final List inRangeValues = new ArrayList<>(); + final Collection values = (Collection) criteria.getSubjectValues().get(0); - values.forEach(o -> { - if (o instanceof Integer || o instanceof Long) { - inRangeValues.add(String.format("%d", o)); - } else if (o instanceof String) { - inRangeValues.add(String.format("'%s'", (String) o)); - } else if (o instanceof Boolean) { - inRangeValues.add(String.format("%b", (Boolean) o)); + final List paras = new ArrayList<>(); + for (Object o : values) { + if (o instanceof String || o instanceof Integer || o instanceof Long || o instanceof Boolean) { + String key = "p" + parameters.size(); + paras.add("@" + key); + parameters.add(Pair.with(key, o)); } else { throw new IllegalQueryException("IN keyword Range only support Number and String type."); } - }); + } - final String inRange = String.join(",", inRangeValues); - return String.format("r.%s %s (%s)", criteria.getSubject(), criteria.getType().getSqlKeyword(), inRange); + return String.format("r.%s %s (%s)", criteria.getSubject(), criteria.getType().getSqlKeyword(), + String.join(",", paras)); } private String generateQueryBody(@NonNull Criteria criteria, @NonNull List> parameters) { @@ -157,7 +158,7 @@ private String generateQueryBody(@NonNull Criteria criteria, @NonNull List subjects = sort.stream().map(this::getParameter).collect(Collectors.toList()); return queryTail - + " " - + String.join(",", subjects); + + " " + + String.join(",", subjects); } @NonNull @@ -250,9 +250,9 @@ protected SqlQuerySpec generateCosmosQuery(@NonNull DocumentQuery query, final List> parameters = queryBody.getValue1(); List sqlParameters = parameters.stream() - .map(p -> new SqlParameter("@" + p.getValue0(), - toCosmosDbValue(p.getValue1()))) - .collect(Collectors.toList()); + .map(p -> new SqlParameter("@" + p.getValue0(), + toCosmosDbValue(p.getValue1()))) + .collect(Collectors.toList()); return new SqlQuerySpec(queryString, sqlParameters); } diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/ProjectRepositoryIT.java b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/ProjectRepositoryIT.java index eeececf58d12..f88e03546b3a 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/ProjectRepositoryIT.java +++ b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/ProjectRepositoryIT.java @@ -27,6 +27,7 @@ import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertTrue; @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = TestRepositoryConfig.class) @@ -73,7 +74,7 @@ public class ProjectRepositoryIT { private static final List PROJECTS = Arrays.asList(PROJECT_0, PROJECT_1, PROJECT_2, PROJECT_3, PROJECT_4); private static final CosmosEntityInformation entityInformation = - new CosmosEntityInformation<>(Project.class); + new CosmosEntityInformation<>(Project.class); private static CosmosTemplate staticTemplate; private static boolean isSetupDone; @@ -217,7 +218,7 @@ public void testFindByWithOrOr() { @Test public void testFindByWithOrAndOr() { List projects = repository.findByNameOrCreatorAndForkCountOrStarCount(NAME_1, CREATOR_0, - FORK_COUNT_2, STAR_COUNT_3); + FORK_COUNT_2, STAR_COUNT_3); assertProjectListEquals(projects, Arrays.asList(PROJECT_1, PROJECT_3)); @@ -226,7 +227,7 @@ public void testFindByWithOrAndOr() { assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_1, PROJECT_3, PROJECT_4)); projects = repository.findByNameOrCreatorAndForkCountOrStarCount(FAKE_NAME, CREATOR_1, - FORK_COUNT_0, FAKE_COUNT); + FORK_COUNT_0, FAKE_COUNT); Assert.assertTrue(projects.isEmpty()); } @@ -293,7 +294,7 @@ public void testFindByLessThanAndGreaterThan() { @Test public void testFindByLessThanEqualsAndGreaterThanEquals() { List projects = repository.findByForkCountLessThanEqualAndStarCountGreaterThan( - STAR_COUNT_MIN, FORK_COUNT_0); + STAR_COUNT_MIN, FORK_COUNT_0); Assert.assertTrue(projects.isEmpty()); @@ -393,17 +394,17 @@ public void testFindByIn() { @Test public void testFindByInWithAnd() { List projects = repository.findByCreatorInAndStarCountIn(Arrays.asList(CREATOR_0, CREATOR_1), - Arrays.asList(STAR_COUNT_2, STAR_COUNT_3)); + Arrays.asList(STAR_COUNT_2, STAR_COUNT_3)); Assert.assertTrue(projects.isEmpty()); projects = repository.findByCreatorInAndStarCountIn(Arrays.asList(CREATOR_0, CREATOR_1), - Arrays.asList(STAR_COUNT_0, STAR_COUNT_2)); + Arrays.asList(STAR_COUNT_0, STAR_COUNT_2)); assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_4)); projects = repository.findByCreatorInAndStarCountIn(Arrays.asList(CREATOR_0, CREATOR_1, CREATOR_2), - Arrays.asList(STAR_COUNT_0, STAR_COUNT_1, STAR_COUNT_2)); + Arrays.asList(STAR_COUNT_0, STAR_COUNT_1, STAR_COUNT_2)); assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_1, PROJECT_2, PROJECT_4)); } @@ -411,7 +412,7 @@ public void testFindByInWithAnd() { @Test public void testFindByNotIn() { List projects = repository.findByCreatorNotIn( - Arrays.asList(CREATOR_0, CREATOR_1, CREATOR_2, CREATOR_3)); + Arrays.asList(CREATOR_0, CREATOR_1, CREATOR_2, CREATOR_3)); Assert.assertTrue(projects.isEmpty()); @@ -427,17 +428,17 @@ public void testFindByNotIn() { @Test public void testFindByInWithNotIn() { List projects = repository.findByCreatorInAndStarCountNotIn(Collections.singletonList(FAKE_CREATOR), - Arrays.asList(STAR_COUNT_2, STAR_COUNT_3)); + Arrays.asList(STAR_COUNT_2, STAR_COUNT_3)); Assert.assertTrue(projects.isEmpty()); projects = repository.findByCreatorInAndStarCountNotIn(Arrays.asList(CREATOR_0, CREATOR_1), - Arrays.asList(STAR_COUNT_0, STAR_COUNT_2)); + Arrays.asList(STAR_COUNT_0, STAR_COUNT_2)); assertProjectListEquals(projects, Collections.singletonList(PROJECT_1)); projects = repository.findByCreatorInAndStarCountNotIn(Arrays.asList(CREATOR_0, CREATOR_1, CREATOR_2), - Arrays.asList(STAR_COUNT_1, STAR_COUNT_2)); + Arrays.asList(STAR_COUNT_1, STAR_COUNT_2)); assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_4)); } @@ -449,7 +450,7 @@ public void testFindByNameIsNull() { Assert.assertTrue(projects.isEmpty()); final Project nullNameProject = new Project("id-999", null, CREATOR_0, true, STAR_COUNT_0, - FORK_COUNT_0); + FORK_COUNT_0); this.repository.save(nullNameProject); projects = repository.findByNameIsNull(); @@ -478,7 +479,7 @@ public void testFindByNameIsNullWithAnd() { Assert.assertTrue(projects.isEmpty()); final Project nullNameProject = new Project("id-999", null, CREATOR_0, true, STAR_COUNT_0, - FORK_COUNT_0); + FORK_COUNT_0); this.repository.save(nullNameProject); projects = repository.findByNameIsNullAndForkCount(FORK_COUNT_0); @@ -526,4 +527,10 @@ public void testFindAllByPartitionKey() { assertThat(findAll.size()).isEqualTo(1); assertThat(findAll.contains(PROJECT_3)).isTrue(); } + + @Test + public void testSqlInjection() { + List projects = this.repository.findAllByNameIn(Collections.singleton("sql) or (r.name <> ''")); + assertTrue(projects.isEmpty()); + } } diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/repository/ProjectRepository.java b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/repository/ProjectRepository.java index 562dab56f24a..3d99657f19d7 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/repository/ProjectRepository.java +++ b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/repository/ProjectRepository.java @@ -83,4 +83,11 @@ List findByNameOrCreatorAndForkCountOrStarCount(String name, String cre List findByNameIsNotNullAndHasReleased(boolean hasReleased); Page findByForkCount(Long forkCount, Pageable pageable); + + + List findAllByNameIn(Collection names); + + List findAllByStarCountIn(Collection startCounts); + + List findAllByHasReleasedIn(Collection releases); }