Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <thrift/lib/cpp2/protocol/CompactProtocol.h>
#include <thrift/lib/cpp2/protocol/Serializer.h>
#include <torch/script.h>
#include <cstring>

#include "../ssd_split_embeddings_cache/initializer.h"
#include "../ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h"
Expand Down Expand Up @@ -483,6 +484,182 @@ class DramKVEmbeddingCache : public DramKVEmbeddingBase {
return result;
};

/// Get embeddings and metaheader from kvstore.
///
/// @param indices The 1D embedding index tensor, should skip on negative
/// value
/// @param weights_with_metaheader The 2D tensor that each row(embeddings) is
/// paired up with relative element in <indices>. This tensor will be
/// filled up with the returned embeddings and metaheader from KVstore.
/// @param count A single element tensor that contains the number of indices
/// to be processed
///
/// @return None
folly::SemiFuture<std::vector<folly::Unit>>
get_kv_db_with_metaheader_async_impl(
const at::Tensor& indices,
const at::Tensor& weights_with_metaheader,
const at::Tensor& count) {
std::vector<folly::Future<folly::Unit>> futures;
auto row_width = weights_with_metaheader.size(1) *
weights_with_metaheader.element_size();
CHECK_EQ(row_width, block_size_);
auto shardid_to_indexes = shard_input(indices, count);

for (auto iter = shardid_to_indexes.begin();
iter != shardid_to_indexes.end();
iter++) {
const auto shard_id = iter->first;
const auto indexes = iter->second;
auto f =
folly::via(executor_.get())
.thenValue([this,
shard_id,
indexes,
&indices,
&weights_with_metaheader,
row_width](folly::Unit) {
FBGEMM_DISPATCH_INTEGRAL_TYPES(
indices.scalar_type(),
"dram_kvstore_set",
[this,
shard_id,
indexes,
&indices,
&weights_with_metaheader,
row_width] {
using index_t = scalar_t;
CHECK(indices.is_contiguous());
CHECK(weights_with_metaheader.is_contiguous());
CHECK_EQ(indices.size(0),
weights_with_metaheader.size(0));
auto wlmap = kv_store_.by(shard_id).wlock();
auto indices_data_ptr = indices.data_ptr<index_t>();
{
for (auto index_iter = indexes.begin();
index_iter != indexes.end();
index_iter++) {
void* weights_data_ptr =
weights_with_metaheader.mutable_data_ptr();
const auto weights_row_index = *index_iter;
auto weight_idx =
int64_t(indices_data_ptr[weights_row_index]);
const auto cached_iter = wlmap->find(weight_idx);
// Defensive programming
// it shouldn't occur under normal circumstances
if (cached_iter == wlmap->end()) {
std::memset(static_cast<char*>(weights_data_ptr) +
weights_row_index * row_width,
0,
row_width);
continue;
}
std::memcpy(static_cast<char*>(weights_data_ptr) +
weights_row_index * row_width,
cached_iter->second,
row_width);
}
}
});
});
futures.push_back(std::move(f));
}
return folly::collect(futures);
};

/// insert embeddings and metaheader into kvstore.
/// current underlying memory management is done through F14FastMap
/// key value pair will be sharded into multiple shards to increase
/// parallelism.
///
/// @param indices The 1D embedding index tensor, should skip on negative
/// value
/// @param weights_with_metaheader The 2D tensor that each row(embeddings with
/// metaheader) is paired up with relative element in <indices>
/// @param count A single element tensor that contains the number of indices
/// to be processed
///
/// @return None
folly::SemiFuture<std::vector<folly::Unit>>
set_kv_db_with_metaheader_async_impl(
const at::Tensor& indices,
const at::Tensor& weights_with_metaheader,
const at::Tensor& count) {
std::vector<folly::Future<folly::Unit>> futures;
auto shardid_to_indexes = shard_input(indices, count);
for (auto iter = shardid_to_indexes.begin();
iter != shardid_to_indexes.end();
iter++) {
const auto shard_id = iter->first;
const auto indexes = iter->second;
auto f =
folly::via(executor_.get())
.thenValue(
[this, shard_id, indexes, &indices, &weights_with_metaheader](
folly::Unit) {
FBGEMM_DISPATCH_INTEGRAL_TYPES(
indices.scalar_type(),
"dram_kv_set",
[this,
shard_id,
indexes,
&indices,
&weights_with_metaheader] {
using index_t = scalar_t;
CHECK(indices.is_contiguous());
CHECK(weights_with_metaheader.is_contiguous());
CHECK_EQ(indices.size(0),
weights_with_metaheader.size(0));
int64_t stride =
weights_with_metaheader.size(1) *
weights_with_metaheader.element_size();
CHECK_EQ(stride, block_size_);
auto indices_data_ptr = indices.data_ptr<index_t>();
void* weights_data_ptr =
weights_with_metaheader.data_ptr();
{
auto wlmap = kv_store_.by(shard_id).wlock();
auto* pool = kv_store_.pool_by(shard_id);

for (auto index_iter = indexes.begin();
index_iter != indexes.end();
index_iter++) {
const auto& id_index = *index_iter;
auto id = int64_t(indices_data_ptr[id_index]);
// Defensive programming
// it shouldn't occur under normal circumstances
auto used = FixedBlockPool::get_used(
static_cast<char*>(weights_data_ptr) +
id_index * stride);
if (!used) {
continue;
}
// use mempool
weight_type* block = nullptr;
// First check if the key already exists
auto it = wlmap->find(id);
if (it != wlmap->end()) {
block = it->second;
} else {
// Key doesn't exist, allocate new block and
// insert.
block =
pool->template allocate_t<weight_type>();
wlmap->insert({id, block});
}
std::memcpy(block,
static_cast<char*>(weights_data_ptr) +
id_index * stride,
block_size_);
}
}
});
});
futures.push_back(std::move(f));
}
return folly::collect(futures);
}

folly::SemiFuture<std::vector<folly::Unit>> get_kv_db_async(
const at::Tensor& indices,
const at::Tensor& weights,
Expand Down Expand Up @@ -528,6 +705,31 @@ class DramKVEmbeddingCache : public DramKVEmbeddingBase {
.wait();
}

// used for ckpt, get kv with metaheader from storage
void get_kv_with_metaheader_from_storage(
const at::Tensor& ids, const at::Tensor& weights_with_metaheader) {
const auto count = at::tensor({ids.size(0)}, at::ScalarType::Long);
get_kv_db_with_metaheader_async_impl(ids, weights_with_metaheader, count)
.wait();
}

void set_kv_with_metaheader_to_storage(
const at::Tensor& weights_with_metaheader) {
std::vector<int64_t> keys(weights_with_metaheader.size(0), 0);
for (int64_t i = 0; i < weights_with_metaheader.size(0); ++i) {
keys[i] = FixedBlockPool::get_key(weights_with_metaheader[i].data_ptr());
}
auto indices =
torch::from_blob(keys.data(), {int64_t(keys.size())}, torch::kInt64);
const auto count =
at::tensor({weights_with_metaheader.size(0)}, at::ScalarType::Long);
set_kv_db_with_metaheader_async_impl(
indices, weights_with_metaheader, count)
.wait();
}

size_t get_block_size() const { return block_size_; }

void compact() override {}

void trigger_feature_evict() {
Expand Down
Loading
Loading