diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h index f1c5fc1bbf..565582abd8 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h @@ -105,6 +105,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { int64_t num_shards = 8, int64_t num_threads = 32, int64_t row_storage_bitwidth = 32, + bool backend_return_whole_row = false, bool enable_async_update = false, std::optional table_dims = std::nullopt, std::optional hash_size_cumsum = std::nullopt) @@ -126,6 +127,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { block_alignment_, /*blocks_per_chunk=*/8192)), elem_size_(row_storage_bitwidth / 8), + backend_return_whole_row_(backend_return_whole_row), feature_evict_config_(feature_evict_config) { executor_ = std::make_unique(std::max( num_threads, facebook::Proc::getCpuInfo().numCpuCores)); @@ -608,11 +610,15 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { void set_range_to_storage( const at::Tensor& weights, const int64_t start, - const int64_t length) { - const auto seq_indices = - at::arange(start, start + length, at::TensorOptions().dtype(at::kLong)); - const auto count = at::tensor({length}, at::ScalarType::Long); - folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count)); + const int64_t length) override { + if (backend_return_whole_row_) { + set_kv_with_metaheader_to_storage(weights); + } else { + const auto seq_indices = at::arange( + start, start + length, at::TensorOptions().dtype(at::kLong)); + const auto count = at::tensor({length}, at::ScalarType::Long); + folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count)); + } } void get_range_from_snapshot( @@ -625,10 +631,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { CHECK(snapshot_handle == nullptr); const auto seq_indices = at::arange(start, start + length, at::TensorOptions().dtype(at::kLong)); - const auto count = at::tensor({length}, at::ScalarType::Long); - get_kv_db_async_impl( - seq_indices, weights, count, width_offset, width_length) - .wait(); + + if (backend_return_whole_row_) { + get_kv_with_metaheader_from_storage(seq_indices, weights); + } else { + const auto count = at::tensor({length}, at::ScalarType::Long); + get_kv_db_async_impl( + seq_indices, weights, count, width_offset, width_length) + .wait(); + } + // this is called by checkpoint mostly, and checkpoint should wait until // eviction finishes so that we could reacha consistent state before/after // state_dict() calls @@ -642,8 +654,41 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { int64_t width_offset = 0, std::optional width_length = std::nullopt) override { CHECK(snapshot_handle == nullptr); + + if (backend_return_whole_row_) { + get_kv_with_metaheader_from_storage( + ids, weights, width_offset, width_length); + } else { + const auto count = at::tensor({ids.size(0)}, at::ScalarType::Long); + get_kv_db_async_impl(ids, weights, count, width_offset, width_length) + .wait(); + } + } + + // used for ckpt, get kv with metaheader from storage + void get_kv_with_metaheader_from_storage( + const at::Tensor& ids, + const at::Tensor& weights_with_metaheader, + int64_t width_offset = 0, + std::optional width_length = std::nullopt) { const auto count = at::tensor({ids.size(0)}, at::ScalarType::Long); - get_kv_db_async_impl(ids, weights, count, width_offset, width_length) + get_kv_db_with_metaheader_async_impl( + ids, weights_with_metaheader, count, width_offset, width_length) + .wait(); + } + + void set_kv_with_metaheader_to_storage( + const at::Tensor& weights_with_metaheader) { + std::vector keys(weights_with_metaheader.size(0), 0); + for (int64_t i = 0; i < weights_with_metaheader.size(0); ++i) { + keys[i] = FixedBlockPool::get_key(weights_with_metaheader[i].data_ptr()); + } + auto indices = + torch::from_blob(keys.data(), {int64_t(keys.size())}, torch::kInt64); + const auto count = + at::tensor({weights_with_metaheader.size(0)}, at::ScalarType::Long); + set_kv_db_with_metaheader_async_impl( + indices, weights_with_metaheader, count) .wait(); // this is called by checkpoint mostly, and checkpoint should wait until // eviction finishes so that we could reacha consistent state before/after @@ -826,6 +871,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { void flush_or_compact(const int64_t timestep) override {} + bool get_backend_return_whole_row() override { + return backend_return_whole_row_; + } + + int64_t get_metaheader_width_in_front() override { + return backend_return_whole_row_ + ? FixedBlockPool::get_metaheader_dim() + : 0; + } + void resume_ongoing_eviction() override { if (feature_evict_) { feature_evict_->resume(); @@ -930,6 +985,192 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { return ret; } + /// Get embeddings and metaheader from kvstore. + /// + /// @param indices The 1D embedding index tensor, should skip on negative + /// value + /// @param weights_with_metaheader The 2D tensor that each row(embeddings) is + /// paired up with relative element in . This tensor will be + /// filled up with the returned embeddings from KVstore. + /// @param count A single element tensor that contains the number of indices + /// to be processed + /// + /// @return None + folly::SemiFuture> + get_kv_db_with_metaheader_async_impl( + const at::Tensor& indices, + const at::Tensor& weights_with_metaheader, + const at::Tensor& count, + int64_t width_offset = 0, + std::optional width_length = std::nullopt) { + std::vector> futures; + auto row_width = weights_with_metaheader.size(1); + auto copy_width = width_length.value_or(row_width); + CHECK_LE(row_width, block_size_); + CHECK_EQ(copy_width, row_width); + auto shardid_to_indexes = shard_input(indices, count); + + for (auto iter = shardid_to_indexes.begin(); + iter != shardid_to_indexes.end(); + iter++) { + const auto shard_id = iter->first; + const auto indexes = iter->second; + auto f = + folly::via(executor_.get()) + .thenValue([this, + shard_id, + indexes, + &indices, + &weights_with_metaheader, + width_offset, + row_width](folly::Unit) { + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), + "dram_kvstore_get_with_metaheader", + [this, + shard_id, + indexes, + &indices, + &weights_with_metaheader, + width_offset, + row_width] { + using index_t = scalar_t; + CHECK(indices.is_contiguous()); + CHECK(weights_with_metaheader.is_contiguous()); + CHECK_EQ( + indices.size(0), weights_with_metaheader.size(0)); + auto wlmap = kv_store_.by(shard_id).wlock(); + auto indices_data_ptr = indices.data_ptr(); + auto weights_data_ptr = + weights_with_metaheader.data_ptr(); + { + for (auto index_iter = indexes.begin(); + index_iter != indexes.end(); + index_iter++) { + const auto weights_row_index = *index_iter; + auto weight_idx = + int64_t(indices_data_ptr[weights_row_index]); + const auto cached_iter = wlmap->find(weight_idx); + // Defensive programming + // it shouldn't occur under normal circumstances + if (cached_iter == wlmap->end()) { + std::memset( + &(weights_data_ptr + [weights_row_index * row_width]), + 0, + row_width); + continue; + } + + // For weight KVT, offset=0 and it will read the whole + // row. For optimizer, offset=dim(metaheader) + + // emb_dim so it will only read the optimizer part + const auto* ptr_offset_from_front = + FixedBlockPool::ptr_offset_from_front< + weight_type>( + cached_iter->second, width_offset); + std::copy( + ptr_offset_from_front, + ptr_offset_from_front + row_width, + &(weights_data_ptr + [weights_row_index * row_width])); + } + } + }); + }); + futures.push_back(std::move(f)); + } + return folly::collect(futures); + } + + /// insert embeddings and metaheader into kvstore. + /// current underlying memory management is done through F14FastMap + /// key value pair will be sharded into multiple shards to increase + /// parallelism. + /// + /// @param indices The 1D embedding index tensor, should skip on negative + /// value + /// @param weights_with_metaheader The 2D tensor that each row(embeddings with + /// metaheader) is paired up with relative element in + /// @param count A single element tensor that contains the number of indices + /// to be processed + /// + /// @return None + folly::SemiFuture> + set_kv_db_with_metaheader_async_impl( + const at::Tensor& indices, + const at::Tensor& weights_with_metaheader, + const at::Tensor& count) { + std::vector> futures; + auto shardid_to_indexes = shard_input(indices, count); + for (auto iter = shardid_to_indexes.begin(); + iter != shardid_to_indexes.end(); + iter++) { + const auto shard_id = iter->first; + const auto indexes = iter->second; + auto f = + folly::via(executor_.get()) + .thenValue( + [this, shard_id, indexes, &indices, &weights_with_metaheader]( + folly::Unit) { + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), + "dram_kv_set_with_metaheader", + [this, + shard_id, + indexes, + &indices, + &weights_with_metaheader] { + using index_t = scalar_t; + CHECK(indices.is_contiguous()); + CHECK(weights_with_metaheader.is_contiguous()); + CHECK_EQ( + indices.size(0), weights_with_metaheader.size(0)); + { + auto wlmap = kv_store_.by(shard_id).wlock(); + auto* pool = kv_store_.pool_by(shard_id); + int64_t stride = weights_with_metaheader.size(1); + auto indices_data_ptr = indices.data_ptr(); + auto weights_data_ptr = + weights_with_metaheader.data_ptr(); + for (auto index_iter = indexes.begin(); + index_iter != indexes.end(); + index_iter++) { + const auto& id_index = *index_iter; + auto id = int64_t(indices_data_ptr[id_index]); + // Defensive programming + // it shouldn't occur under normal circumstances + auto used = FixedBlockPool::get_used( + weights_data_ptr + id_index * stride); + if (!used) { + continue; + } + // use mempool + weight_type* block = nullptr; + // First check if the key already exists + auto it = wlmap->find(id); + if (it != wlmap->end()) { + block = it->second; + } else { + // Key doesn't exist, allocate new block and + // insert. + block = + pool->template allocate_t(); + wlmap->insert({id, block}); + } + std::copy( + weights_data_ptr + id_index * stride, + weights_data_ptr + (id_index + 1) * stride, + block); + } + } + }); + }); + futures.push_back(std::move(f)); + } + return folly::collect(futures); + } + std::unique_ptr executor_; // background thread folly::FunctionScheduler scheduler_; @@ -942,6 +1183,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { std::atomic_bool is_eviction_ongoing_ = false; std::vector> initializers_; int64_t elem_size_; + bool backend_return_whole_row_; std::vector sub_table_dims_; std::vector sub_table_hash_cumsum_; std::optional> feature_evict_config_; diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h index 77d1832a87..d396ad68e8 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h @@ -37,6 +37,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { int64_t row_storage_bitwidth = 32, const std::optional& table_dims = std::nullopt, const std::optional& hash_size_cumsum = std::nullopt, + bool backend_return_whole_row = false, bool enable_async_update = false) { if (row_storage_bitwidth == 16) { impl_ = std::make_shared>( @@ -47,6 +48,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { num_shards, num_threads, row_storage_bitwidth, + backend_return_whole_row, enable_async_update, table_dims, hash_size_cumsum); @@ -59,6 +61,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { num_shards, num_threads, row_storage_bitwidth, + backend_return_whole_row, enable_async_update, table_dims, hash_size_cumsum); diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h b/fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h index 727de1ed01..c19f841292 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h @@ -100,6 +100,12 @@ class FixedBlockPool : public std::pmr::memory_resource { return std::max(alignof(FixedBlockPool::MetaHeader), alignof(scalar_t)); } + // Get dimension of Metaheader + template + static size_t get_metaheader_dim() { + return sizeof(FixedBlockPool::MetaHeader) / sizeof(scalar_t); + } + // Data pointer retrieval template static scalar_t* data_ptr(scalar_t* block) { @@ -114,6 +120,22 @@ class FixedBlockPool : public std::pmr::memory_resource { sizeof(FixedBlockPool::MetaHeader)); } + template + static scalar_t* ptr_offset_from_front( + scalar_t* block, + const int64_t offset) { + return reinterpret_cast( + reinterpret_cast(block) + offset * sizeof(scalar_t)); + } + + template + static const scalar_t* ptr_offset_from_front( + const scalar_t* block, + const int64_t offset) { + return reinterpret_cast( + reinterpret_cast(block) + offset * sizeof(scalar_t)); + } + template static scalar_t get_l2weight(scalar_t* block, size_t dimension) { scalar_t* data = FixedBlockPool::data_ptr(block); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index 932bca8ad0..ffd36fd012 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -402,6 +402,16 @@ class EmbeddingKVDB : public std::enable_shared_from_this { return max_D_; } + virtual bool get_backend_return_whole_row() { + // only DRAM backend can enable this for now + return false; + } + + virtual int64_t get_metaheader_width_in_front() { + // will return non-zero if DRAM enables backend_return_whole_row + return 0; + } + #ifdef FBGEMM_FBCODE folly::coro::Task tensor_stream( const at::Tensor& indices, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h index 66d3ef8892..89ef37d02d 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h @@ -64,7 +64,8 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder { std::optional sorted_indices = std::nullopt, int64_t width_offset = 0, const std::optional> - checkpoint_handle = std::nullopt); + checkpoint_handle = std::nullopt, + bool read_only = false); explicit KVTensorWrapper(const std::string& serialized); @@ -112,7 +113,8 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder { std::string serialize() const; - // ONLY FOR DEBUGGING PURPOSES, Please don't use this function in production + // ONLY FOR DEBUGGING PURPOSES, Please don't use this function in + // production std::string logs() const; void deserialize(const std::string& serialized); @@ -150,6 +152,7 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder { int64_t num_threads{}; int64_t max_D{}; std::string checkpoint_uuid; + bool read_only_{}; }; void to_json(json& j, const KVTensorWrapper& kvt); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp index 7b2ba600f0..ee10494a88 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp @@ -37,7 +37,8 @@ KVTensorWrapper::KVTensorWrapper( [[maybe_unused]] const std::optional sorted_indices, [[maybe_unused]] int64_t width_offset, [[maybe_unused]] const std::optional< - c10::intrusive_ptr>) + c10::intrusive_ptr>, + [[maybe_unused]] bool read_only) // @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn : shape_(std::move(shape)), row_offset_(row_offset) { FBEXCEPTION("Not implemented"); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index b04bc3448c..f72c2b5377 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -356,11 +356,13 @@ KVTensorWrapper::KVTensorWrapper( std::optional sorted_indices, int64_t width_offset_, const std::optional> - checkpoint_handle) + checkpoint_handle, + bool read_only) : db_(nullptr), shape_(std::move(shape)), row_offset_(row_offset), - width_offset_(width_offset_) { + width_offset_(width_offset_), + read_only_(read_only) { CHECK_GE(width_offset_, 0); CHECK_EQ(shape_.size(), 2) << "Only 2D emb tensors are supported"; options_ = at::TensorOptions() @@ -480,7 +482,8 @@ at::Tensor KVTensorWrapper::narrow(int64_t dim, int64_t start, int64_t length) { CHECK_EQ(dim, 0) << "Only narrow on dim 0 is supported"; if (db_) { CHECK_TRUE(db_ != nullptr); - CHECK_GE(db_->get_max_D(), shape_[1]); + CHECK_GE( + db_->get_max_D() + db_->get_metaheader_width_in_front(), shape_[1]); TORCH_CHECK( (snapshot_handle_ == nullptr) == (std::dynamic_pointer_cast(db_).get() == nullptr), @@ -521,14 +524,19 @@ void KVTensorWrapper::set_range( const int64_t start, const int64_t length, const at::Tensor& weights) { + if (read_only_) { + XLOG(INFO) << "KVTensorWrapper is read only, set_range() is no-op"; + return; + } // Mutex lock for disabling concurrent writes to the same KVTensor std::lock_guard lock(mtx); CHECK_EQ(weights.device(), at::kCPU); CHECK(db_) << "EmbeddingRocksDB must be a valid pointer to call set_range"; CHECK_EQ(dim, 0) << "Only set_range on dim 0 is supported"; CHECK_TRUE(db_ != nullptr); - CHECK_GE(db_->get_max_D(), shape_[1]); - int pad_right = db_->get_max_D() - weights.size(1); + CHECK_GE(db_->get_max_D() + db_->get_metaheader_width_in_front(), shape_[1]); + int pad_right = + db_->get_max_D() + db_->get_metaheader_width_in_front() - weights.size(1); if (pad_right == 0) { db_->set_range_to_storage(weights, start + row_offset_, length); } else { @@ -542,6 +550,11 @@ void KVTensorWrapper::set_range( void KVTensorWrapper::set_weights_and_ids( const at::Tensor& weights, const at::Tensor& ids) { + if (read_only_) { + XLOG(INFO) + << "KVTensorWrapper is read only, set_weights_and_ids() is no-op"; + return; + } CHECK_EQ(weights.device(), at::kCPU); CHECK_TRUE(db_ != nullptr); CHECK_EQ(ids.size(0), weights.size(0)) @@ -614,8 +627,10 @@ void from_json(const ssd::json& j, KVTensorWrapper& kvt) { at::Tensor KVTensorWrapper::get_weights_by_ids(const at::Tensor& ids) { CHECK_TRUE(db_ != nullptr); - CHECK_GE(db_->get_max_D(), shape_[1]); - CHECK_GE(db_->get_max_D(), shape_[1] + width_offset_); + CHECK_GE(db_->get_max_D() + db_->get_metaheader_width_in_front(), shape_[1]); + CHECK_GE( + db_->get_max_D() + db_->get_metaheader_width_in_front(), + shape_[1] + width_offset_); TORCH_CHECK( (snapshot_handle_ == nullptr) == (std::dynamic_pointer_cast(db_).get() == nullptr), @@ -851,6 +866,7 @@ static auto dram_kv_embedding_cache_wrapper = int64_t, std::optional, std::optional, + bool, bool>(), "", { @@ -863,6 +879,7 @@ static auto dram_kv_embedding_cache_wrapper = torch::arg("row_storage_bitwidth") = 32, torch::arg("table_dims") = std::nullopt, torch::arg("hash_size_cumsum") = std::nullopt, + torch::arg("backend_return_whole_row") = false, torch::arg("enable_async_update") = false, }) .def( @@ -948,7 +965,8 @@ static auto kv_tensor_wrapper = std::optional, int64_t, std::optional< - c10::intrusive_ptr>>(), + c10::intrusive_ptr>, + bool>(), "", {torch::arg("shape"), torch::arg("dtype"), @@ -958,7 +976,8 @@ static auto kv_tensor_wrapper = torch::arg("snapshot_handle") = std::nullopt, torch::arg("sorted_indices") = std::nullopt, torch::arg("width_offset") = 0, - torch::arg("checkpoint_handle") = std::nullopt}) + torch::arg("checkpoint_handle") = std::nullopt, + torch::arg("read_only") = false}) .def( "set_embedding_rocks_dp_wrapper", &KVTensorWrapper::set_embedding_rocks_dp_wrapper,