Skip to content
15 changes: 14 additions & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public class KNNSettings {
public static final String KNN_INDEX = "index.knn";
public static final String MODEL_INDEX_NUMBER_OF_SHARDS = "knn.model_index_number_of_shards";
public static final String MODEL_INDEX_NUMBER_OF_REPLICAS = "knn.model_index_number_of_replicas";
public static final String MODEL_CACHE_SIZE_IN_BYTES = "knn.model_cache.size_in_bytes";

/**
* Default setting values
Expand All @@ -100,6 +101,10 @@ public class KNNSettings {
public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION = 512;
public static final Integer KNN_DEFAULT_ALGO_PARAM_INDEX_THREAD_QTY = 1;
public static final Integer KNN_DEFAULT_CIRCUIT_BREAKER_UNSET_PERCENTAGE = 75;
public static final Integer KNN_DEFAULT_MODEL_CACHE_SIZE_IN_BYTES = 50000000; // 50 Mb
public static final Integer KNN_MAX_MODEL_CACHE_SIZE_IN_BYTES = 80000000; // 80 Mb
public static final Integer KNN_MIN_MODEL_CACHE_SIZE_IN_BYTES = 0;


/**
* Settings Definition
Expand Down Expand Up @@ -158,6 +163,13 @@ public class KNNSettings {
Setting.Property.NodeScope,
Setting.Property.Dynamic);

public static final Setting<Long> MODEL_CACHE_SIZE_IN_BYTES_SETTING = Setting.longSetting(
MODEL_CACHE_SIZE_IN_BYTES,
KNN_DEFAULT_MODEL_CACHE_SIZE_IN_BYTES,
KNN_MIN_MODEL_CACHE_SIZE_IN_BYTES,
KNN_MAX_MODEL_CACHE_SIZE_IN_BYTES,
Setting.Property.NodeScope,
Setting.Property.Dynamic);

/**
* This setting identifies KNN index.
Expand Down Expand Up @@ -310,7 +322,8 @@ public List<Setting<?>> getSettings() {
KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE_SETTING,
IS_KNN_INDEX_SETTING,
MODEL_INDEX_NUMBER_OF_SHARDS_SETTING,
MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING);
MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING,
MODEL_CACHE_SIZE_IN_BYTES_SETTING);
return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream())
.collect(Collectors.toList());
}
Expand Down
134 changes: 134 additions & 0 deletions src/main/java/org/opensearch/knn/indices/ModelCache.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.indices;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.service.ClusterService;

import java.util.concurrent.ExecutionException;

import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_IN_BYTES_SETTING;


public final class ModelCache {

private static Logger logger = LogManager.getLogger(ModelCache.class);

private static ModelCache instance;
private static ModelDao modelDao;
private static ClusterService clusterService;

private Cache<String, byte[]> cache;
private long cacheSizeInBytes;

/**
* Get instance of cache
*
* @return singleton instance of cache
*/
public static synchronized ModelCache getInstance() {
if (instance == null) {
instance = new ModelCache();
}
return instance;
}

/**
* Initialize the cache
*
* @param modelDao modelDao used to read persistence layer for models
* @param clusterService used to update settings
*/
public static void initialize(ModelDao modelDao, ClusterService clusterService) {
ModelCache.modelDao = modelDao;
ModelCache.clusterService = clusterService;
}

/**
* Evict all entries and rebuild the graph
*/
public synchronized void rebuild() {
cache.invalidateAll();
initCache();
}

protected ModelCache() {
cacheSizeInBytes = MODEL_CACHE_SIZE_IN_BYTES_SETTING.get(clusterService.getSettings());
clusterService.getClusterSettings().addSettingsUpdateConsumer(MODEL_CACHE_SIZE_IN_BYTES_SETTING, it -> {
cacheSizeInBytes = it;
rebuild();
});
initCache();
}

private void initCache() {
CacheBuilder<String, byte[]> cacheBuilder = CacheBuilder.newBuilder()
.recordStats()
.concurrencyLevel(1)
.maximumWeight(cacheSizeInBytes)
.weigher((k, v) -> v.length);

cache = cacheBuilder.build();
}

/**
* Get the model from modelId
*
* @param modelId model identifier
* @return byte array representing model
*/
public byte[] get(String modelId) {
try {
return cache.get(modelId, () -> modelDao.get(modelId));
} catch (ExecutionException ee) {
throw new IllegalStateException("Unable to retrieve model blob for \"" + modelId + "\": " + ee);
}
}

/**
* Get total weight of cache
*
* @return total weight
*/
public long getTotalWeight() {
return cache.asMap().values().stream().map(bytes -> (long) bytes.length).reduce(0L, Long::sum);
}

/**
* Remove modelId from cache
*
* @param modelId to be removed
*/
public void remove(String modelId) {
Comment thread
VijayanB marked this conversation as resolved.
cache.invalidate(modelId);
}

/**
* Check if modelId is in the cache
*
* @param modelId model id to be checked
* @return true if model id is in the cache; false otherwise
*/
public boolean contains(String modelId) {
return cache.asMap().containsKey(modelId);
}

/**
* Remove all elements from the cache
*/
public void removeAll() {
cache.invalidateAll();
}
}
Loading