diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 7992640a28..2468030325 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -90,6 +90,7 @@ public class KNNSettings { public static final String KNN_INDEX = "index.knn"; public static final String MODEL_INDEX_NUMBER_OF_SHARDS = "knn.model_index_number_of_shards"; public static final String MODEL_INDEX_NUMBER_OF_REPLICAS = "knn.model_index_number_of_replicas"; + public static final String MODEL_CACHE_SIZE_IN_BYTES = "knn.model_cache.size_in_bytes"; /** * Default setting values @@ -100,6 +101,10 @@ public class KNNSettings { public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION = 512; public static final Integer KNN_DEFAULT_ALGO_PARAM_INDEX_THREAD_QTY = 1; public static final Integer KNN_DEFAULT_CIRCUIT_BREAKER_UNSET_PERCENTAGE = 75; + public static final Integer KNN_DEFAULT_MODEL_CACHE_SIZE_IN_BYTES = 50000000; // 50 Mb + public static final Integer KNN_MAX_MODEL_CACHE_SIZE_IN_BYTES = 80000000; // 80 Mb + public static final Integer KNN_MIN_MODEL_CACHE_SIZE_IN_BYTES = 0; + /** * Settings Definition @@ -158,6 +163,13 @@ public class KNNSettings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting MODEL_CACHE_SIZE_IN_BYTES_SETTING = Setting.longSetting( + MODEL_CACHE_SIZE_IN_BYTES, + KNN_DEFAULT_MODEL_CACHE_SIZE_IN_BYTES, + KNN_MIN_MODEL_CACHE_SIZE_IN_BYTES, + KNN_MAX_MODEL_CACHE_SIZE_IN_BYTES, + Setting.Property.NodeScope, + Setting.Property.Dynamic); /** * This setting identifies KNN index. @@ -310,7 +322,8 @@ public List> getSettings() { KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE_SETTING, IS_KNN_INDEX_SETTING, MODEL_INDEX_NUMBER_OF_SHARDS_SETTING, - MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING); + MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING, + MODEL_CACHE_SIZE_IN_BYTES_SETTING); return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()) .collect(Collectors.toList()); } diff --git a/src/main/java/org/opensearch/knn/indices/ModelCache.java b/src/main/java/org/opensearch/knn/indices/ModelCache.java new file mode 100644 index 0000000000..90c2dd39ab --- /dev/null +++ b/src/main/java/org/opensearch/knn/indices/ModelCache.java @@ -0,0 +1,134 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.indices; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.service.ClusterService; + +import java.util.concurrent.ExecutionException; + +import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_IN_BYTES_SETTING; + + +public final class ModelCache { + + private static Logger logger = LogManager.getLogger(ModelCache.class); + + private static ModelCache instance; + private static ModelDao modelDao; + private static ClusterService clusterService; + + private Cache cache; + private long cacheSizeInBytes; + + /** + * Get instance of cache + * + * @return singleton instance of cache + */ + public static synchronized ModelCache getInstance() { + if (instance == null) { + instance = new ModelCache(); + } + return instance; + } + + /** + * Initialize the cache + * + * @param modelDao modelDao used to read persistence layer for models + * @param clusterService used to update settings + */ + public static void initialize(ModelDao modelDao, ClusterService clusterService) { + ModelCache.modelDao = modelDao; + ModelCache.clusterService = clusterService; + } + + /** + * Evict all entries and rebuild the graph + */ + public synchronized void rebuild() { + cache.invalidateAll(); + initCache(); + } + + protected ModelCache() { + cacheSizeInBytes = MODEL_CACHE_SIZE_IN_BYTES_SETTING.get(clusterService.getSettings()); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MODEL_CACHE_SIZE_IN_BYTES_SETTING, it -> { + cacheSizeInBytes = it; + rebuild(); + }); + initCache(); + } + + private void initCache() { + CacheBuilder cacheBuilder = CacheBuilder.newBuilder() + .recordStats() + .concurrencyLevel(1) + .maximumWeight(cacheSizeInBytes) + .weigher((k, v) -> v.length); + + cache = cacheBuilder.build(); + } + + /** + * Get the model from modelId + * + * @param modelId model identifier + * @return byte array representing model + */ + public byte[] get(String modelId) { + try { + return cache.get(modelId, () -> modelDao.get(modelId)); + } catch (ExecutionException ee) { + throw new IllegalStateException("Unable to retrieve model blob for \"" + modelId + "\": " + ee); + } + } + + /** + * Get total weight of cache + * + * @return total weight + */ + public long getTotalWeight() { + return cache.asMap().values().stream().map(bytes -> (long) bytes.length).reduce(0L, Long::sum); + } + + /** + * Remove modelId from cache + * + * @param modelId to be removed + */ + public void remove(String modelId) { + cache.invalidate(modelId); + } + + /** + * Check if modelId is in the cache + * + * @param modelId model id to be checked + * @return true if model id is in the cache; false otherwise + */ + public boolean contains(String modelId) { + return cache.asMap().containsKey(modelId); + } + + /** + * Remove all elements from the cache + */ + public void removeAll() { + cache.invalidateAll(); + } +} diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java new file mode 100644 index 0000000000..4658a89ecf --- /dev/null +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -0,0 +1,271 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.indices; + +import com.google.common.base.Charsets; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.delete.DeleteAction; +import org.opensearch.action.delete.DeleteRequestBuilder; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetAction; +import org.opensearch.action.get.GetRequestBuilder; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequestBuilder; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.util.KNNEngine; + +import java.io.IOException; +import java.net.URL; +import java.util.Base64; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH; +import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; +import static org.opensearch.knn.index.KNNSettings.MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING; +import static org.opensearch.knn.index.KNNSettings.MODEL_INDEX_NUMBER_OF_SHARDS_SETTING; + +/** + * ModelDao is used to interface with the model persistence layer + */ +public interface ModelDao { + + /** + * Creates model index. It is possible that the 2 threads call this function simulateously. In this case, one + * thread will throw a ResourceAlreadyExistsException. This should be caught and handled. + * + * @param actionListener CreateIndexResponse listener + * @throws IOException thrown when get mapping fails + */ + void create(ActionListener actionListener) throws IOException; + + /** + * Checks if the model index exists + * + * @return true if the model index exists; false otherwise + */ + boolean isCreated(); + + /** + * Put a model into the system index. Non-blocking + * + * @param modelId Id of model to create + * @param modelBlob byte array of model + * @param listener handles index response + */ + void put(String modelId, KNNEngine knnEngine, byte[] modelBlob, ActionListener listener) throws IOException; + + /** + * Put a model into the system index. Non-blocking. When no id is passed in, OpenSearch will generate the id + * automatically. The id can be retrieved in the IndexResponse. + * + * @param modelBlob byte array of model + * @param listener handles index response + */ + void put(KNNEngine knnEngine, byte[] modelBlob, ActionListener listener) throws IOException; + + /** + * Get a model from the system index. Call blocks. + * + * @param modelId to retrieve + * @return byte array representing the model + * @throws ExecutionException thrown on search + * @throws InterruptedException thrown on search + */ + byte[] get(String modelId) throws ExecutionException, InterruptedException; + + /** + * Delete model from index + * + * @param modelId to delete + * @param listener handles delete response + */ + void delete(String modelId, ActionListener listener); + + /** + * Implementation of ModelDao for k-NN model index + */ + final class OpenSearchKNNModelDao implements ModelDao { + + public static Logger logger = LogManager.getLogger(ModelDao.class); + + private int numberOfShards; + private int numberOfReplicas; + + private static OpenSearchKNNModelDao INSTANCE; + private static Client client; + private static ClusterService clusterService; + private static Settings settings; + + /** + * Make sure we just have one instance of model index + * + * @return ModelIndex instance + */ + public static synchronized OpenSearchKNNModelDao getInstance() { + if (INSTANCE == null) { + INSTANCE = new OpenSearchKNNModelDao(); + } + return INSTANCE; + } + + public static void initialize(Client client, ClusterService clusterService, Settings settings) { + OpenSearchKNNModelDao.client = client; + OpenSearchKNNModelDao.clusterService = clusterService; + OpenSearchKNNModelDao.settings = settings; + } + + private OpenSearchKNNModelDao() { + numberOfShards = MODEL_INDEX_NUMBER_OF_SHARDS_SETTING.get(settings); + numberOfReplicas = MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING.get(settings); + + clusterService.getClusterSettings().addSettingsUpdateConsumer(MODEL_INDEX_NUMBER_OF_SHARDS_SETTING, + it -> numberOfShards = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING, + it -> numberOfReplicas = it); + } + + @Override + public void create(ActionListener actionListener) throws IOException { + if (isCreated()) { + return; + } + + CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME) + .mapping("_doc", getMapping(), XContentType.JSON) + .settings(Settings.builder() + .put("index.hidden", true) + .put("index.number_of_shards", this.numberOfShards) + .put("index.number_of_replicas", this.numberOfReplicas) + ); + client.admin().indices().create(request, actionListener); + } + + @Override + public boolean isCreated() { + return clusterService.state().getRoutingTable().hasIndex(MODEL_INDEX_NAME); + } + + @Override + public void put(String modelId, KNNEngine knnEngine, byte[] modelBlob, ActionListener listener) + throws IOException { + String base64Model = Base64.getEncoder().encodeToString(modelBlob); + + Map parameters = ImmutableMap.of( + KNNConstants.KNN_ENGINE, knnEngine.getName(), + KNNConstants.MODEL_BLOB_PARAMETER, base64Model + ); + + IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME, "_doc"); + indexRequestBuilder.setId(modelId); + indexRequestBuilder.setSource(parameters); + + // Fail if the id already exists. Models are not updateable + indexRequestBuilder.setOpType(DocWriteRequest.OpType.CREATE); + indexRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + if (!isCreated()) { + create(ActionListener.wrap(createIndexResponse -> indexRequestBuilder.execute(listener), + listener::onFailure)); + return; + } + + indexRequestBuilder.execute(listener); + } + + @Override + public void put(KNNEngine knnEngine, byte[] modelBlob, ActionListener listener) + throws IOException { + String base64Model = Base64.getEncoder().encodeToString(modelBlob); + + Map parameters = ImmutableMap.of( + KNNConstants.KNN_ENGINE, knnEngine.getName(), + KNNConstants.MODEL_BLOB_PARAMETER, base64Model + ); + + IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME, "_doc"); + indexRequestBuilder.setSource(parameters); + + // Fail if the id already exists. Models are not updateable + indexRequestBuilder.setOpType(DocWriteRequest.OpType.CREATE); + indexRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + if (!isCreated()) { + create(ActionListener.wrap(createIndexResponse -> indexRequestBuilder.execute(listener), + listener::onFailure)); + return; + } + + indexRequestBuilder.execute(listener); + } + + @Override + public byte[] get(String modelId) throws ExecutionException, InterruptedException { + /* + GET //?source_includes=&_local + */ + GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME) + .setId(modelId) + .setFetchSource(KNNConstants.MODEL_BLOB_PARAMETER, null) + .setPreference("_local"); + GetResponse getResponse = getRequestBuilder.execute().get(); + + Object blob = getResponse.getSourceAsMap().get(KNNConstants.MODEL_BLOB_PARAMETER); + + if (blob == null) { + throw new IllegalArgumentException("No model available in \"" + MODEL_INDEX_NAME + "\" index with id \"" + + modelId + "\"."); + } + + return Base64.getDecoder().decode((String) blob); + } + + private String getMapping() throws IOException { + URL url = ModelDao.class.getClassLoader().getResource(MODEL_INDEX_MAPPING_PATH); + if (url == null) { + throw new IllegalStateException("Unable to retrieve mapping for \"" + MODEL_INDEX_NAME + "\""); + } + + return Resources.toString(url, Charsets.UTF_8); + } + + @Override + public void delete(String modelId, ActionListener listener) { + if (!isCreated()) { + logger.info("Cannot delete model \"" + modelId + "\". Model index does not exist."); + return; + } + + DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, + MODEL_INDEX_NAME); + deleteRequestBuilder.setId(modelId); + deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + deleteRequestBuilder.execute(ActionListener.wrap(deleteResponse -> { + ModelCache.getInstance().remove(modelId); + listener.onResponse(deleteResponse); + }, listener::onFailure)); + } + } +} diff --git a/src/main/java/org/opensearch/knn/indices/ModelIndex.java b/src/main/java/org/opensearch/knn/indices/ModelIndex.java deleted file mode 100644 index 90daa4e975..0000000000 --- a/src/main/java/org/opensearch/knn/indices/ModelIndex.java +++ /dev/null @@ -1,235 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.knn.indices; - -import com.google.common.base.Charsets; -import com.google.common.collect.ImmutableMap; -import com.google.common.io.Resources; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.ActionListener; -import org.opensearch.action.DocWriteRequest; -import org.opensearch.action.admin.indices.create.CreateIndexRequest; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.delete.DeleteAction; -import org.opensearch.action.delete.DeleteRequestBuilder; -import org.opensearch.action.delete.DeleteResponse; -import org.opensearch.action.get.GetAction; -import org.opensearch.action.get.GetRequestBuilder; -import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequestBuilder; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.util.KNNEngine; - -import java.io.IOException; -import java.net.URL; -import java.util.Base64; -import java.util.Map; -import java.util.concurrent.ExecutionException; - -import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH; -import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; -import static org.opensearch.knn.index.KNNSettings.MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING; -import static org.opensearch.knn.index.KNNSettings.MODEL_INDEX_NUMBER_OF_SHARDS_SETTING; - -/** - * ModelIndex is a singleton class that controls operations on the model system index - */ -public final class ModelIndex { - public static Logger logger = LogManager.getLogger(ModelIndex.class); - - private int numberOfShards; - private int numberOfReplicas; - - private static ModelIndex INSTANCE; - private static Client client; - private static ClusterService clusterService; - private static Settings settings; - - /** - * Make sure we just have one instance of model index - * - * @return ModelIndex instance - */ - public static synchronized ModelIndex getInstance() { - if (INSTANCE == null) { - INSTANCE = new ModelIndex(); - } - return INSTANCE; - } - - public static void initialize(Client client, ClusterService clusterService, Settings settings) { - ModelIndex.client = client; - ModelIndex.clusterService = clusterService; - ModelIndex.settings = settings; - } - - private ModelIndex() { - numberOfShards = MODEL_INDEX_NUMBER_OF_SHARDS_SETTING.get(settings); - numberOfReplicas = MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING.get(settings); - - clusterService.getClusterSettings().addSettingsUpdateConsumer(MODEL_INDEX_NUMBER_OF_SHARDS_SETTING, - it -> numberOfShards = it); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING, - it -> numberOfReplicas = it); - } - - /** - * Creates model index. It is possible that the 2 threads call this function simulateously. In this case, one - * thread will throw a ResourceAlreadyExistsException. This should be caught and handled. - * - * @param actionListener CreateIndexResponse listener - * @throws IOException thrown when get mapping fails - */ - public void create(ActionListener actionListener) throws IOException { - if (isCreated()) { - return; - } - - CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME) - .mapping("_doc", getMapping(), XContentType.JSON) - .settings(Settings.builder() - .put("index.hidden", true) - .put("index.number_of_shards", this.numberOfShards) - .put("index.number_of_replicas", this.numberOfReplicas) - ); - client.admin().indices().create(request, actionListener); - } - - /** - * Checks if the model index exists - * - * @return true if the model index exists; false otherwise - */ - public boolean isCreated() { - return clusterService.state().getRoutingTable().hasIndex(MODEL_INDEX_NAME); - } - - /** - * Put a model into the system index. Non-blocking - * - * @param modelId Id of model to create - * @param modelBlob byte array of model - * @param listener handles index response - */ - public void put(String modelId, KNNEngine knnEngine, byte[] modelBlob, ActionListener listener) { - String base64Model = Base64.getEncoder().encodeToString(modelBlob); - - Map parameters = ImmutableMap.of( - KNNConstants.KNN_ENGINE, knnEngine.getName(), - KNNConstants.MODEL_BLOB_PARAMETER, base64Model - ); - - IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME, "_doc"); - indexRequestBuilder.setId(modelId); - indexRequestBuilder.setSource(parameters); - - put(indexRequestBuilder, listener); - } - - /** - * Put a model into the system index. Non-blocking. When no id is passed in, OpenSearch will generate the id - * automatically. The id can be retrieved in the IndexResponse. - * - * @param modelBlob byte array of model - * @param listener handles index response - */ - public void put(KNNEngine knnEngine, byte[] modelBlob, ActionListener listener) { - String base64Model = Base64.getEncoder().encodeToString(modelBlob); - - Map parameters = ImmutableMap.of( - KNNConstants.KNN_ENGINE, knnEngine.getName(), - KNNConstants.MODEL_BLOB_PARAMETER, base64Model - ); - - IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME, "_doc"); - indexRequestBuilder.setSource(parameters); - - put(indexRequestBuilder, listener); - } - - private void put(IndexRequestBuilder indexRequestBuilder, ActionListener listener) { - if (!isCreated()) { - throw new IllegalStateException("Cannot put model in index before index has been initialized"); - } - - // Fail if the id already exists. Models are not updateable - indexRequestBuilder.setOpType(DocWriteRequest.OpType.CREATE); - indexRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - indexRequestBuilder.execute(listener); - } - - /** - * Get a model from the system index. Call blocks. - * - * @param modelId to retrieve - * @return byte array representing the model - * @throws ExecutionException thrown on search - * @throws InterruptedException thrown on search - */ - public byte[] get(String modelId) throws ExecutionException, InterruptedException { - if (!isCreated()) { - throw new IllegalStateException("Cannot get model \"" + modelId + "\". Model index does not exist."); - } - - /* - GET //?source_includes=&_local - */ - GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME) - .setId(modelId) - .setFetchSource(KNNConstants.MODEL_BLOB_PARAMETER, null) - .setPreference("_local"); - GetResponse getResponse = getRequestBuilder.execute().get(); - - Object blob = getResponse.getSourceAsMap().get(KNNConstants.MODEL_BLOB_PARAMETER); - - if (blob == null) { - throw new IllegalArgumentException("No model available in \"" + MODEL_INDEX_NAME + "\" index with id \"" - + modelId + "\"."); - } - - return Base64.getDecoder().decode((String) blob); - } - - private String getMapping() throws IOException { - URL url = ModelIndex.class.getClassLoader().getResource(MODEL_INDEX_MAPPING_PATH); - if (url == null) { - throw new IllegalStateException("Unable to retrieve mapping for \"" + MODEL_INDEX_NAME + "\""); - } - - return Resources.toString(url, Charsets.UTF_8); - } - - /** - * Delete model from index - * - * @param modelId to delete - * @param listener handles delete response - */ - public void delete(String modelId, ActionListener listener) { - if (!isCreated()) { - throw new IllegalStateException("Cannot delete model \"" + modelId + "\". Model index does not exist."); - } - - DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, - MODEL_INDEX_NAME); - deleteRequestBuilder.setId(modelId); - deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - deleteRequestBuilder.execute(listener); - } -} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index f88832ad7d..70520c4b98 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -31,7 +31,8 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KNNVectorFieldMapper; -import org.opensearch.knn.indices.ModelIndex; +import org.opensearch.knn.indices.ModelCache; +import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.plugin.rest.RestKNNStatsHandler; import org.opensearch.knn.plugin.rest.RestKNNWarmupHandler; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; @@ -145,7 +146,8 @@ public Collection createComponents(Client client, ClusterService cluster this.clusterService = clusterService; KNNIndexCache.setResourceWatcherService(resourceWatcherService); KNNSettings.state().initialize(client, clusterService); - ModelIndex.initialize(client, clusterService, environment.settings()); + ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings()); + ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client); knnStats = new KNNStats(KNNStatsConfig.KNN_STATS); return ImmutableList.of(knnStats); diff --git a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java new file mode 100644 index 0000000000..329bfb116a --- /dev/null +++ b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java @@ -0,0 +1,313 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.indices; + +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.UncheckedExecutionException; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNTestCase; + +import java.util.concurrent.ExecutionException; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_IN_BYTES_SETTING; + +public class ModelCacheTests extends KNNTestCase { + + public void testGet_normal() throws ExecutionException, InterruptedException { + String modelId = "test-model-id"; + byte[] mockModel = "hello".getBytes(); + long cacheSize = 100L; + + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId)).thenReturn(mockModel); + + Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_IN_BYTES_SETTING.getKey(), cacheSize).build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, + ImmutableSet.of(MODEL_CACHE_SIZE_IN_BYTES_SETTING)); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + ModelCache.initialize(modelDao, clusterService); + ModelCache modelCache = new ModelCache(); + + assertArrayEquals(mockModel, modelCache.get(modelId)); + } + + public void testGet_modelDoesNotFitInCache() throws ExecutionException, InterruptedException { + String modelId = "test-model-id"; + long cacheSize = 500; + byte[] mockModel = new byte[Long.valueOf(cacheSize).intValue() + 1]; + + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId)).thenReturn(mockModel); + + Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_IN_BYTES_SETTING.getKey(), cacheSize).build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, + ImmutableSet.of(MODEL_CACHE_SIZE_IN_BYTES_SETTING)); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + ModelCache.initialize(modelDao, clusterService); + ModelCache modelCache = new ModelCache(); + + assertArrayEquals(mockModel, modelCache.get(modelId)); + assertFalse(modelCache.contains(modelId)); + } + + public void testGet_modelDoesNotExist() throws ExecutionException, InterruptedException { + String modelId = "test-model-id"; + long cacheSize = 100L; + + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId)).thenThrow(new IllegalArgumentException()); + + Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_IN_BYTES_SETTING.getKey(), cacheSize).build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, + ImmutableSet.of(MODEL_CACHE_SIZE_IN_BYTES_SETTING)); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + ModelCache.initialize(modelDao, clusterService); + ModelCache modelCache = new ModelCache(); + + expectThrows(UncheckedExecutionException.class, () -> modelCache.get(modelId)); + } + + public void testGetTotalWeight() throws ExecutionException, InterruptedException { + String modelId1 = "test-model-id-1"; + String modelId2 = "test-model-id-2"; + long cacheSize = 500L; + + int size1 = 100; + byte[] mockModel1 = new byte[size1]; + int size2 = 300; + byte[] mockModel2 = new byte[size2]; + + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId1)).thenReturn(mockModel1); + when(modelDao.get(modelId2)).thenReturn(mockModel2); + + + Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_IN_BYTES_SETTING.getKey(), cacheSize).build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, + ImmutableSet.of(MODEL_CACHE_SIZE_IN_BYTES_SETTING)); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + ModelCache.initialize(modelDao, clusterService); + ModelCache modelCache = new ModelCache(); + + modelCache.get(modelId1); + modelCache.get(modelId2); + modelCache.get(modelId1); + modelCache.get(modelId2); + + assertEquals(size1 + size2, modelCache.getTotalWeight()); + } + + public void testRemove_normal() throws ExecutionException, InterruptedException { + String modelId1 = "test-model-id-1"; + String modelId2 = "test-model-id-2"; + long cacheSize = 500L; + + int size1 = 100; + byte[] mockModel1 = new byte[size1]; + int size2 = 300; + byte[] mockModel2 = new byte[size2]; + + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId1)).thenReturn(mockModel1); + when(modelDao.get(modelId2)).thenReturn(mockModel2); + + Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_IN_BYTES_SETTING.getKey(), cacheSize).build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, + ImmutableSet.of(MODEL_CACHE_SIZE_IN_BYTES_SETTING)); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + ModelCache.initialize(modelDao, clusterService); + ModelCache modelCache = new ModelCache(); + + modelCache.get(modelId1); + modelCache.get(modelId2); + modelCache.get(modelId1); + modelCache.get(modelId2); + + assertEquals(size1 + size2, modelCache.getTotalWeight()); + + modelCache.remove(modelId1); + + assertEquals( size2, modelCache.getTotalWeight()); + + modelCache.remove(modelId2); + + assertEquals( 0, modelCache.getTotalWeight()); + } + + public void testRebuild_normal() throws ExecutionException, InterruptedException { + String modelId = "test-model-id"; + long cacheSize = 100L; + byte[] mockModel = "hello".getBytes(); + + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId)).thenReturn(mockModel); + + Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_IN_BYTES_SETTING.getKey(), cacheSize).build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, + ImmutableSet.of(MODEL_CACHE_SIZE_IN_BYTES_SETTING)); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + ModelCache.initialize(modelDao, clusterService); + ModelCache modelCache = new ModelCache(); + + // Add element to cache - nothing should be kept + modelCache.get(modelId); + assertEquals(mockModel.length, modelCache.getTotalWeight()); + + // Rebuild and make sure cache is empty + modelCache.rebuild(); + assertEquals(0, modelCache.getTotalWeight()); + + // Add element again + modelCache.get(modelId); + assertEquals(mockModel.length, modelCache.getTotalWeight()); + } + + public void testRebuild_afterSettingUpdate() throws ExecutionException, InterruptedException { + String modelId = "test-model-id"; + + int modelSize = 101; + byte[] mockModel = new byte[modelSize]; + + long cacheSize1 = 100L; + long cacheSize2 = 200L; + + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId)).thenReturn(mockModel); + + Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_IN_BYTES_SETTING.getKey(), cacheSize1).build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, + ImmutableSet.of(MODEL_CACHE_SIZE_IN_BYTES_SETTING)); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + ModelCache.initialize(modelDao, clusterService); + ModelCache modelCache = new ModelCache(); + + // Add element to cache - element should not remain in cache + modelCache.get(modelId); + assertEquals(0, modelCache.getTotalWeight()); + + // Rebuild and make sure cache is empty + Settings newSettings = Settings.builder().put(MODEL_CACHE_SIZE_IN_BYTES_SETTING.getKey(), cacheSize2).build(); + clusterService.getClusterSettings().applySettings(newSettings); + assertEquals(0, modelCache.getTotalWeight()); + + // Add element again - element should remain in cache + modelCache.get(modelId); + assertEquals(modelSize, modelCache.getTotalWeight()); + } + + public void testRemove_modelNotInCache() { + String modelId1 = "test-model-id-1"; + long cacheSize = 100L; + + ModelDao modelDao = mock(ModelDao.class); + + Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_IN_BYTES_SETTING.getKey(), cacheSize).build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, + ImmutableSet.of(MODEL_CACHE_SIZE_IN_BYTES_SETTING)); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + ModelCache.initialize(modelDao, clusterService); + ModelCache modelCache = new ModelCache(); + + assertEquals( 0, modelCache.getTotalWeight()); + modelCache.remove(modelId1); + assertEquals( 0, modelCache.getTotalWeight()); + } + + public void testContains() throws ExecutionException, InterruptedException { + String modelId1 = "test-model-id-1"; + int modelSize1 = 100; + byte[] mockModel1 = new byte[modelSize1]; + + String modelId2 = "test-model-id-2"; + + long cacheSize = 500L; + + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId1)).thenReturn(mockModel1); + + Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_IN_BYTES_SETTING.getKey(), cacheSize).build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, + ImmutableSet.of(MODEL_CACHE_SIZE_IN_BYTES_SETTING)); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + ModelCache.initialize(modelDao, clusterService); + ModelCache modelCache = new ModelCache(); + + assertFalse(modelCache.contains(modelId1)); + modelCache.get(modelId1); + assertTrue(modelCache.contains(modelId1)); + assertFalse(modelCache.contains(modelId2)); + } + + public void testRemoveAll() throws ExecutionException, InterruptedException { + String modelId1 = "test-model-id-1"; + int modelSize1 = 100; + byte[] mockModel1 = new byte[modelSize1]; + + String modelId2 = "test-model-id-2"; + int modelSize2 = 100; + byte[] mockModel2 = new byte[modelSize2]; + + long cacheSize = 500L; + + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId1)).thenReturn(mockModel1); + when(modelDao.get(modelId2)).thenReturn(mockModel2); + + Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_IN_BYTES_SETTING.getKey(), cacheSize).build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, + ImmutableSet.of(MODEL_CACHE_SIZE_IN_BYTES_SETTING)); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + ModelCache.initialize(modelDao, clusterService); + ModelCache modelCache = new ModelCache(); + + modelCache.get(modelId1); + modelCache.get(modelId2); + + assertEquals( modelSize1 + modelSize2, modelCache.getTotalWeight()); + modelCache.removeAll(); + assertEquals( 0, modelCache.getTotalWeight()); + } +} diff --git a/src/test/java/org/opensearch/knn/indices/ModelIndexTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java similarity index 80% rename from src/test/java/org/opensearch/knn/indices/ModelIndexTests.java rename to src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 13661d4eb1..8720ef09f0 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelIndexTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -35,10 +35,10 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_BLOB_PARAMETER; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; -public class ModelIndexTests extends KNNSingleNodeTestCase { +public class ModelDaoTests extends KNNSingleNodeTestCase { public void testCreate() throws IOException, InterruptedException { - int attempts = 20; + int attempts = 3; final CountDownLatch inProgressLatch = new CountDownLatch(attempts); ActionListener indexCreationListener = ActionListener.wrap(response -> { @@ -51,24 +51,24 @@ public void testCreate() throws IOException, InterruptedException { inProgressLatch.countDown(); }); - ModelIndex modelIndex = ModelIndex.getInstance(); + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); for (int i = 0; i < attempts; i++) { - modelIndex.create(indexCreationListener); + modelDao.create(indexCreationListener); } assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); } public void testExists() { - ModelIndex modelIndex = ModelIndex.getInstance(); - assertFalse(modelIndex.isCreated()); + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + assertFalse(modelDao.isCreated()); createIndex(MODEL_INDEX_NAME); - assertTrue(modelIndex.isCreated()); + assertTrue(modelDao.isCreated()); } - public void testPut_withId() throws InterruptedException { - ModelIndex modelIndex = ModelIndex.getInstance(); + public void testPut_withId() throws InterruptedException, IOException { + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); String modelId = "efbsdhcvbsd"; byte [] modelBlob = "hello".getBytes(); @@ -82,7 +82,7 @@ public void testPut_withId() throws InterruptedException { inProgressLatch1.countDown(); }, exception -> fail("Unable to put the model: " + exception)); - modelIndex.put(modelId, KNNEngine.DEFAULT, modelBlob, docCreationListener); + modelDao.put(modelId, KNNEngine.DEFAULT, modelBlob, docCreationListener); assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); @@ -97,12 +97,12 @@ public void testPut_withId() throws InterruptedException { inProgressLatch2.countDown(); }); - modelIndex.put(modelId, KNNEngine.DEFAULT, modelBlob, docCreationListenerDuplicateId); + modelDao.put(modelId, KNNEngine.DEFAULT, modelBlob, docCreationListenerDuplicateId); assertTrue(inProgressLatch2.await(100, TimeUnit.SECONDS)); } - public void testPut_withoutId() throws InterruptedException { - ModelIndex modelIndex = ModelIndex.getInstance(); + public void testPut_withoutId() throws InterruptedException, IOException { + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); byte [] modelBlob = "hello".getBytes(); createIndex(MODEL_INDEX_NAME); @@ -115,34 +115,34 @@ public void testPut_withoutId() throws InterruptedException { }, exception -> fail("Unable to put the model: " + exception)); - modelIndex.put(KNNEngine.DEFAULT, modelBlob, docCreationListenerNoModelId); + modelDao.put(KNNEngine.DEFAULT, modelBlob, docCreationListenerNoModelId); assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); } public void testGet() throws IOException, InterruptedException, ExecutionException { - ModelIndex modelIndex = ModelIndex.getInstance(); + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); String modelId = "efbsdhcvbsd"; byte[] modelBlob = "hello".getBytes(); // model index doesnt exist - expectThrows(IllegalStateException.class, () -> modelIndex.get(modelId)); + expectThrows(ExecutionException.class, () -> modelDao.get(modelId)); // model id doesnt exist createIndex(MODEL_INDEX_NAME); - expectThrows(Exception.class, () -> modelIndex.get(modelId)); + expectThrows(Exception.class, () -> modelDao.get(modelId)); // model id exists addDoc(modelId, modelBlob); - assertArrayEquals(modelBlob, modelIndex.get(modelId)); + assertArrayEquals(modelBlob, modelDao.get(modelId)); } public void testDelete() throws IOException, InterruptedException, ExecutionException { - ModelIndex modelIndex = ModelIndex.getInstance(); + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); String modelId = "efbsdhcvbsd"; byte[] modelBlob = "hello".getBytes(); - // model index doesnt exist - expectThrows(IllegalStateException.class, () -> modelIndex.delete(modelId, null)); + // model index doesnt exist --> nothing should happen + modelDao.delete(modelId, null); // model id doesnt exist createIndex(MODEL_INDEX_NAME); @@ -153,7 +153,7 @@ public void testDelete() throws IOException, InterruptedException, ExecutionExce inProgressLatch1.countDown(); }, exception -> fail("Unable to delete the model: " + exception)); - modelIndex.delete(modelId, deleteModelDoesNotExistListener); + modelDao.delete(modelId, deleteModelDoesNotExistListener); assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); // model id exists @@ -165,7 +165,7 @@ public void testDelete() throws IOException, InterruptedException, ExecutionExce inProgressLatch2.countDown(); }, exception -> fail("Unable to delete model: " + exception)); - modelIndex.delete(modelId, deleteModelExistsListener); + modelDao.delete(modelId, deleteModelExistsListener); assertTrue(inProgressLatch2.await(100, TimeUnit.SECONDS)); }