Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@
// Licensed under the MIT License.
package com.azure.spring.data.cosmos.core.generator;

import com.azure.cosmos.models.SqlParameter;
import com.azure.cosmos.models.SqlQuerySpec;
import com.azure.spring.data.cosmos.core.query.Criteria;
import com.azure.spring.data.cosmos.core.query.CriteriaType;
import com.azure.spring.data.cosmos.core.query.DocumentQuery;
import com.azure.spring.data.cosmos.exception.IllegalQueryException;
import static com.azure.spring.data.cosmos.core.convert.MappingCosmosConverter.toCosmosDbValue;
Comment thread
jacko9et marked this conversation as resolved.
Outdated

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

import org.javatuples.Pair;
import org.springframework.data.domain.Sort;
import org.springframework.data.repository.query.parser.Part;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;

import static com.azure.spring.data.cosmos.core.convert.MappingCosmosConverter.toCosmosDbValue;
import com.azure.cosmos.models.SqlParameter;
import com.azure.cosmos.models.SqlQuerySpec;
import com.azure.spring.data.cosmos.core.query.Criteria;
import com.azure.spring.data.cosmos.core.query.CriteriaType;
import com.azure.spring.data.cosmos.core.query.DocumentQuery;
import com.azure.spring.data.cosmos.exception.IllegalQueryException;

/**
* Base class for generating sql query
Expand Down Expand Up @@ -124,29 +126,35 @@ private String generateClosedQuery(@NonNull String left, @NonNull String right,
}

@SuppressWarnings("unchecked")
private String generateInQuery(Criteria criteria) {
private String generateInQuery(@NonNull Criteria criteria, @NonNull List<Pair<String, Object>> parameters) {
Assert.isTrue(criteria.getSubjectValues().size() == 1,
"Criteria should have only one subject value");
if (!(criteria.getSubjectValues().get(0) instanceof Collection)) {
throw new IllegalQueryException("IN keyword requires Collection type in parameters");
}
final List<String> inRangeValues = new ArrayList<>();

final Collection<Object> values = (Collection<Object>) criteria.getSubjectValues().get(0);

values.forEach(o -> {
if (o instanceof Integer || o instanceof Long) {
inRangeValues.add(String.format("%d", o));
} else if (o instanceof String) {
inRangeValues.add(String.format("'%s'", (String) o));
} else if (o instanceof Boolean) {
inRangeValues.add(String.format("%b", (Boolean) o));
final StringBuilder builder = new StringBuilder();
int index = 0;
Iterator<Object> iterator = values.iterator();
while (iterator.hasNext()) {
Object o = iterator.next();
if (o instanceof String || o instanceof Integer || o instanceof Long || o instanceof Boolean) {
String key = "p" + index;
if (index == 0) {
builder.append("@").append(key);
} else {
builder.append(",@").append(key);
}
parameters.add(Pair.with(key, o));
index++;
} else {
throw new IllegalQueryException("IN keyword Range only support Number and String type.");
}
});
}

final String inRange = String.join(",", inRangeValues);
return String.format("r.%s %s (%s)", criteria.getSubject(), criteria.getType().getSqlKeyword(), inRange);
return String.format("r.%s %s (%s)", criteria.getSubject(), criteria.getType().getSqlKeyword(), builder.toString());
}

private String generateQueryBody(@NonNull Criteria criteria, @NonNull List<Pair<String, Object>> parameters) {
Expand All @@ -157,7 +165,7 @@ private String generateQueryBody(@NonNull Criteria criteria, @NonNull List<Pair<
return "";
case IN:
case NOT_IN:
return generateInQuery(criteria);
return generateInQuery(criteria, parameters);
case BETWEEN:
return generateBetween(criteria, parameters);
case IS_NULL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Optional;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertTrue;

@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(classes = TestRepositoryConfig.class)
Expand Down Expand Up @@ -526,4 +527,17 @@ public void testFindAllByPartitionKey() {
assertThat(findAll.size()).isEqualTo(1);
assertThat(findAll.contains(PROJECT_3)).isTrue();
}

@Test
public void testSqlInjection() {
List<Project> projects = null;
projects = this.repository.findAllByNameIn(Arrays.asList(NAME_1, NAME_2));
assertProjectListEquals(projects, Arrays.asList(PROJECT_1, PROJECT_2));
projects = this.repository.findAllByStarCountIn(Arrays.asList(STAR_COUNT_0, STAR_COUNT_1));
assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_1, PROJECT_4));
projects = this.repository.findAllByHasReleasedIn(Collections.singleton(false));
assertProjectListEquals(projects, Arrays.asList(PROJECT_4));
projects = this.repository.findAllByNameIn(Collections.singleton("sql) or (r.name <> ''"));
assertTrue(projects.isEmpty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,11 @@ List<Project> findByNameOrCreatorAndForkCountOrStarCount(String name, String cre
List<Project> findByNameIsNotNullAndHasReleased(boolean hasReleased);

Page<Project> findByForkCount(Long forkCount, Pageable pageable);


List<Project> findAllByNameIn(Collection<String> names);

List<Project> findAllByStarCountIn(Collection<Long> startCounts);

List<Project> findAllByHasReleasedIn(Collection<Boolean> releases);
}