Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 252 additions & 10 deletions fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
int64_t num_shards = 8,
int64_t num_threads = 32,
int64_t row_storage_bitwidth = 32,
bool backend_return_whole_row = false,
bool enable_async_update = false,
std::optional<at::Tensor> table_dims = std::nullopt,
std::optional<at::Tensor> hash_size_cumsum = std::nullopt)
Expand All @@ -126,6 +127,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
block_alignment_,
/*blocks_per_chunk=*/8192)),
elem_size_(row_storage_bitwidth / 8),
backend_return_whole_row_(backend_return_whole_row),
feature_evict_config_(feature_evict_config) {
executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(std::max<size_t>(
num_threads, facebook::Proc::getCpuInfo().numCpuCores));
Expand Down Expand Up @@ -608,11 +610,15 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
void set_range_to_storage(
const at::Tensor& weights,
const int64_t start,
const int64_t length) {
const auto seq_indices =
at::arange(start, start + length, at::TensorOptions().dtype(at::kLong));
const auto count = at::tensor({length}, at::ScalarType::Long);
folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count));
const int64_t length) override {
if (backend_return_whole_row_) {
set_kv_with_metaheader_to_storage(weights);
} else {
const auto seq_indices = at::arange(
start, start + length, at::TensorOptions().dtype(at::kLong));
const auto count = at::tensor({length}, at::ScalarType::Long);
folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count));
}
}

void get_range_from_snapshot(
Expand All @@ -625,10 +631,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
CHECK(snapshot_handle == nullptr);
const auto seq_indices =
at::arange(start, start + length, at::TensorOptions().dtype(at::kLong));
const auto count = at::tensor({length}, at::ScalarType::Long);
get_kv_db_async_impl(
seq_indices, weights, count, width_offset, width_length)
.wait();

if (backend_return_whole_row_) {
get_kv_with_metaheader_from_storage(seq_indices, weights);
} else {
const auto count = at::tensor({length}, at::ScalarType::Long);
get_kv_db_async_impl(
seq_indices, weights, count, width_offset, width_length)
.wait();
}

// this is called by checkpoint mostly, and checkpoint should wait until
// eviction finishes so that we could reacha consistent state before/after
// state_dict() calls
Expand All @@ -642,8 +654,41 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
int64_t width_offset = 0,
std::optional<int64_t> width_length = std::nullopt) override {
CHECK(snapshot_handle == nullptr);

if (backend_return_whole_row_) {
get_kv_with_metaheader_from_storage(
ids, weights, width_offset, width_length);
} else {
const auto count = at::tensor({ids.size(0)}, at::ScalarType::Long);
get_kv_db_async_impl(ids, weights, count, width_offset, width_length)
.wait();
}
}

// used for ckpt, get kv with metaheader from storage
void get_kv_with_metaheader_from_storage(
const at::Tensor& ids,
const at::Tensor& weights_with_metaheader,
int64_t width_offset = 0,
std::optional<int64_t> width_length = std::nullopt) {
const auto count = at::tensor({ids.size(0)}, at::ScalarType::Long);
get_kv_db_async_impl(ids, weights, count, width_offset, width_length)
get_kv_db_with_metaheader_async_impl(
ids, weights_with_metaheader, count, width_offset, width_length)
.wait();
}

