diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/common/CosmosUtils.java b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/common/CosmosUtils.java index 9678f69d91c1..9f6aaa694794 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/common/CosmosUtils.java +++ b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/common/CosmosUtils.java @@ -6,8 +6,10 @@ import com.azure.cosmos.models.FeedResponse; import com.azure.spring.data.cosmos.core.ResponseDiagnostics; import com.azure.spring.data.cosmos.core.ResponseDiagnosticsProcessor; +import com.azure.spring.data.cosmos.exception.IllegalQueryException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.util.Assert; /** * Util class to fill and process response diagnostics @@ -45,4 +47,26 @@ public static void fillAndProcessResponseDiagnostics( // Process response diagnostics responseDiagnosticsProcessor.processResponseDiagnostics(responseDiagnostics); } + + /** + * ID value should be string value, real id type will be String, Integer, Long, + * all of these must be converted to String type. + * @param idValue id value to find + * @throws IllegalArgumentException thrown if id value fail the validation. + * @throws IllegalQueryException thrown if id value fail the validation. + * @return String id value + */ + public static String getStringIDValue(Object idValue) { + Assert.notNull(idValue, "id should not be null"); + if (idValue instanceof String) { + Assert.hasText(idValue.toString(), "id should not be empty or only whitespaces."); + return (String) idValue; + } else if (idValue instanceof Integer) { + return Integer.toString((Integer) idValue); + } else if (idValue instanceof Long) { + return Long.toString((Long) idValue); + } else { + throw new IllegalQueryException("Type of id field must be String or Integer or Long"); + } + } } diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/CosmosTemplate.java b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/CosmosTemplate.java index f24b8cb304c0..e0c748aa4197 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/CosmosTemplate.java +++ b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/CosmosTemplate.java @@ -166,13 +166,12 @@ public T findById(Object id, Class domainType) { public T findById(Object id, Class domainType, PartitionKey partitionKey) { Assert.notNull(domainType, "domainType should not be null"); Assert.notNull(partitionKey, "partitionKey should not be null"); - assertValidId(id); - + String idToQuery = CosmosUtils.getStringIDValue(id); final String containerName = getContainerName(domainType); return cosmosAsyncClient .getDatabase(databaseName) .getContainer(containerName) - .readItem(id.toString(), partitionKey, JsonNode.class) + .readItem(idToQuery, partitionKey, JsonNode.class) .flatMap(cosmosItemResponse -> { CosmosUtils.fillAndProcessResponseDiagnostics(responseDiagnosticsProcessor, cosmosItemResponse.getDiagnostics(), null); @@ -195,10 +194,9 @@ public T findById(Object id, Class domainType, PartitionKey partitionKey) public T findById(String containerName, Object id, Class domainType) { Assert.hasText(containerName, "containerName should not be null, empty or only whitespaces"); Assert.notNull(domainType, "domainType should not be null"); - assertValidId(id); final String query = String.format("select * from root where root.id = '%s'", - id.toString()); + CosmosUtils.getStringIDValue(id)); final CosmosQueryRequestOptions options = new CosmosQueryRequestOptions(); options.setQueryMetricsEnabled(enableQueryMetrics); return cosmosAsyncClient @@ -429,8 +427,7 @@ public CosmosContainerProperties createContainerIfNotExists(CosmosEntityInformat */ public void deleteById(String containerName, Object id, PartitionKey partitionKey) { Assert.hasText(containerName, "containerName should not be null, empty or only whitespaces"); - assertValidId(id); - + String idToDelete = CosmosUtils.getStringIDValue(id); LOGGER.debug("execute deleteById in database {} container {}", this.databaseName, containerName); @@ -439,7 +436,7 @@ public void deleteById(String containerName, Object id, PartitionKey partitionKe } cosmosAsyncClient.getDatabase(this.databaseName) .getContainer(containerName) - .deleteItem(id.toString(), partitionKey) + .deleteItem(idToDelete, partitionKey) .doOnNext(response -> CosmosUtils.fillAndProcessResponseDiagnostics(responseDiagnosticsProcessor, response.getDiagnostics(), null)) @@ -454,9 +451,12 @@ public List findByIds(Iterable ids, Class domainType, String c Assert.notNull(ids, "Id list should not be null"); Assert.notNull(domainType, "domainType should not be null."); Assert.hasText(containerName, "container should not be null, empty or only whitespaces"); - + final List idList = new ArrayList<>(); + for (ID id : ids) { + idList.add(CosmosUtils.getStringIDValue(id)); + } final DocumentQuery query = new DocumentQuery(Criteria.getInstance(CriteriaType.IN, "id", - Collections.singletonList(ids), Part.IgnoreCaseType.NEVER)); + Collections.singletonList(idList), Part.IgnoreCaseType.NEVER)); return find(query, domainType, containerName); } @@ -670,13 +670,6 @@ private List getPartitionKeyNames(Class domainType) { return Collections.singletonList(entityInfo.getPartitionKeyFieldName()); } - private void assertValidId(Object id) { - Assert.notNull(id, "id should not be null"); - if (id instanceof String) { - Assert.hasText(id.toString(), "id should not be empty or only whitespaces."); - } - } - private List findItems(@NonNull DocumentQuery query, @NonNull String containerName) { final SqlQuerySpec sqlQuerySpec = new FindQuerySpecGenerator().generateCosmos(query); diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/ReactiveCosmosTemplate.java b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/ReactiveCosmosTemplate.java index b1e7eaf6dd2d..50a7d990d400 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/ReactiveCosmosTemplate.java +++ b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/ReactiveCosmosTemplate.java @@ -212,10 +212,9 @@ public Mono findById(Object id, Class domainType) { public Mono findById(String containerName, Object id, Class domainType) { Assert.hasText(containerName, "containerName should not be null, empty or only whitespaces"); Assert.notNull(domainType, "domainType should not be null"); - assertValidId(id); final String query = String.format("select * from root where root.id = '%s'", - id.toString()); + CosmosUtils.getStringIDValue(id)); final CosmosQueryRequestOptions options = new CosmosQueryRequestOptions(); options.setQueryMetricsEnabled(isPopulateQueryMetrics); @@ -249,12 +248,12 @@ public Mono findById(String containerName, Object id, Class domainType @Override public Mono findById(Object id, Class domainType, PartitionKey partitionKey) { Assert.notNull(domainType, "domainType should not be null"); - assertValidId(id); + String idToFind = CosmosUtils.getStringIDValue(id); final String containerName = getContainerName(domainType); return cosmosAsyncClient.getDatabase(databaseName) .getContainer(containerName) - .readItem(id.toString(), partitionKey, JsonNode.class) + .readItem(idToFind, partitionKey, JsonNode.class) .flatMap(cosmosItemResponse -> { CosmosUtils.fillAndProcessResponseDiagnostics(responseDiagnosticsProcessor, cosmosItemResponse.getDiagnostics(), null); @@ -390,7 +389,7 @@ public Mono upsert(String containerName, T object) { @Override public Mono deleteById(String containerName, Object id, PartitionKey partitionKey) { Assert.hasText(containerName, "container name should not be null, empty or only whitespaces"); - assertValidId(id); + String idToDelete = CosmosUtils.getStringIDValue(id); if (partitionKey == null) { partitionKey = PartitionKey.NONE; @@ -398,7 +397,7 @@ public Mono deleteById(String containerName, Object id, PartitionKey parti return cosmosAsyncClient.getDatabase(this.databaseName) .getContainer(containerName) - .deleteItem(id.toString(), partitionKey) + .deleteItem(idToDelete, partitionKey) .doOnNext(cosmosItemResponse -> CosmosUtils.fillAndProcessResponseDiagnostics(responseDiagnosticsProcessor, cosmosItemResponse.getDiagnostics(), null)) @@ -589,13 +588,6 @@ private Flux findItems(@NonNull DocumentQuery query, CosmosExceptionUtils.exceptionHandler("Failed to query items", throwable)); } - private void assertValidId(Object id) { - Assert.notNull(id, "id should not be null"); - if (id instanceof String) { - Assert.hasText(id.toString(), "id should not be empty or only whitespaces."); - } - } - private List getPartitionKeyNames(Class domainType) { final CosmosEntityInformation entityInfo = entityInfoCreator.apply(domainType); diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/repository/support/CosmosEntityInformation.java b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/repository/support/CosmosEntityInformation.java index d7dfbb44ea8b..8243e6f3795a 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/repository/support/CosmosEntityInformation.java +++ b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/repository/support/CosmosEntityInformation.java @@ -201,8 +201,10 @@ private Field getIdField(Class domainType) { throw new IllegalArgumentException("domain should contain @Id field or field named id"); } else if (idField.getType() != String.class && idField.getType() != Integer.class - && idField.getType() != int.class) { - throw new IllegalArgumentException("type of id field must be String or Integer"); + && idField.getType() != int.class + && idField.getType() != Long.class + && idField.getType() != long.class) { + throw new IllegalArgumentException("type of id field must be String, Integer or Long"); } return idField; diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/domain/LongIdDomain.java b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/domain/LongIdDomain.java new file mode 100644 index 000000000000..5cf3fe2f3e2d --- /dev/null +++ b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/domain/LongIdDomain.java @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.spring.data.cosmos.domain; + +import com.azure.spring.data.cosmos.core.mapping.Document; +import org.springframework.data.annotation.Id; + +import java.util.Objects; + +@Document +public class LongIdDomain { + + @Id + private Long number; + + private String name; + + public LongIdDomain(Long number, String name) { + this.number = number; + this.name = name; + } + + public LongIdDomain() { + } + + public Long getNumber() { + return number; + } + + public void setNumber(Long number) { + this.number = number; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + LongIdDomain that = (LongIdDomain) o; + return Objects.equals(number, that.number) + && Objects.equals(name, that.name); + } + + @Override + public int hashCode() { + return Objects.hash(number, name); + } + + @Override + public String toString() { + return "LongIdDomain{" + + "number=" + + number + + ", name='" + + name + + '\'' + + '}'; + } +} diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/domain/LongIdDomainPartition.java b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/domain/LongIdDomainPartition.java new file mode 100644 index 000000000000..17edcbd0da87 --- /dev/null +++ b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/domain/LongIdDomainPartition.java @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.spring.data.cosmos.domain; + +import com.azure.spring.data.cosmos.core.mapping.Document; +import com.azure.spring.data.cosmos.core.mapping.PartitionKey; +import org.springframework.data.annotation.Id; + +import java.util.Objects; + +@Document +public class LongIdDomainPartition { + + @Id + private Long number; + + @PartitionKey + private String name; + + public LongIdDomainPartition(Long number, String name) { + this.number = number; + this.name = name; + } + + public LongIdDomainPartition() { + } + + public Long getNumber() { + return number; + } + + public void setNumber(Long number) { + this.number = number; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + LongIdDomainPartition that = (LongIdDomainPartition) o; + return Objects.equals(number, that.number) + && Objects.equals(name, that.name); + } + + @Override + public int hashCode() { + return Objects.hash(number, name); + } + + @Override + public String toString() { + return "LongIdDomain{" + + "number=" + + number + + ", name='" + + name + + '\'' + + '}'; + } +} diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/IntegerIdDomainRepositoryIT.java b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/IntegerIdDomainRepositoryIT.java index f19eaa83834e..66c4f357f171 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/IntegerIdDomainRepositoryIT.java +++ b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/IntegerIdDomainRepositoryIT.java @@ -13,7 +13,6 @@ import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; -import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; @@ -122,7 +121,6 @@ public void testSaveAllAndFindAll() { } @Test - @Ignore // TODO(kuthapar): findById IN clause not working in case of Integer public void testFindAllById() { final Iterable allById = this.repository.findAllById(Collections.singleton(DOMAIN.getNumber())); diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/LongIdDomainRepositoryIT.java b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/LongIdDomainRepositoryIT.java new file mode 100644 index 000000000000..c14990287557 --- /dev/null +++ b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/LongIdDomainRepositoryIT.java @@ -0,0 +1,295 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.spring.data.cosmos.repository.integration; + +import com.azure.spring.data.cosmos.core.CosmosTemplate; +import com.azure.spring.data.cosmos.core.query.CosmosPageRequest; +import com.azure.spring.data.cosmos.domain.LongIdDomain; +import com.azure.spring.data.cosmos.exception.CosmosAccessException; +import com.azure.spring.data.cosmos.repository.TestRepositoryConfig; +import com.azure.spring.data.cosmos.repository.repository.LongIdDomainRepository; +import com.azure.spring.data.cosmos.repository.support.CosmosEntityInformation; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.Sort; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.Arrays; +import java.util.Comparator; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(classes = TestRepositoryConfig.class) +public class LongIdDomainRepositoryIT { + + private static final Long ID_1 = 12345L; + private static final String NAME_1 = "moary"; + + private static final Long ID_2 = 67890L; + private static final String NAME_2 = "camille"; + + private static final LongIdDomain DOMAIN_1 = new LongIdDomain(ID_1, NAME_1); + private static final LongIdDomain DOMAIN_2 = new LongIdDomain(ID_2, NAME_2); + + private static final CosmosEntityInformation entityInformation = + new CosmosEntityInformation<>(LongIdDomain.class); + + private static CosmosTemplate staticTemplate; + private static boolean isSetupDone; + + @Autowired + private CosmosTemplate template; + + @Autowired + private LongIdDomainRepository repository; + + @Before + public void setUp() { + if (!isSetupDone) { + staticTemplate = template; + template.createContainerIfNotExists(entityInformation); + } + this.repository.save(DOMAIN_1); + this.repository.save(DOMAIN_2); + isSetupDone = true; + } + + @After + public void cleanup() { + this.repository.deleteAll(); + } + + @AfterClass + public static void afterClassCleanup() { + staticTemplate.deleteContainer(entityInformation.getContainerName()); + } + + @Test + public void testLongIdDomain() { + this.repository.deleteAll(); + Assert.assertFalse(this.repository.findById(ID_1).isPresent()); + + this.repository.save(DOMAIN_1); + final Optional foundOptional = this.repository.findById(ID_1); + + Assert.assertTrue(foundOptional.isPresent()); + Assert.assertEquals(DOMAIN_1.getNumber(), foundOptional.get().getNumber()); + Assert.assertEquals(DOMAIN_1.getName(), foundOptional.get().getName()); + + this.repository.delete(DOMAIN_1); + + Assert.assertFalse(this.repository.findById(ID_1).isPresent()); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidDomain() { + new CosmosEntityInformation(InvalidDomain.class); + } + + @Test + public void testBasicQuery() { + final LongIdDomain save = this.repository.save(DOMAIN_1); + Assert.assertNotNull(save); + } + + @Test + public void testSaveAndFindById() { + Assert.assertNotNull(this.repository.save(DOMAIN_1)); + + final Optional savedEntity = this.repository.findById(DOMAIN_1.getNumber()); + Assert.assertTrue(savedEntity.isPresent()); + Assert.assertEquals(DOMAIN_1, savedEntity.get()); + } + + @Test + public void testSaveAllAndFindAll() { + Assert.assertTrue(this.repository.findAll().iterator().hasNext()); + + final Set entitiesToSave = Stream.of(DOMAIN_1, DOMAIN_2).collect(Collectors.toSet()); + this.repository.saveAll(entitiesToSave); + + final Set savedEntities = StreamSupport.stream(this.repository.findAll().spliterator(), false) + .collect(Collectors.toSet()); + + Assert.assertTrue(entitiesToSave.containsAll(savedEntities)); + } + + @Test + public void testFindAllById() { + final Iterable allById = + this.repository.findAllById(Arrays.asList(DOMAIN_1.getNumber(), DOMAIN_2.getNumber())); + Assert.assertTrue(((ArrayList) allById).size() == 2); + Iterator it = allById.iterator(); + assertLongIdDomainEquals(Arrays.asList(it.next(), it.next()), Arrays.asList(DOMAIN_1, DOMAIN_2)); + } + + private void assertLongIdDomainEquals(List cur, List reference) { + cur.sort(Comparator.comparing(LongIdDomain::getNumber)); + reference.sort(Comparator.comparing(LongIdDomain::getNumber)); + Assert.assertEquals(reference, cur); + } + + @Test + public void testCount() { + Assert.assertEquals(2, repository.count()); + } + + @Test + public void testDeleteById() { + this.repository.save(DOMAIN_1); + this.repository.save(DOMAIN_2); + this.repository.deleteById(DOMAIN_1.getNumber()); + this.repository.deleteById(DOMAIN_2.getNumber()); + Assert.assertEquals(0, this.repository.count()); + } + + @Test(expected = CosmosAccessException.class) + public void testDeleteByIdShouldFailIfNothingToDelete() { + this.repository.deleteAll(); + this.repository.deleteById(DOMAIN_1.getNumber()); + } + + @Test + public void testDelete() { + this.repository.save(DOMAIN_1); + this.repository.delete(DOMAIN_1); + Assert.assertEquals(1, this.repository.count()); + } + + @Test(expected = CosmosAccessException.class) + public void testDeleteShouldFailIfNothingToDelete() { + this.repository.deleteAll(); + this.repository.delete(DOMAIN_1); + } + + @Test + public void testDeleteAll() { + this.repository.save(DOMAIN_1); + this.repository.save(DOMAIN_2); + this.repository.deleteAll(Arrays.asList(DOMAIN_1, DOMAIN_2)); + Assert.assertEquals(0, this.repository.count()); + } + + @Test + public void testExistsById() { + this.repository.save(DOMAIN_1); + Assert.assertTrue(this.repository.existsById(DOMAIN_1.getNumber())); + } + + @Test + public void testFindAllSort() { + final LongIdDomain other = new LongIdDomain(DOMAIN_1.getNumber() + 1, "other-name"); + this.repository.save(other); + this.repository.save(DOMAIN_1); + + final Sort ascSort = Sort.by(Sort.Direction.ASC, "number"); + final List ascending = StreamSupport + .stream(this.repository.findAll(ascSort).spliterator(), false) + .collect(Collectors.toList()); + Assert.assertEquals(3, ascending.size()); + Assert.assertEquals(DOMAIN_1, ascending.get(0)); + Assert.assertEquals(other, ascending.get(1)); + Assert.assertEquals(DOMAIN_2, ascending.get(2)); + + final Sort descSort = Sort.by(Sort.Direction.DESC, "number"); + final List descending = StreamSupport + .stream(this.repository.findAll(descSort).spliterator(), false) + .collect(Collectors.toList()); + Assert.assertEquals(3, descending.size()); + Assert.assertEquals(DOMAIN_2, descending.get(0)); + Assert.assertEquals(other, descending.get(1)); + Assert.assertEquals(DOMAIN_1, descending.get(2)); + + } + + @Test + public void testFindAllPageable() { + final LongIdDomain other = new LongIdDomain(DOMAIN_1.getNumber() + 1, "other-name"); + this.repository.save(other); + + final Page page1 = this.repository.findAll(new CosmosPageRequest(0, 1, null)); + final Iterator page1Iterator = page1.iterator(); + Assert.assertTrue(page1Iterator.hasNext()); + Assert.assertEquals(DOMAIN_1, page1Iterator.next()); + + final Page page2 = this.repository.findAll(new CosmosPageRequest(1, 1, null)); + final Iterator page2Iterator = page2.iterator(); + Assert.assertTrue(page2Iterator.hasNext()); + Assert.assertEquals(DOMAIN_1, page2Iterator.next()); + } + + private static class InvalidDomain { + + private long count; + + private String location; + + InvalidDomain() { + } + + InvalidDomain(long count, String location) { + this.count = count; + this.location = location; + } + + public long getCount() { + return count; + } + + public void setCount(long count) { + this.count = count; + } + + public String getLocation() { + return location; + } + + public void setLocation(String location) { + this.location = location; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + InvalidDomain that = (InvalidDomain) o; + return count == that.count + && Objects.equals(location, that.location); + } + + @Override + public int hashCode() { + return Objects.hash(count, location); + } + + @Override + public String toString() { + return "InvalidDomain{" + + "count=" + + count + + ", location='" + + location + + '\'' + + '}'; + } + } +} diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/ReactiveLongIdDomainPartitionPartitionRepositoryIT.java b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/ReactiveLongIdDomainPartitionPartitionRepositoryIT.java new file mode 100644 index 000000000000..54246f2cda58 --- /dev/null +++ b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/integration/ReactiveLongIdDomainPartitionPartitionRepositoryIT.java @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.spring.data.cosmos.repository.integration; + +import com.azure.cosmos.models.PartitionKey; +import com.azure.spring.data.cosmos.core.CosmosTemplate; +import com.azure.spring.data.cosmos.domain.LongIdDomainPartition; +import com.azure.spring.data.cosmos.exception.CosmosAccessException; +import com.azure.spring.data.cosmos.repository.TestRepositoryConfig; +import com.azure.spring.data.cosmos.repository.repository.ReactiveLongIdDomainPartitionRepository; +import com.azure.spring.data.cosmos.repository.support.CosmosEntityInformation; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.domain.Sort; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import java.time.Duration; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(classes = TestRepositoryConfig.class) +public class ReactiveLongIdDomainPartitionPartitionRepositoryIT { + + private static final Long ID_1 = 12345L; + private static final String NAME_1 = "moary"; + + private static final Long ID_2 = 67890L; + private static final String NAME_2 = "camille"; + + private static final LongIdDomainPartition DOMAIN_1 = new LongIdDomainPartition(ID_1, NAME_1); + private static final LongIdDomainPartition DOMAIN_2 = new LongIdDomainPartition(ID_2, NAME_2); + + private static final CosmosEntityInformation entityInformation = + new CosmosEntityInformation<>(LongIdDomainPartition.class); + + private static CosmosTemplate staticTemplate; + private static boolean isSetupDone; + + private static final Duration DEFAULT_TIME_OUT = Duration.ofSeconds(10); + + @Autowired + private CosmosTemplate template; + + @Autowired + private ReactiveLongIdDomainPartitionRepository repository; + + @Before + public void setUp() { + if (!isSetupDone) { + staticTemplate = template; + template.createContainerIfNotExists(entityInformation); + } + this.repository.save(DOMAIN_1).block(DEFAULT_TIME_OUT); + this.repository.save(DOMAIN_2).block(DEFAULT_TIME_OUT); + isSetupDone = true; + } + + @After + public void cleanup() { + final Mono deletedMono = repository.deleteAll(); + StepVerifier.create(deletedMono).thenAwait().verifyComplete(); + } + + @AfterClass + public static void afterClassCleanup() { + staticTemplate.deleteContainer(entityInformation.getContainerName()); + } + + @Test + public void testLongIdDomainPartition() { + this.repository.deleteAll().block(DEFAULT_TIME_OUT); + Assert.assertFalse(this.repository.findById(ID_1).blockOptional(DEFAULT_TIME_OUT).isPresent()); + + this.repository.save(DOMAIN_1).block(DEFAULT_TIME_OUT); + Optional foundOptional = this.repository.findById(ID_1).blockOptional(DEFAULT_TIME_OUT); + + Assert.assertTrue(foundOptional.isPresent()); + Assert.assertEquals(DOMAIN_1.getNumber(), foundOptional.get().getNumber()); + Assert.assertEquals(DOMAIN_1.getName(), foundOptional.get().getName()); + + this.repository.delete(DOMAIN_1).block(DEFAULT_TIME_OUT); + + Assert.assertFalse(this.repository.findById(ID_1).blockOptional(DEFAULT_TIME_OUT).isPresent()); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidDomain() { + new CosmosEntityInformation(InvalidDomain.class); + } + + @Test + public void testBasicQuery() { + LongIdDomainPartition save = this.repository.save(DOMAIN_1).block(DEFAULT_TIME_OUT); + Assert.assertNotNull(save); + } + + @Test + public void testSaveAndFindById() { + Assert.assertNotNull(this.repository.save(DOMAIN_1).block(DEFAULT_TIME_OUT)); + Optional longIdDomainPartitionOptional = this.repository + .findById(DOMAIN_1.getNumber()).blockOptional(DEFAULT_TIME_OUT); + Assert.assertTrue(longIdDomainPartitionOptional.isPresent()); + Assert.assertEquals(DOMAIN_1, longIdDomainPartitionOptional.get()); + } + + @Test + public void testSaveAllAndFindAll() { + this.repository.deleteAll().block(DEFAULT_TIME_OUT); + List savedEntities = Stream.of(DOMAIN_1, DOMAIN_2).collect(Collectors.toList()); + this.repository.saveAll(savedEntities).collectList().block(DEFAULT_TIME_OUT); + List longIdDomainPartitionList = this.repository.findAll().collectList().block(DEFAULT_TIME_OUT); + Assert.assertTrue(longIdDomainPartitionList.containsAll(savedEntities)); + } + + @Test + public void testCount() { + Assert.assertTrue(2 == repository.count().block(DEFAULT_TIME_OUT)); + } + + @Test + public void testDeleteByIdWithoutPartitionKey() { + final Mono deleteMono = repository.deleteById(DOMAIN_1.getNumber()); + StepVerifier.create(deleteMono).expectError(CosmosAccessException.class).verify(); + } + + @Test + public void testDeleteByIdAndPartitionKey() { + final Mono deleteMono = repository.deleteById(DOMAIN_1.getNumber(), + new PartitionKey(entityInformation.getPartitionKeyFieldValue(DOMAIN_1))); + StepVerifier.create(deleteMono).verifyComplete(); + + final Mono byId = repository.findById(DOMAIN_1.getNumber(), + new PartitionKey(entityInformation.getPartitionKeyFieldValue(DOMAIN_1))); + Assert.assertNull(byId.block(DEFAULT_TIME_OUT)); + } + + @Test(expected = CosmosAccessException.class) + public void testDeleteByIdShouldFailIfNothingToDelete() { + this.repository.deleteAll().block(DEFAULT_TIME_OUT); + this.repository.deleteById(DOMAIN_1.getNumber()).block(DEFAULT_TIME_OUT); + } + + @Test + public void testDelete() { + this.repository.save(DOMAIN_1).block(DEFAULT_TIME_OUT); + this.repository.delete(DOMAIN_1).block(DEFAULT_TIME_OUT); + Assert.assertTrue(1 == this.repository.count().block(DEFAULT_TIME_OUT)); + } + + @Test(expected = CosmosAccessException.class) + public void testDeleteShouldFailIfNothingToDelete() { + this.repository.deleteAll().block(DEFAULT_TIME_OUT); + this.repository.delete(DOMAIN_1).block(DEFAULT_TIME_OUT); + } + + @Test + public void testDeleteAll() { + this.repository.save(DOMAIN_1).block(DEFAULT_TIME_OUT); + this.repository.save(DOMAIN_2).block(DEFAULT_TIME_OUT); + this.repository.deleteAll(Arrays.asList(DOMAIN_1, DOMAIN_2)).block(DEFAULT_TIME_OUT); + Assert.assertTrue(0 == this.repository.count().block(DEFAULT_TIME_OUT)); + } + + @Test + public void testExistsById() { + this.repository.save(DOMAIN_1).block(DEFAULT_TIME_OUT); + Assert.assertTrue(this.repository.existsById(DOMAIN_1.getNumber()).block(DEFAULT_TIME_OUT)); + } + + @Test + public void testFindAllSort() { + final LongIdDomainPartition other = new LongIdDomainPartition(DOMAIN_1.getNumber() + 1, "other-name"); + this.repository.save(other).block(DEFAULT_TIME_OUT); + this.repository.save(DOMAIN_1).block(DEFAULT_TIME_OUT); + + final Sort ascSort = Sort.by(Sort.Direction.ASC, "number"); + final List ascending = this.repository.findAll(ascSort) + .collectList().block(DEFAULT_TIME_OUT); + Assert.assertEquals(3, ascending.size()); + Assert.assertEquals(DOMAIN_1, ascending.get(0)); + Assert.assertEquals(other, ascending.get(1)); + Assert.assertEquals(DOMAIN_2, ascending.get(2)); + + final Sort descSort = Sort.by(Sort.Direction.DESC, "number"); + final List descending = this.repository.findAll(descSort) + .collectList().block(DEFAULT_TIME_OUT); + Assert.assertEquals(3, descending.size()); + Assert.assertEquals(DOMAIN_2, descending.get(0)); + Assert.assertEquals(other, descending.get(1)); + Assert.assertEquals(DOMAIN_1, descending.get(2)); + + } + + private static class InvalidDomain { + + private long count; + + private String location; + + InvalidDomain() { + } + + InvalidDomain(long count, String location) { + this.count = count; + this.location = location; + } + + public long getCount() { + return count; + } + + public void setCount(long count) { + this.count = count; + } + + public String getLocation() { + return location; + } + + public void setLocation(String location) { + this.location = location; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + InvalidDomain that = (InvalidDomain) o; + return count == that.count + && Objects.equals(location, that.location); + } + + @Override + public int hashCode() { + return Objects.hash(count, location); + } + + @Override + public String toString() { + return "InvalidDomain{" + + "count=" + + count + + ", location='" + + location + + '\'' + + '}'; + } + } +} diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/repository/LongIdDomainRepository.java b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/repository/LongIdDomainRepository.java new file mode 100644 index 000000000000..c8fb082d6691 --- /dev/null +++ b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/repository/LongIdDomainRepository.java @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.spring.data.cosmos.repository.repository; + +import com.azure.spring.data.cosmos.domain.LongIdDomain; +import com.azure.spring.data.cosmos.repository.CosmosRepository; +import org.springframework.stereotype.Repository; + +@Repository +public interface LongIdDomainRepository extends CosmosRepository { + +} diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/repository/ReactiveLongIdDomainPartitionRepository.java b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/repository/ReactiveLongIdDomainPartitionRepository.java new file mode 100644 index 000000000000..a6adc573650f --- /dev/null +++ b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/repository/ReactiveLongIdDomainPartitionRepository.java @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.spring.data.cosmos.repository.repository; + +import com.azure.spring.data.cosmos.domain.LongIdDomainPartition; +import com.azure.spring.data.cosmos.repository.ReactiveCosmosRepository; +import org.springframework.stereotype.Repository; + +@Repository +public interface ReactiveLongIdDomainPartitionRepository extends ReactiveCosmosRepository { + +} diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/support/CosmosEntityInformationUnitTest.java b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/support/CosmosEntityInformationUnitTest.java index 55c18653afe3..994d68637225 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/support/CosmosEntityInformationUnitTest.java +++ b/sdk/cosmos/azure-spring-data-cosmos/src/test/java/com/azure/spring/data/cosmos/repository/support/CosmosEntityInformationUnitTest.java @@ -6,9 +6,11 @@ import com.azure.spring.data.cosmos.core.mapping.Document; import com.azure.spring.data.cosmos.core.mapping.PartitionKey; import com.azure.spring.data.cosmos.domain.Address; +import com.azure.spring.data.cosmos.domain.LongIdDomain; import com.azure.spring.data.cosmos.domain.Person; import com.azure.spring.data.cosmos.domain.Student; import org.junit.Test; +import org.springframework.data.annotation.Id; import org.springframework.data.annotation.Version; import java.util.List; @@ -380,4 +382,80 @@ public String toString() { + '}'; } } + + @Test + public void testGetIdFieldWithLongType() { + final CosmosEntityInformation entityInformation = + new CosmosEntityInformation<>(LongIdDomain.class); + assertThat(entityInformation.getIdField().getType().equals(Long.class)).isTrue(); + } + + @Test + public void testGetIdFieldWithBasicType() { + final CosmosEntityInformation entityInformation = + new CosmosEntityInformation<>(BasicLongIdDomain.class); + assertThat(entityInformation.getIdField().getType().equals(long.class)).isTrue(); + } + + @Document + class BasicLongIdDomain { + + @Id + private long number; + + private String name; + + BasicLongIdDomain(long number, String name) { + this.number = number; + this.name = name; + } + + BasicLongIdDomain() { + } + + public long getNumber() { + return number; + } + + public void setNumber(Long number) { + this.number = number; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BasicLongIdDomain that = (BasicLongIdDomain) o; + return Objects.equals(number, that.number) + && Objects.equals(name, that.name); + } + + @Override + public int hashCode() { + return Objects.hash(number, name); + } + + @Override + public String toString() { + return "BasicLongIdDomain{" + + "number=" + + number + + ", name='" + + name + + '\'' + + '}'; + } + } }