void set_kv_with_metaheader_to_storage(
const at::Tensor& weights_with_metaheader) {
std::vector<int64_t> keys(weights_with_metaheader.size(0), 0);
for (int64_t i = 0; i < weights_with_metaheader.size(0); ++i) {
keys[i] = FixedBlockPool::get_key(weights_with_metaheader[i].data_ptr());
}
auto indices =
torch::from_blob(keys.data(), {int64_t(keys.size())}, torch::kInt64);
const auto count =
at::tensor({weights_with_metaheader.size(0)}, at::ScalarType::Long);
set_kv_db_with_metaheader_async_impl(
indices, weights_with_metaheader, count)
.wait();
// this is called by checkpoint mostly, and checkpoint should wait until
// eviction finishes so that we could reacha consistent state before/after
Expand Down Expand Up @@ -826,6 +871,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {

void flush_or_compact(const int64_t timestep) override {}

bool get_backend_return_whole_row() override {
return backend_return_whole_row_;
}

int64_t get_metaheader_width_in_front() override {
return backend_return_whole_row_
? FixedBlockPool::get_metaheader_dim<weight_type>()
: 0;
}

void resume_ongoing_eviction() override {
if (feature_evict_) {
feature_evict_->resume();
Expand Down Expand Up @@ -930,6 +985,192 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
return ret;
}

/// Get embeddings and metaheader from kvstore.
///
/// @param indices The 1D embedding index tensor, should skip on negative
/// value
/// @param weights_with_metaheader The 2D tensor that each row(embeddings) is
/// paired up with relative element in <indices>. This tensor will be
/// filled up with the returned embeddings from KVstore.
/// @param count A single element tensor that contains the number of indices
/// to be processed
///
/// @return None
folly::SemiFuture<std::vector<folly::Unit>>
get_kv_db_with_metaheader_async_impl(
const at::Tensor& indices,
const at::Tensor& weights_with_metaheader,
const at::Tensor& count,
int64_t width_offset = 0,
std::optional<int64_t> width_length = std::nullopt) {
std::vector<folly::Future<folly::Unit>> futures;
auto row_width = weights_with_metaheader.size(1);
auto copy_width = width_length.value_or(row_width);
CHECK_LE(row_width, block_size_);
CHECK_EQ(copy_width, row_width);
auto shardid_to_indexes = shard_input(indices, count);

for (auto iter = shardid_to_indexes.begin();
iter != shardid_to_indexes.end();
iter++) {
const auto shard_id = iter->first;
const auto indexes = iter->second;
auto f =
folly::via(executor_.get())
.thenValue([this,
shard_id,
indexes,
&indices,
&weights_with_metaheader,
width_offset,
row_width](folly::Unit) {
FBGEMM_DISPATCH_INTEGRAL_TYPES(
indices.scalar_type(),
"dram_kvstore_get_with_metaheader",
[this,
shard_id,
indexes,
&indices,
&weights_with_metaheader,
width_offset,
row_width] {
using index_t = scalar_t;
CHECK(indices.is_contiguous());
CHECK(weights_with_metaheader.is_contiguous());
CHECK_EQ(
indices.size(0), weights_with_metaheader.size(0));
auto wlmap = kv_store_.by(shard_id).wlock();
auto indices_data_ptr = indices.data_ptr<index_t>();
auto weights_data_ptr =
weights_with_metaheader.data_ptr<weight_type>();
{
for (auto index_iter = indexes.begin();
index_iter != indexes.end();
index_iter++) {
const auto weights_row_index = *index_iter;
auto weight_idx =
int64_t(indices_data_ptr[weights_row_index]);
const auto cached_iter = wlmap->find(weight_idx);
// Defensive programming
// it shouldn't occur under normal circumstances
if (cached_iter == wlmap->end()) {
std::memset(
&(weights_data_ptr
[weights_row_index * row_width]),
0,
row_width);
continue;
}

// For weight KVT, offset=0 and it will read the whole
// row. For optimizer, offset=dim(metaheader) +
// emb_dim so it will only read the optimizer part
const auto* ptr_offset_from_front =
FixedBlockPool::ptr_offset_from_front<
weight_type>(
cached_iter->second, width_offset);
std::copy(
ptr_offset_from_front,
ptr_offset_from_front + row_width,
&(weights_data_ptr
[weights_row_index * row_width]));
}
}
});
});
futures.push_back(std::move(f));
}
return folly::collect(futures);
}

/// insert embeddings and metaheader into kvstore.
/// current underlying memory management is done through F14FastMap
/// key value pair will be sharded into multiple shards to increase
/// parallelism.
///
/// @param indices The 1D embedding index tensor, should skip on negative
/// value
/// @param weights_with_metaheader The 2D tensor that each row(embeddings with
/// metaheader) is paired up with relative element in <indices>
/// @param count A single element tensor that contains the number of indices
/// to be processed
///
/// @return None
folly::SemiFuture<std::vector<folly::Unit>>
set_kv_db_with_metaheader_async_impl(
const at::Tensor& indices,
const at::Tensor& weights_with_metaheader,
const at::Tensor& count) {
std::vector<folly::Future<folly::Unit>> futures;
auto shardid_to_indexes = shard_input(indices, count);
for (auto iter = shardid_to_indexes.begin();
iter != shardid_to_indexes.end();
iter++) {
const auto shard_id = iter->first;
const auto indexes = iter->second;
auto f =
folly::via(executor_.get())
.thenValue(
[this, shard_id, indexes, &indices, &weights_with_metaheader](
folly::Unit) {
FBGEMM_DISPATCH_INTEGRAL_TYPES(
indices.scalar_type(),
"dram_kv_set_with_metaheader",
[this,
shard_id,
indexes,
&indices,
&weights_with_metaheader] {
using index_t = scalar_t;
CHECK(indices.is_contiguous());
CHECK(weights_with_metaheader.is_contiguous());
CHECK_EQ(
indices.size(0), weights_with_metaheader.size(0));
{
auto wlmap = kv_store_.by(shard_id).wlock();
auto* pool = kv_store_.pool_by(shard_id);
int64_t stride = weights_with_metaheader.size(1);
auto indices_data_ptr = indices.data_ptr<index_t>();
auto weights_data_ptr =
weights_with_metaheader.data_ptr<weight_type>();
for (auto index_iter = indexes.begin();
index_iter != indexes.end();
index_iter++) {
const auto& id_index = *index_iter;
auto id = int64_t(indices_data_ptr[id_index]);
// Defensive programming
// it shouldn't occur under normal circumstances
auto used = FixedBlockPool::get_used(
weights_data_ptr + id_index * stride);
if (!used) {
continue;
}
// use mempool
weight_type* block = nullptr;
// First check if the key already exists
auto it = wlmap->find(id);
if (it != wlmap->end()) {
block = it->second;
} else {
// Key doesn't exist, allocate new block and
// insert.
block =
pool->template allocate_t<weight_type>();
wlmap->insert({id, block});
}
std::copy(
weights_data_ptr + id_index * stride,
weights_data_ptr + (id_index + 1) * stride,
block);
}
}
});
});
futures.push_back(std::move(f));
}
return folly::collect(futures);
}

std::unique_ptr<folly::CPUThreadPoolExecutor> executor_;
// background thread
folly::FunctionScheduler scheduler_;
Expand All @@ -942,6 +1183,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
std::atomic_bool is_eviction_ongoing_ = false;
std::vector<std::unique_ptr<ssd::Initializer>> initializers_;
int64_t elem_size_;
bool backend_return_whole_row_;
std::vector<int64_t> sub_table_dims_;
std::vector<int64_t> sub_table_hash_cumsum_;
std::optional<c10::intrusive_ptr<FeatureEvictConfig>> feature_evict_config_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
int64_t row_storage_bitwidth = 32,
const std::optional<at::Tensor>& table_dims = std::nullopt,
const std::optional<at::Tensor>& hash_size_cumsum = std::nullopt,
bool backend_return_whole_row = false,
bool enable_async_update = false) {
if (row_storage_bitwidth == 16) {
impl_ = std::make_shared<kv_mem::DramKVEmbeddingCache<at::Half>>(
Expand All @@ -47,6 +48,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
num_shards,
num_threads,
row_storage_bitwidth,
backend_return_whole_row,
enable_async_update,
table_dims,
hash_size_cumsum);
Expand All @@ -59,6 +61,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
num_shards,
num_threads,
row_storage_bitwidth,
backend_return_whole_row,
enable_async_update,
table_dims,
hash_size_cumsum);
Expand Down
22 changes: 22 additions & 0 deletions fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class FixedBlockPool : public std::pmr::memory_resource {
return std::max(alignof(FixedBlockPool::MetaHeader), alignof(scalar_t));
}

// Get dimension of Metaheader
template <typename scalar_t>
static size_t get_metaheader_dim() {
return sizeof(FixedBlockPool::MetaHeader) / sizeof(scalar_t);
}

// Data pointer retrieval
template <typename scalar_t>
static scalar_t* data_ptr(scalar_t* block) {
Expand All @@ -114,6 +120,22 @@ class FixedBlockPool : public std::pmr::memory_resource {
sizeof(FixedBlockPool::MetaHeader));
}

template <typename scalar_t>
static scalar_t* ptr_offset_from_front(
scalar_t* block,
const int64_t offset) {
return reinterpret_cast<scalar_t*>(
reinterpret_cast<char*>(block) + offset * sizeof(scalar_t));
}

template <typename scalar_t>
static const scalar_t* ptr_offset_from_front(
const scalar_t* block,
const int64_t offset) {
return reinterpret_cast<const scalar_t*>(
reinterpret_cast<const char*>(block) + offset * sizeof(scalar_t));
}

template <typename scalar_t>
static scalar_t get_l2weight(scalar_t* block, size_t dimension) {
scalar_t* data = FixedBlockPool::data_ptr(block);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,16 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
return max_D_;
}

virtual bool get_backend_return_whole_row() {
// only DRAM backend can enable this for now
return false;
}

virtual int64_t get_metaheader_width_in_front() {
// will return non-zero if DRAM enables backend_return_whole_row
return 0;
}

#ifdef FBGEMM_FBCODE
folly::coro::Task<void> tensor_stream(
const at::Tensor& indices,
Expand Down
Loading
Loading