From 38990228efec06792dd11dc4914361748174ed1c Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 26 Mar 2025 10:36:31 +0000 Subject: [PATCH 01/10] init gloo support --- src/turbomind/comm/CMakeLists.txt | 13 + src/turbomind/comm/gloo/CMakeLists.txt | 38 +++ src/turbomind/comm/gloo/gloo_comm.cc | 309 +++++++++++++++++++++++ src/turbomind/comm/host_comm.cc | 8 + src/turbomind/comm/host_comm.h | 47 +++- src/turbomind/comm/serialize.cc | 174 +++++++++++++ src/turbomind/comm/serialize.h | 84 ++++++ src/turbomind/engine/model_request.cc | 29 +-- src/turbomind/engine/request.h | 4 + src/turbomind/models/llama/LlamaBatch.cc | 69 +++++ src/turbomind/utils/Tensor.cc | 28 ++ src/turbomind/utils/Tensor.h | 2 + 12 files changed, 774 insertions(+), 31 deletions(-) create mode 100644 src/turbomind/comm/gloo/CMakeLists.txt create mode 100644 src/turbomind/comm/gloo/gloo_comm.cc create mode 100644 src/turbomind/comm/serialize.cc create mode 100644 src/turbomind/comm/serialize.h diff --git a/src/turbomind/comm/CMakeLists.txt b/src/turbomind/comm/CMakeLists.txt index 43a2dacf21..56906be91d 100644 --- a/src/turbomind/comm/CMakeLists.txt +++ b/src/turbomind/comm/CMakeLists.txt @@ -3,6 +3,7 @@ cmake_minimum_required(VERSION 3.8) add_library(host_comm STATIC host_comm.cc thread_comm.cc) +target_link_libraries(host_comm PRIVATE logger) set_property(TARGET host_comm PROPERTY POSITION_INDEPENDENT_CODE ON) add_library(device_comm STATIC device_comm.cc) @@ -19,9 +20,21 @@ if (BUILD_MULTI_GPU) target_link_libraries(device_comm INTERFACE nccl_comm) endif () + add_subdirectory(gloo) + target_link_libraries(host_comm INTERFACE gloo_comm) + + add_library(serialize STATIC serialize.cc) + target_link_libraries(serialize INTERFACE tensor) + set_property(TARGET serialize PROPERTY POSITION_INDEPENDENT_CODE ON) + target_link_libraries(host_comm INTERFACE serialize) + if (BUILD_TEST) add_executable(test_comm test_comm.cu) target_link_libraries(test_comm PRIVATE device_comm host_comm pthread nvtx_utils) target_compile_options(test_comm PRIVATE -O3 -march=native -mtune=native) + + # add_executable(test_gloo test_gloo.cc) + # target_link_libraries(test_gloo PRIVATE host_comm) + # target_compile_options(test_gloo PRIVATE -O3 -march=native -mtune=native) endif () endif () diff --git a/src/turbomind/comm/gloo/CMakeLists.txt b/src/turbomind/comm/gloo/CMakeLists.txt new file mode 100644 index 0000000000..946849945b --- /dev/null +++ b/src/turbomind/comm/gloo/CMakeLists.txt @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +cmake_minimum_required(VERSION 3.8) + +include(FetchContent) +FetchContent_Declare( + gloo + GIT_REPOSITORY https://github.com/facebookincubator/gloo.git + GIT_TAG cbe963b5a43cd75e6eca4f74d2bb38ec8dcfdbc8 +) + +# some settings of gloo, +set(GLOO_INSTALL OFF CACHE BOOL "" FORCE) +set(GLOO_STATIC_OR_SHARED STATIC CACHE STRING "" FORCE) +set(__USE_NCCL ${USE_NCCL}) +set(__BUILD_TEST ${BUILD_TEST}) +set(USE_NCCL OFF) +set(BUILD_TEST OFF) +set(USE_REDIS ON) # TODO remove, use tcp_store instead +FetchContent_MakeAvailable(gloo) + +# gloo build doesn't add include directories as a target property... +target_include_directories(gloo PUBLIC + $ + $ # config.h generated at cmake config time +) +set(USE_NCCL ${__USE_NCCL}) +set(BUILD_TEST ${__BUILD_TEST}) + +add_library(gloo_comm STATIC + gloo_comm.cc + # tcp_store.cc +) +set_property(TARGET gloo_comm PROPERTY POSITION_INDEPENDENT_CODE ON) +target_link_libraries(gloo_comm PRIVATE gloo logger) + +# TODO remove, use tcp_store instead +include_directories(SYSTEM ${HIREDIS_INCLUDE_DIRS}) +target_link_libraries(gloo_comm PRIVATE ${HIREDIS_LIBRARIES}) diff --git a/src/turbomind/comm/gloo/gloo_comm.cc b/src/turbomind/comm/gloo/gloo_comm.cc new file mode 100644 index 0000000000..4445e6f6f4 --- /dev/null +++ b/src/turbomind/comm/gloo/gloo_comm.cc @@ -0,0 +1,309 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/turbomind/comm/host_comm.h" +#include "src/turbomind/utils/logger.h" + +namespace turbomind::comm { + +const char* GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; +const char STORE_INFO_DELIM = ','; + +std::shared_ptr<::gloo::transport::Device> createGlooDevice() +{ + ::gloo::transport::tcp::attr attr; + if (auto ifname = std::getenv(GLOO_SOCKET_IFNAME_ENV); ifname) { + attr.iface = ifname; + } + else { + attr.hostname = ::gloo::getHostname(); + } + return ::gloo::transport::tcp::CreateDevice(attr); +} + +class Store: public ::gloo::rendezvous::PrefixStore { +public: + explicit Store(const std::string& host, int port, const std::string& prefix): + host_(host), port_(port), redis_store_(host_, port_), ::gloo::rendezvous::PrefixStore(prefix, redis_store_){}; + + ~Store() = default; + + std::shared_ptr New(const std::string& prefix) + { + std::string new_prefix = prefix + "/" + prefix_; + return std::make_shared(host_, port_, new_prefix); + } + +public: + std::string host_; + int port_; + ::gloo::rendezvous::RedisStore redis_store_; + using ::gloo::rendezvous::PrefixStore::prefix_; +}; + +class GlobalStoreFactory { +public: + static GlobalStoreFactory& Instance() + { + static GlobalStoreFactory instance; + return instance; + } + + std::string New() + { + std::lock_guard lock(mutex_); + + // TODO: use param instead of env + auto get = [](const std::string& key, const std::string& default_value) { + const char* value = std::getenv(key.c_str()); + return value ? std::string(value) : default_value; + }; + + std::string host = get("STORE_ADDR", "127.0.0.1"); + int port = std::stoi(get("STORE_PORT", "6800")); + + std::stringstream ss; + ss << host << STORE_INFO_DELIM << port << STORE_INFO_DELIM << prefix_++; + return ss.str(); + } + + std::shared_ptr Load(const std::string& info) + { + std::stringstream ss(info); + std::vector keys; + std::string local; + while (getline(ss, local, STORE_INFO_DELIM)) { + keys.push_back(std::move(local)); + } + FT_CHECK(keys.size() == 3); + + std::string host = keys[0]; + int port = stoi(keys[1]); + std::string prefix = keys[2]; + + return std::make_shared(host, port, prefix); + } + +private: + GlobalStoreFactory() {} + + std::mutex mutex_; + int prefix_{0}; +}; + +typedef void (*ReduceFunc)(void*, const void*, const void*, size_t); + +struct GlooCommImpl: public HostCommImpl { + + struct SplitInfo { + int color; + int rank; + + bool operator<(const SplitInfo& other) const + { + return (color < other.color) || (color == other.color && rank < other.rank); + } + + bool operator==(const SplitInfo& other) const + { + return (color == other.color) && (rank == other.rank); + } + }; + + GlooCommImpl(std::shared_ptr store, int n_ranks, int rank): + store_{std::move(store)}, rank_{rank}, n_ranks_{n_ranks} + { + // TM_LOG_INFO("[GlooCommImpl] rank=%d, n_ranks=%d, prefix=%s", rank_, n_ranks_, store_->prefix_.c_str()); + device_ = createGlooDevice(); + context_ = std::make_shared<::gloo::rendezvous::Context>(rank_, n_ranks_); + context_->connectFullMesh(*store_, device_); + } + + ~GlooCommImpl() {} + + int rank() const override + { + return rank_; + } + + int n_ranks() const override + { + return n_ranks_; + } + + bool is_same_process() const override + { + return false; + } + + std::shared_ptr Split(int color, int key) override + { + // don't know why key was set to 0 + auto vec = comm::AllGather(this, SplitInfo{color, rank_}); + auto last = std::stable_partition(vec.begin(), vec.end(), [&](auto x) { // + return x.color == color; + }); + vec.erase(last, vec.end()); + std::stable_sort(vec.begin(), vec.end(), [](auto& a, auto& b) { // + return a < b; + }); + + auto new_prefix = std::to_string(color) + ":" + std::to_string(n_split_++); + auto new_store = store_->New(new_prefix); + int new_n_ranks = vec.size(); + int new_rank = std::find(vec.begin(), vec.end(), SplitInfo{color, rank_}) - vec.begin(); + return std::make_shared(new_store, new_n_ranks, new_rank); + } + + void Sync() override + { + ::gloo::BarrierOptions opts(context_); + opts.setTimeout(std::chrono::milliseconds(1000 * 60 * 30)); + ::gloo::barrier(opts); + } + + void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy) override + { + ::gloo::BroadcastOptions opts(context_); + opts.setRoot(root); + opts.setTimeout(std::chrono::milliseconds(1000 * 60 * 30)); + opts.setOutput((char*)data, count); + ::gloo::broadcast(opts); + } + + void AllGather(void* data, int count, DataType dtype, copy_fn copy) override + { + ::gloo::AllgatherOptions opts(context_); + opts.setTimeout(std::chrono::milliseconds(1000 * 60 * 30)); + opts.setOutput((char*)data, count * n_ranks_); + ::gloo::allgather(opts); + } + + static ReduceFunc getReduceFunc(DataType dtype, RedOp red_op) + { + + auto dispatch_op = [&](auto t) -> ReduceFunc { + using T = decltype(t); + switch (red_op) { + case RedOp::kSum: + return ::gloo::sum; + case RedOp::kMax: + return ::gloo::max; + case RedOp::kMin: + return ::gloo::min; + default: + return {}; + } + }; + + auto dispatch = [&]() -> ReduceFunc { + switch (dtype) { + case DataType::TYPE_INT32: + return dispatch_op(int32_t{}); + case DataType::TYPE_INT64: + return dispatch_op(int64_t{}); + case DataType::TYPE_UINT32: + return dispatch_op(uint32_t{}); + case DataType::TYPE_UINT64: + return dispatch_op(uint64_t{}); + default: + return {}; + } + }; + + if (auto fn = dispatch()) { + return fn; + } + else { + throw std::runtime_error("not implemented"); + return {}; + } + } + + void AllReduce(void* data, int count, DataType dtype, RedOp red_op) override + { + ::gloo::AllreduceOptions opts(context_); + opts.setTimeout(std::chrono::milliseconds(1000 * 60 * 30)); + opts.setReduceFunction(getReduceFunc(dtype, red_op)); + switch (dtype) { + case DataType::TYPE_INT32: + opts.setOutput((int32_t*)data, count); + break; + case DataType::TYPE_INT64: + opts.setOutput((int64_t*)data, count); + break; + case DataType::TYPE_UINT32: + opts.setOutput((uint32_t*)data, count); + break; + case DataType::TYPE_UINT64: + opts.setOutput((uint64_t*)data, count); + break; + default: + throw std::runtime_error("not implemented"); + } + ::gloo::allreduce(opts); + } + + int n_split_{}; + std::shared_ptr<::gloo::transport::Device> device_; + std::shared_ptr<::gloo::rendezvous::Context> context_; + std::shared_ptr store_; + int rank_; + int n_ranks_; + uint32_t tag_{}; +}; + +class GlooGroupId: public HostGroupId { + + void Initialize() override + { + info_ = GlobalStoreFactory::Instance().New(); + // TM_LOG_ERROR("[GlooGroupId][Initialize] info=%s", info_.c_str()); + } + + void Export(std::ostream& os) override + { + os << info_; + } + + void Import(std::istream& is) override + { + std::stringstream ss; + ss << is.rdbuf(); + info_ = ss.str(); + } + + HostComm CreateCommunicator(int n_ranks, int rank) override + { + FT_CHECK(info_ != ""); + auto impl = std::make_shared(GlobalStoreFactory::Instance().Load(info_), n_ranks, rank); + return std::static_pointer_cast(impl); + } + +private: + std::string info_; // ip,port,prefix + std::shared_ptr<::gloo::rendezvous::Store> store_; +}; + +std::unique_ptr CreateGlooGroupId() +{ + return std::make_unique(); +} + +} // namespace turbomind::comm diff --git a/src/turbomind/comm/host_comm.cc b/src/turbomind/comm/host_comm.cc index 0d3cf367e2..37c45732df 100644 --- a/src/turbomind/comm/host_comm.cc +++ b/src/turbomind/comm/host_comm.cc @@ -8,8 +8,16 @@ HostCommImpl::~HostCommImpl() = default; std::unique_ptr CreateThreadGroupId(); +std::unique_ptr CreateGlooGroupId(); + std::unique_ptr CreateHostGroupId(const std::string& backend) { +#ifdef BUILD_MULTI_GPU + if (backend == "gloo") { + return CreateGlooGroupId(); + } +#endif + return CreateThreadGroupId(); } diff --git a/src/turbomind/comm/host_comm.h b/src/turbomind/comm/host_comm.h index 5cf35d7b28..aa11d06b71 100644 --- a/src/turbomind/comm/host_comm.h +++ b/src/turbomind/comm/host_comm.h @@ -7,7 +7,9 @@ #include #include +#include "src/turbomind/comm/serialize.h" #include "src/turbomind/utils/Tensor.h" +#include "src/turbomind/utils/logger.h" namespace turbomind::comm { @@ -87,7 +89,23 @@ void Broadcast(HostCommImpl* comm, T* data, int n, int root) comm->Broadcast(data, n, TYPE_INVALID, root, detail::copy_fn); } else { - throw std::runtime_error("not implemented"); + try { + // buf may have different size on different ranks + std::vector buf; + serialize(data, n, buf); + size_t size = buf.size(); + Broadcast(comm, &size, 1, root); + buf.resize(size); + comm->Broadcast(buf.data(), buf.size(), TYPE_INT8, root, detail::copy_fn); + if (comm->rank() != root) { + // some field in data may be not shared by all rank + deserialize(data, n, buf); + } + } + catch (const std::invalid_argument& e) { + TM_LOG_ERROR("Broadcast failed: %s", e.what()); + throw; + } } } } @@ -104,8 +122,31 @@ void AllGather(HostCommImpl* comm, T* data, int n) comm->AllGather(data, n, TYPE_INVALID, detail::copy_fn); } else { - /// serialize data - throw std::runtime_error("not implemented"); + try { + // buf may have different size on different ranks + std::vector rbuf; + for (int i = 0; i < n; ++i) { + std::vector ibuf; + serialize(data + n * comm->rank() + i, 1, ibuf); + rbuf.insert(rbuf.end(), ibuf.begin(), ibuf.end()); + } + int size = rbuf.size(); + comm->AllReduce(&size, 1, TYPE_INT32, RedOp::kMax); + std::vector buf(size * comm->n_ranks()); + memcpy(buf.data() + comm->rank() * size, rbuf.data(), rbuf.size()); + comm->AllGather(buf.data(), size, TYPE_INT8, detail::copy_fn); + for (int i = 0; i < comm->n_ranks(); ++i) { + if (i != comm->rank()) { + // some field in data may be not shared by all rank + deserialize( + data + n * i, n, std::vector(buf.begin() + i * size, buf.begin() + (i + 1) * size)); + } + } + } + catch (const std::invalid_argument& e) { + TM_LOG_ERROR("AllGather failed: %s", e.what()); + throw; + } } } } diff --git a/src/turbomind/comm/serialize.cc b/src/turbomind/comm/serialize.cc new file mode 100644 index 0000000000..a4ae66300e --- /dev/null +++ b/src/turbomind/comm/serialize.cc @@ -0,0 +1,174 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include + +#include "src/turbomind/comm/serialize.h" +#include "src/turbomind/engine/request.h" +#include "src/turbomind/utils/Tensor.h" + +namespace turbomind::comm { + +std::vector streambuf_to_vector(std::streambuf* sb) +{ + auto start = sb->pubseekoff(0, std::ios::beg, std::ios::in); + auto end = sb->pubseekoff(0, std::ios::end, std::ios::in); + auto size = end - start; + + std::vector buffer(size); + sb->pubseekpos(start); + sb->sgetn(buffer.data(), size); + return buffer; +} + +void serialize(std::ostream& os, const std::string& s) +{ + int size = s.length(); + serialize(os, size); + os << s; +} + +void deserialize(std::istream& is, std::string& s) +{ + int size; + deserialize(is, size); + s.resize(size); + is.read(s.data(), size); +} + +void serialize(std::ostream& os, const GenerationConfig& gen) +{ + serialize(os, gen.max_new_tokens); + serialize(os, gen.min_new_tokens); + serialize(os, gen.eos_ids); + serialize(os, gen.stop_ids[0]); + serialize(os, gen.stop_ids[1]); + serialize(os, gen.bad_ids[0]); + serialize(os, gen.bad_ids[1]); + serialize(os, gen.top_k); + serialize(os, gen.top_p); + serialize(os, gen.min_p); + serialize(os, gen.temperature); + serialize(os, gen.repetition_penalty); + serialize(os, gen.random_seed); + serialize(os, gen.output_logprobs); + serialize(os, gen.output_last_hidden_state); + serialize(os, gen.output_logits); +} + +void deserialize(std::istream& is, GenerationConfig& gen) +{ + deserialize(is, gen.max_new_tokens); + deserialize(is, gen.min_new_tokens); + deserialize(is, gen.eos_ids); + deserialize(is, gen.stop_ids[0]); + deserialize(is, gen.stop_ids[1]); + deserialize(is, gen.bad_ids[0]); + deserialize(is, gen.bad_ids[1]); + deserialize(is, gen.top_k); + deserialize(is, gen.top_p); + deserialize(is, gen.min_p); + deserialize(is, gen.temperature); + deserialize(is, gen.repetition_penalty); + deserialize(is, gen.random_seed); + deserialize(is, gen.output_logprobs); + deserialize(is, gen.output_last_hidden_state); + deserialize(is, gen.output_logits); +} + +void serialize(std::ostream& os, const SessionParam& sess) +{ + serialize(os, sess.id); + serialize(os, sess.step); + serialize(os, sess.start_flag); + serialize(os, sess.end_flag); + serialize(os, sess.kill_flag); +} + +void deserialize(std::istream& is, SessionParam& sess) +{ + deserialize(is, sess.id); + deserialize(is, sess.step); + deserialize(is, sess.start_flag); + deserialize(is, sess.end_flag); + deserialize(is, sess.kill_flag); +} + +//--------------------------- below may not right --------------------------- + +void serialize(std::ostream& os, const Tensor& tensor) +{ + serialize(os, tensor.where); + serialize(os, tensor.type); + serialize(os, tensor.shape); + serialize(os, tensor.offsets); + os.write(tensor.getPtr(), tensor.sizeBytes()); +} + +void deserialize(std::istream& is, ManagedTensor& holder) +{ + Tensor tensor{}; + deserialize(is, tensor.where); + deserialize(is, tensor.type); + deserialize(is, tensor.shape); + deserialize(is, tensor.offsets); // not used + int64_t byte_size{}; + holder = ManagedTensor::create( + tensor.type, tensor.where, std::vector(tensor.shape.begin(), tensor.shape.end()), byte_size); + is.read(holder->getPtr(), byte_size); +} + +void serialize(std::ostream& os, const TensorMap& map) +{ + int size = map.size(); + serialize(os, size); + for (const auto& [key, tensor] : map) { + serialize(os, key); + serialize(os, tensor); + } +} + +void deserialize(std::istream& is, TensorMap& map, Request::TensorMap_& map_) +{ + int size; + deserialize(is, size); + for (int i = 0; i < size; ++i) { + std::string key; + deserialize(is, key); + + ManagedTensor tensor; + deserialize(is, tensor); + map_.emplace(key, tensor); + + map.insert(key, *tensor); + } +} + +void serialize(std::ostream& os, const Request& req) +{ + serialize(os, req.id); + serialize(os, req.unique_id); + serialize(os, req.session); + serialize(os, req.gen_cfg); + serialize(os, req.stream_output); + serialize(os, req.inputs); + serialize(os, req.outputs); + serialize(os, req.ec); +} + +void deserialize(std::istream& is, Request& req) +{ + deserialize(is, req.id); + deserialize(is, req.unique_id); + deserialize(is, req.session); + deserialize(is, req.gen_cfg); + deserialize(is, req.stream_output); + deserialize(is, req.inputs, req.ipc_buffer); + deserialize(is, req.outputs, req.ipc_buffer); + deserialize(is, req.ec); + + req.output_ids = req.outputs.at("output_ids"); + req.sequence_length = req.outputs.at("sequence_length"); +} + +} // namespace turbomind::comm diff --git a/src/turbomind/comm/serialize.h b/src/turbomind/comm/serialize.h new file mode 100644 index 0000000000..f935a15798 --- /dev/null +++ b/src/turbomind/comm/serialize.h @@ -0,0 +1,84 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "src/turbomind/engine/request.h" +#include "src/turbomind/utils/Tensor.h" + +namespace turbomind::comm { + +std::vector streambuf_to_vector(std::streambuf* sb); + +template +inline void serialize(const T*, int n, std::vector&) +{ + throw std::invalid_argument("not implemented"); +} + +template +inline void deserialize(T*, int n, const std::vector&) +{ + throw std::invalid_argument("not implemented"); +} + +template>> +inline void serialize(std::ostream& os, const T& v) +{ + os.write((char*)&v, sizeof(v)); +} + +template>> +inline void deserialize(std::istream& is, T& v) +{ + is.read((char*)&v, sizeof(v)); +} + +void serialize(std::ostream& os, const std::string& s); + +void deserialize(std::istream& is, std::string& s); + +template>> +inline void serialize(std::ostream& os, const std::vector& vec) +{ + int size = vec.size(); + os.write((char*)&size, sizeof(int)); + os.write((char*)vec.data(), sizeof(T) * size); +} + +template>> +inline void deserialize(std::istream& is, std::vector& vec) +{ + int size; + is.read((char*)&size, sizeof(int)); + vec.resize(size); + is.read((char*)vec.data(), sizeof(T) * size); +} + +void serialize(std::ostream& os, const GenerationConfig& gen); + +void deserialize(std::istream& is, GenerationConfig& gen); + +void serialize(std::ostream& os, const SessionParam& sess); + +void deserialize(std::istream& is, SessionParam& sess); + +void serialize(std::ostream& os, const Tensor& tensor); + +void deserialize(std::istream& is, ManagedTensor& holder); + +void serialize(std::ostream& os, const TensorMap& map); + +void deserialize(std::istream& is, TensorMap& map, Request::TensorMap_& map_); + +void serialize(std::ostream& os, const Request& req); + +void deserialize(std::istream& is, Request& req); + +} // namespace turbomind::comm diff --git a/src/turbomind/engine/model_request.cc b/src/turbomind/engine/model_request.cc index 6ba355e896..2017dac660 100644 --- a/src/turbomind/engine/model_request.cc +++ b/src/turbomind/engine/model_request.cc @@ -17,33 +17,6 @@ namespace turbomind { -static ManagedTensor create(DataType dtype, MemoryType where, const std::vector& size, int64_t& byte_size) -{ - byte_size = std::accumulate(size.begin(), size.end(), Tensor::getTypeSize(dtype), std::multiplies<>{}); - void* data{}; - - if (where == MEMORY_GPU) { - check_cuda_error(cudaMallocAsync(&data, byte_size, nullptr)); - } - else { - data = std::malloc(byte_size); - } - - ManagedTensor ret; - ret.tensor = Tensor{where, dtype, std::vector(size.begin(), size.end()), data}; - ret.data_holder.reset((void*)nullptr, [data, where](auto) { - // std::cerr << "turbomind tensor deallocate" << std::endl; - if (where == MEMORY_GPU) { - /// TODO: guard device id - check_cuda_error(cudaFreeAsync(data, nullptr)); - } - else { - std::free(data); - } - }); - return ret; -} - template static T get(const std::unordered_map& m, const std::string& key, T fallback = {}) { @@ -97,7 +70,7 @@ auto ModelRequest::Forward(InputParam param, std::function cb) -> Output shape_ = {shape.cbegin(), shape.cend()}; } int64_t byte_size{}; - auto it = dest->emplace(key, create(dtype, where, shape_, byte_size)).first; + auto it = dest->emplace(key, ManagedTensor::create(dtype, where, shape_, byte_size)).first; return std::make_pair(it->second->data, byte_size); }; diff --git a/src/turbomind/engine/request.h b/src/turbomind/engine/request.h index 28f2943b54..eb3aa95c9a 100644 --- a/src/turbomind/engine/request.h +++ b/src/turbomind/engine/request.h @@ -147,6 +147,10 @@ struct Request { kFinish = 7, kCancel = 8, }; + + // data holder(tensor) for inter-process + using TensorMap_ = std::unordered_map; + TensorMap_ ipc_buffer; }; inline void UpdateState(Request& r, int status, int seq_len) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 51180b2b2f..a3220cd99a 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -39,6 +39,7 @@ #include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/comm/serialize.h" #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/anomaly_handler.h" #include "src/turbomind/utils/constant.h" @@ -1342,6 +1343,9 @@ void LlamaBatch::OutputLogits(const float* logits, int first, int last, Gener template void LlamaBatch::OutputLastHiddenState(const T* hidden_states, int first, int last) { + if (tp_rank_ != 0) { + return; + } for (int i = first; i < last; ++i) { const int input_len = h_input_length_buf_[i]; // input lenght for this iter @@ -1626,6 +1630,71 @@ struct RequestData { } // namespace +namespace comm { + +void serialize(std::ostream& os, const RequestData& req) +{ + // std::vector> infer; + serialize(os, (int)req.infer.size()); + for (const auto& r : req.infer) { + serialize(os, *r); + } + // std::vector> kill; + serialize(os, (int)req.kill.size()); + for (const auto& r : req.kill) { + serialize(os, *r); + } + + serialize(os, req.cancel); // std::vector cancel; + serialize(os, req.abort); // bool abort; +} + +template<> +void serialize(const std::shared_ptr* req, int n, std::vector& vec) +{ + std::stringstream ss; + for (int i = 0; i < n; ++i) { + const auto& r = req[i]; + if (r != nullptr) { + serialize(ss, *r); + } + } + vec = streambuf_to_vector(ss.rdbuf()); +} + +void deserialize(std::istream& is, RequestData& req) +{ + auto process = [](std::istream& is, std::vector>& vec) { + int size; + deserialize(is, size); + vec.resize(size); + for (auto& r : vec) { + r = std::make_shared(); + deserialize(is, *r); + } + }; + process(is, req.infer); + process(is, req.kill); + deserialize(is, req.cancel); + deserialize(is, req.abort); +} + +template<> +void deserialize(std::shared_ptr* req, int n, const std::vector& vec) +{ + std::stringstream ss; + ss.write(vec.data(), vec.size()); + for (int i = 0; i < n; ++i) { + auto& r = req[i]; + if (r == nullptr) { + r = std::make_shared(); + } + deserialize(ss, *r); + } +} + +} // namespace comm + template void LlamaBatch::InternalThreadEntry() { diff --git a/src/turbomind/utils/Tensor.cc b/src/turbomind/utils/Tensor.cc index 7a2cedac13..c8211e69f8 100644 --- a/src/turbomind/utils/Tensor.cc +++ b/src/turbomind/utils/Tensor.cc @@ -438,4 +438,32 @@ void TensorMap::saveNpy(const std::string& base_folder) } } +ManagedTensor +ManagedTensor::create(DataType dtype, MemoryType where, const std::vector& size, int64_t& byte_size) +{ + byte_size = std::accumulate(size.begin(), size.end(), Tensor::getTypeSize(dtype), std::multiplies<>{}); + void* data{}; + + if (where == MEMORY_GPU) { + check_cuda_error(cudaMallocAsync(&data, byte_size, nullptr)); + } + else { + data = std::malloc(byte_size); + } + + ManagedTensor ret; + ret.tensor = Tensor{where, dtype, std::vector(size.begin(), size.end()), data}; + ret.data_holder.reset((void*)nullptr, [data, where](auto) { + // std::cerr << "turbomind tensor deallocate" << std::endl; + if (where == MEMORY_GPU) { + /// TODO: guard device id + check_cuda_error(cudaFreeAsync(data, nullptr)); + } + else { + std::free(data); + } + }); + return ret; +} + } // namespace turbomind diff --git a/src/turbomind/utils/Tensor.h b/src/turbomind/utils/Tensor.h index bf9840314c..2c2be38424 100644 --- a/src/turbomind/utils/Tensor.h +++ b/src/turbomind/utils/Tensor.h @@ -577,6 +577,8 @@ struct ManagedTensor { { return tensor; } + + static ManagedTensor create(DataType dtype, MemoryType where, const std::vector& size, int64_t& byte_size); }; } // namespace turbomind From 916e44b7500d45556e9ef0a24bad245e48aada47 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 2 Apr 2025 13:12:28 +0000 Subject: [PATCH 02/10] use pytorch tcpstore --- src/turbomind/comm/gloo/CMakeLists.txt | 13 +- src/turbomind/comm/gloo/gloo_comm.cc | 15 +- src/turbomind/comm/gloo/tcp_store.cc | 217 +++++++++++++++++++++++++ src/turbomind/comm/gloo/tcp_store.h | 37 +++++ 4 files changed, 272 insertions(+), 10 deletions(-) create mode 100644 src/turbomind/comm/gloo/tcp_store.cc create mode 100644 src/turbomind/comm/gloo/tcp_store.h diff --git a/src/turbomind/comm/gloo/CMakeLists.txt b/src/turbomind/comm/gloo/CMakeLists.txt index 946849945b..52dc579324 100644 --- a/src/turbomind/comm/gloo/CMakeLists.txt +++ b/src/turbomind/comm/gloo/CMakeLists.txt @@ -15,7 +15,7 @@ set(__USE_NCCL ${USE_NCCL}) set(__BUILD_TEST ${BUILD_TEST}) set(USE_NCCL OFF) set(BUILD_TEST OFF) -set(USE_REDIS ON) # TODO remove, use tcp_store instead +# set(USE_REDIS ON) # TODO remove, use tcp_store instead FetchContent_MakeAvailable(gloo) # gloo build doesn't add include directories as a target property... @@ -28,11 +28,12 @@ set(BUILD_TEST ${__BUILD_TEST}) add_library(gloo_comm STATIC gloo_comm.cc - # tcp_store.cc + tcp_store.cc ) set_property(TARGET gloo_comm PROPERTY POSITION_INDEPENDENT_CODE ON) -target_link_libraries(gloo_comm PRIVATE gloo logger) +target_link_libraries(gloo_comm PUBLIC gloo logger) -# TODO remove, use tcp_store instead -include_directories(SYSTEM ${HIREDIS_INCLUDE_DIRS}) -target_link_libraries(gloo_comm PRIVATE ${HIREDIS_LIBRARIES}) +if(USE_REDIS) + include_directories(SYSTEM ${HIREDIS_INCLUDE_DIRS}) + target_link_libraries(gloo_comm PRIVATE ${HIREDIS_LIBRARIES}) +endif() diff --git a/src/turbomind/comm/gloo/gloo_comm.cc b/src/turbomind/comm/gloo/gloo_comm.cc index 4445e6f6f4..97a520c146 100644 --- a/src/turbomind/comm/gloo/gloo_comm.cc +++ b/src/turbomind/comm/gloo/gloo_comm.cc @@ -12,11 +12,14 @@ #include #include #include +#if GLOO_USE_REDIS #include +#endif #include #include #include +#include "src/turbomind/comm/gloo/tcp_store.h" #include "src/turbomind/comm/host_comm.h" #include "src/turbomind/utils/logger.h" @@ -40,7 +43,7 @@ std::shared_ptr<::gloo::transport::Device> createGlooDevice() class Store: public ::gloo::rendezvous::PrefixStore { public: explicit Store(const std::string& host, int port, const std::string& prefix): - host_(host), port_(port), redis_store_(host_, port_), ::gloo::rendezvous::PrefixStore(prefix, redis_store_){}; + host_(host), port_(port), store_(host_, port_), ::gloo::rendezvous::PrefixStore(prefix, store_){}; ~Store() = default; @@ -51,9 +54,13 @@ class Store: public ::gloo::rendezvous::PrefixStore { } public: - std::string host_; - int port_; - ::gloo::rendezvous::RedisStore redis_store_; + std::string host_; + int port_; + +#if GLOO_USE_REDIS +// ::gloo::rendezvous::RedisStore store_; +#endif + TCPStore store_; using ::gloo::rendezvous::PrefixStore::prefix_; }; diff --git a/src/turbomind/comm/gloo/tcp_store.cc b/src/turbomind/comm/gloo/tcp_store.cc new file mode 100644 index 0000000000..9a2e74351e --- /dev/null +++ b/src/turbomind/comm/gloo/tcp_store.cc @@ -0,0 +1,217 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include +#include +#include + +#include +#include + +#include "src/turbomind/comm/gloo/tcp_store.h" +#include "src/turbomind/utils/logger.h" + +namespace turbomind::comm { + +namespace { + +// copy from pytorch https://github.com/pytorch/pytorch/blob/v2.5.1/torch/csrc/distributed/c10d/TCPStoreBackend.hpp + +static const uint32_t validationMagicNumber = 0x3C85F7CE; + +enum class CheckResponseType : uint8_t +{ + READY, + NOT_READY +}; + +enum class QueryType : uint8_t +{ + VALIDATE, + SET, + COMPARE_SET, + GET, + ADD, + CHECK, + WAIT, + GETNUMKEYS, + DELETE_KEY, + APPEND, + MULTI_GET, + MULTI_SET, + CANCEL_WAIT, + PING, +}; + +} // namespace + +struct Buffer { + std::vector buffer; + + template>> + void append(T val) + { + char* ptr = (char*)&val; + buffer.insert(buffer.end(), ptr, ptr + sizeof(T)); + } + + void append(const std::vector& vec) + { + append((uint64_t)vec.size()); + buffer.insert(buffer.end(), vec.begin(), vec.end()); + } + + void append(const std::string& str) + { + append((uint64_t)str.size()); + buffer.insert(buffer.end(), str.begin(), str.end()); + } + + const char* data() const + { + return buffer.data(); + } + + size_t count() const + { + return buffer.size(); + } +}; + +void validate(std::shared_ptr<::gloo::transport::tcp::Socket>& socket) +{ + Buffer buffer; + buffer.append(QueryType::VALIDATE); + buffer.append(validationMagicNumber); + socket->write(buffer.data(), buffer.count()); +} + +void ping(std::shared_ptr<::gloo::transport::tcp::Socket>& socket) +{ + Buffer buffer; + buffer.append(QueryType::PING); + uint32_t nonce = getpid(); + uint32_t returnedNonce = -1; + buffer.append(nonce); + socket->write(buffer.data(), buffer.count()); + int r = socket->read(&returnedNonce, sizeof(returnedNonce)); + if (nonce != returnedNonce) { + std::stringstream ss; + ss << "Ping failed, nonce=" << nonce << ", returnedNonce=" << returnedNonce << ", socket read=" << r; + throw std::runtime_error(ss.str()); + } +} + +TCPStore::TCPStore(const std::string& host, int port) +{ + auto retry = 0; + do { + try { + ::addrinfo hints{}, *res{}; + hints.ai_flags = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + int status = getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res); + + std::shared_ptr holder(res, [](addrinfo* p) { + if (p != nullptr) { + freeaddrinfo(p); + } + }); + + if (status != 0) { + throw std::runtime_error("getaddrinfo failed: " + std::string(gai_strerror(status))); + } + + for (::addrinfo* addr = res; addr != nullptr; addr = addr->ai_next) { + int fd = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + if (fd == -1) { + continue; + } + auto socket = std::make_shared<::gloo::transport::tcp::Socket>(fd); + socket->connect(addr->ai_addr, addr->ai_addrlen); + socket->noDelay(true); + socket->recvTimeout(std::chrono::milliseconds(5000)); + socket->sendTimeout(std::chrono::milliseconds(5000)); + validate(socket); // validate the connection + ping(socket); // check send/recv + socket_ = std::move(socket); + break; + } + + if (socket_ == nullptr) { + throw std::runtime_error("unable to connect to " + host + ":" + std::to_string(port)); + } + } + catch (const std::exception& e) { + TM_LOG_WARNING("[TCPStore] Failed to connect to store after %d retries: %s", retry, e.what()); + std::this_thread::sleep_for(std::chrono::seconds(1)); + retry += 1; + } + } while (socket_ == nullptr); +} + +void TCPStore::set(const std::string& key, const std::vector& data) +{ + std::lock_guard lock(mutex_); + Buffer buffer; + buffer.append(QueryType::SET); + buffer.append(key); + buffer.append(data); + socket_->write(buffer.data(), buffer.count()); +} + +std::vector TCPStore::get(const std::string& key) +{ + wait({key}); + std::lock_guard lock(mutex_); + Buffer buffer; + buffer.append(QueryType::GET); + buffer.append(key); + socket_->write(buffer.data(), buffer.count()); + + uint64_t vec_size; + socket_->read(&vec_size, sizeof(vec_size)); + std::vector value(vec_size); + socket_->read(value.data(), value.size()); + return value; +} + +bool TCPStore::check(const std::vector& keys) +{ + std::lock_guard lock(mutex_); + Buffer buffer; + buffer.append(QueryType::CHECK); + buffer.append((uint64_t)keys.size()); + for (const auto& key : keys) { + buffer.append(key); + } + socket_->write(buffer.data(), buffer.count()); + + CheckResponseType response; + socket_->read(&response, sizeof(response)); + return response == CheckResponseType::READY; +} + +void TCPStore::wait(const std::vector& keys, const std::chrono::milliseconds& timeout) +{ + const auto start = std::chrono::steady_clock::now(); + while (!check(keys)) { + const auto elapsed = std::chrono::duration_cast(std::chrono::steady_clock::now() - start); + if (elapsed > timeout) { + std::stringstream ss; + ss << "Wait timeout for key(s): ["; + for (const auto& key : keys) { + ss << key << " "; + } + ss << "]"; + throw std::runtime_error("Wait timeout for key(s): " + ss.str()); + } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } +} + +TCPStore::~TCPStore() = default; + +} // namespace turbomind::comm diff --git a/src/turbomind/comm/gloo/tcp_store.h b/src/turbomind/comm/gloo/tcp_store.h new file mode 100644 index 0000000000..35dd1c05bf --- /dev/null +++ b/src/turbomind/comm/gloo/tcp_store.h @@ -0,0 +1,37 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include +#include + +#include +#include + +namespace turbomind::comm { + +class TCPStore: public gloo::rendezvous::Store { +public: + explicit TCPStore(const std::string& host, int port); + + ~TCPStore(); + + void set(const std::string& key, const std::vector& data) override; + + std::vector get(const std::string& key) override; + + bool check(const std::vector& keys); + + void wait(const std::vector& keys) override + { + wait(keys, std::chrono::seconds(30)); + } + + void wait(const std::vector& keys, const std::chrono::milliseconds& timeout) override; + +private: + std::shared_ptr<::gloo::transport::tcp::Socket> socket_; + std::mutex mutex_; +}; + +} // namespace turbomind::comm From bd5be0fe2e39b0eaf7c211da933ea9aebcaf1cac Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 9 Apr 2025 12:13:18 +0000 Subject: [PATCH 03/10] update gateway and support setting devices --- lmdeploy/messages.py | 4 ++ lmdeploy/turbomind/turbomind.py | 37 +++++++++--- src/turbomind/comm/gloo/gloo_comm.cc | 10 +--- src/turbomind/engine/gateway.cc | 6 +- src/turbomind/engine/gateway.h | 9 ++- src/turbomind/engine/request_queue.h | 8 +-- src/turbomind/models/llama/LlamaBatch.cc | 20 ++++++- src/turbomind/models/llama/llama_params.h | 6 ++ .../triton_backend/llama/LlamaTritonModel.cc | 60 +++++++++++++------ 9 files changed, 114 insertions(+), 46 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 02187ea587..11369e25af 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -223,6 +223,10 @@ class TurbomindEngineConfig: mlp_tp_size: int = None mlp_dp_size: int = None outer_dp_size: int = None + nnodes: int = 1 + node_rank: int = 0 + ngpus_per_node: int = None + devices: List[int] = None session_len: Optional[int] = None max_batch_size: int = None cache_max_entry_count: float = 0.8 diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 8d43923109..c51d5eeaa3 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -4,6 +4,7 @@ import copy import json import math +import os import os.path as osp import sys from collections import defaultdict @@ -81,6 +82,17 @@ def complete_parallel_config(cfg: TurbomindEngineConfig): def update_parallel_config(cfg: TurbomindEngineConfig): + # multi-node, use torchrun environment variables + if cfg.nnodes > 1: + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + local_world_size = int(os.environ['LOCAL_WORLD_SIZE']) + assert local_rank == 0, 'only support init engine on local_rank 0' + cfg.node_rank = rank // local_world_size + cfg.ngpus_per_node = cfg.ngpus_per_node or local_world_size + cfg.device_num = cfg.ngpus_per_node * cfg.nnodes + cfg.devices = cfg.devices or list(range(cfg.ngpus_per_node)) + if not complete_parallel_config(cfg): total = cfg.dp * cfg.tp if not cfg.device_num: @@ -101,6 +113,12 @@ def update_parallel_config(cfg: TurbomindEngineConfig): assert cfg.attn_dp_size * cfg.attn_tp_size == cfg.mlp_dp_size * cfg.mlp_tp_size assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.outer_dp_size == cfg.device_num + # update devices + if cfg.nnodes == 1: + cfg.devices = cfg.devices if cfg.devices else list(range(cfg.device_num)) + cfg.ngpus_per_node = cfg.ngpus_per_node or len(cfg.devices) + # for simplicity, each node has dp + assert cfg.outer_dp_size * cfg.attn_dp_size % cfg.nnodes == 0 class TurboMind: """LMDeploy's inference engine. @@ -137,8 +155,13 @@ def __init__(self, f' greater than 0, but got {_engine_config.max_batch_size}' update_parallel_config(_engine_config) + if _engine_config.nnodes > 1 and _engine_config.node_rank == 0: + from torch.distributed import TCPStore + master_addr = os.environ.get('LMDEPLOY_DP_MASTER_ADDR') + master_port = os.environ.get('LMDEPLOY_DP_MASTER_PORT') + self.store = TCPStore(host_name=master_addr, port=int(master_port), is_master=True) - self.gpu_count = _engine_config.device_num + self.gpu_count = len(_engine_config.devices) self.tokenizer = tokenizer if model_source == ModelSource.WORKSPACE: @@ -164,9 +187,8 @@ def _create_weight(self, model_comm): """Allocate weight buffer, load params if from_workspace.""" # TODO: support mpi - self.node_id = 0 - self.node_num = 1 - torch.cuda.synchronize() + engine_cfg = self.config_dict['engine_config'] + self.node_id = engine_cfg['node_rank'] # create weight def _create_weight_func(device_id): @@ -321,6 +343,8 @@ def from_pretrained(cls, def close(self): if self.model_comm is not None: self.model_comm = None + if hasattr(self, 'store'): + del self.store def create_instance(self, cuda_stream_id=0): """Create a turbomind instance. @@ -427,11 +451,6 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea self.tm_model = tm_model self.cuda_stream_id = cuda_stream_id - self.node_id = tm_model.node_id - self.gpu_count = tm_model.gpu_count - - self.session_len = tm_model.session_len - # create model instances self.model_inst = self._create_model_instance(0) diff --git a/src/turbomind/comm/gloo/gloo_comm.cc b/src/turbomind/comm/gloo/gloo_comm.cc index 97a520c146..d19101383c 100644 --- a/src/turbomind/comm/gloo/gloo_comm.cc +++ b/src/turbomind/comm/gloo/gloo_comm.cc @@ -76,14 +76,8 @@ class GlobalStoreFactory { { std::lock_guard lock(mutex_); - // TODO: use param instead of env - auto get = [](const std::string& key, const std::string& default_value) { - const char* value = std::getenv(key.c_str()); - return value ? std::string(value) : default_value; - }; - - std::string host = get("STORE_ADDR", "127.0.0.1"); - int port = std::stoi(get("STORE_PORT", "6800")); + std::string host = std::getenv("LMDEPLOY_DP_MASTER_ADDR"); + int port = std::stoi(std::getenv("LMDEPLOY_DP_MASTER_PORT")); std::stringstream ss; ss << host << STORE_INFO_DELIM << port << STORE_INFO_DELIM << prefix_++; diff --git a/src/turbomind/engine/gateway.cc b/src/turbomind/engine/gateway.cc index 3dd8c4b4cb..ff7846bff7 100644 --- a/src/turbomind/engine/gateway.cc +++ b/src/turbomind/engine/gateway.cc @@ -7,9 +7,13 @@ namespace turbomind { -Gateway::Gateway(int groups, int group_size, std::function()> ctx_factory): +Gateway::Gateway(int groups, + int group_size, + std::vector node_dp_ranks, + std::function()> ctx_factory): size_{groups * group_size}, group_size_{group_size}, + node_dp_ranks_{std::move(node_dp_ranks)}, queues_(size_), flags_(groups), ctx_factory_{ctx_factory}, diff --git a/src/turbomind/engine/gateway.h b/src/turbomind/engine/gateway.h index 8350822046..12a8b6b6dd 100644 --- a/src/turbomind/engine/gateway.h +++ b/src/turbomind/engine/gateway.h @@ -60,7 +60,10 @@ class SeqId2Rank { class Gateway { public: - Gateway(int groups, int group_size, std::function()> ctx_factory); + Gateway(int groups, + int group_size, + std::vector node_dp_ranks, + std::function()> ctx_factory); void shutdown(); @@ -73,7 +76,8 @@ class Gateway { rank = seqid2rank_.find(r->session.id); } else { - rank = next_.fetch_add(1, std::memory_order_relaxed) % size_; + rank = next_.fetch_add(1, std::memory_order_relaxed) % node_dp_ranks_.size(); + rank = node_dp_ranks_[rank]; } if (rank >= 0) { @@ -188,6 +192,7 @@ class Gateway { std::vector> queues_; std::vector>> flags_; + std::vector node_dp_ranks_; std::function()> ctx_factory_; diff --git a/src/turbomind/engine/request_queue.h b/src/turbomind/engine/request_queue.h index 590578bf8a..4d6ee641b9 100644 --- a/src/turbomind/engine/request_queue.h +++ b/src/turbomind/engine/request_queue.h @@ -78,10 +78,10 @@ class RequestQueue { || flag_->load(std::memory_order_relaxed) == expected_ // || closed_; }); - if (closed_) { - abort = true; - return false; - } + } + if (closed_) { + abort = true; + return false; } bool is_first = false; diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index a3220cd99a..bffc10d052 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -1716,8 +1716,17 @@ void LlamaBatch::InternalThreadEntry() NvtxScope _("pop"); const int free_slot_count = max_batch_size_ - state_->size + g.finished_count; const bool is_empty = (free_slot_count == max_batch_size_); - // Block if batch is empty AND no silbings are ready - gateway_->pop(req->infer, req->kill, free_slot_count, is_empty, req->abort, dp_rank_); + // Block if batch is empty AND no silbings are ready AND comm in same node + const bool blocking = is_empty && comm_.h_comm->is_same_process(); + int wait = 0; + do { + gateway_->pop(req->infer, req->kill, free_slot_count, blocking, req->abort, dp_rank_); + if (!comm_.h_comm->is_same_process()) { + bool empty_pop = req->infer.size() == 0 && req->kill.size() == 0 && req->abort == false; + wait = is_empty && empty_pop; + wait = AllReduce(comm_.h_comm, wait, comm::RedOp::kSum) == comm_.h_comm->n_ranks(); + } + } while (wait); } // Mark reqs to the same session_id as invalid (which are dangerous to the engine) DisableInvalidRequests(req->infer, req->kill); @@ -1730,8 +1739,13 @@ void LlamaBatch::InternalThreadEntry() // 2. Broadcast `ec` from rank-0 // shared_state_->barrier->wait(); // comm_.h_comm->Sync(comm_.h_comm_tp_group); + if (comm_.h_tp_group->n_ranks() > 1) { + Broadcast(comm_.h_tp_group, req, 0); + } - Broadcast(comm_.h_tp_group, req, 0); + if (!comm_.h_comm->is_same_process()) { + req->abort = AllReduce(comm_.h_comm, (int)req->abort, comm::RedOp::kSum) > 0; + } if (req->abort) { TM_LOG_INFO("[InternalThreadEntry] stop requested."); diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 3ece4940e5..2af2be5802 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -94,6 +94,12 @@ struct EngineParam { int attn_tp_rank; int mlp_tp_size; int mlp_tp_rank; + + // multi-node + int nnodes; + int node_rank; + int ngpus_per_node; + std::vector devices; }; enum class LoraPolicy : int diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 36e55f99c2..096c400bbe 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -228,10 +228,10 @@ LlamaTritonModel::~LlamaTritonModel() for (int device_id = 0; device_id < (int)engines_.size(); ++device_id) { // Set device id before destructing CUDA resources - check_cuda_error(cudaSetDevice(device_id)); + check_cuda_error(cudaSetDevice(engine_param_.devices[device_id])); engines_[device_id].reset(); weights_[device_id].reset(); - trim_default_mempool(device_id); + trim_default_mempool(engine_param_.devices[device_id]); } } @@ -239,7 +239,7 @@ template LlamaTritonModel::LlamaTritonModel(std::string model_dir, std::string config, std::function()> ffi_ctx_factory): - model_param_{}, attn_param_{}, moe_param_{}, lora_param_{}, engine_param_{}, weights_(getDeviceCount()) + model_param_{}, attn_param_{}, moe_param_{}, lora_param_{}, engine_param_{} { FT_CHECK_WITH_INFO(!(config.empty() && model_dir.empty()), "invalid init options"); @@ -297,6 +297,13 @@ LlamaTritonModel::LlamaTritonModel(std::string mod // rotary embedding parameters parse_rope_param(attention_reader["rope_param"], attn_param_.rope); + // multi-node information + engine_param_.nnodes = engine_reader["nnodes"].as(); + engine_param_.node_rank = engine_reader["node_rank"].as(); + engine_param_.devices = engine_reader["devices"].as>(); + engine_param_.ngpus_per_node = engine_reader["ngpus_per_node"].as(); + FT_CHECK(engine_param_.devices.size() == engine_param_.ngpus_per_node); + engine_param_.max_batch_size = engine_reader["max_batch_size"].as(0); engine_param_.max_prefill_token_num = engine_reader["max_prefill_token_num"].as(0); engine_param_.max_context_token_num = engine_reader["max_context_token_num"].as(0); @@ -347,10 +354,8 @@ LlamaTritonModel::LlamaTritonModel(std::string mod handleMissingParams(); - gateway_ = std::make_shared(engine_param_.outer_dp_size, engine_param_.attn_dp_size, ffi_ctx_factory); - - const auto device_count = getDeviceCount(); - engines_.resize(device_count); + weights_.resize(engine_param_.ngpus_per_node); + engines_.resize(engine_param_.ngpus_per_node); const std::string weight_type_str = model_reader["weight_type"].as(); if (weight_type_str == "fp16" || weight_type_str == "float16") { @@ -383,7 +388,10 @@ LlamaTritonModel::LlamaTritonModel(std::string mod // NOTE: This runs on Python main thread group_ids_.resize(engine_param_.outer_dp_size); for (size_t i = 0; i < group_ids_.size(); ++i) { - group_ids_[i] = comm::CreateHostGroupId(""); + // TODO: fine-grained comm control + const std::string group_backend = (comm_size_ <= engine_param_.ngpus_per_node) ? "" : "gloo"; + + group_ids_[i] = comm::CreateHostGroupId(group_backend); group_ids_[i]->Initialize(); } @@ -398,13 +406,26 @@ LlamaTritonModel::LlamaTritonModel(std::string mod e.mlp_tp_rank = i % comm_size_; } + std::vector node_dp_ranks; + for (int local_rank = 0, offset = engine_param_.ngpus_per_node * engine_param_.node_rank; + local_rank < engine_param_.ngpus_per_node; + ++local_rank) { + auto& e = engine_params_[offset + local_rank]; + if (e.attn_tp_rank == 0) { + node_dp_ranks.push_back(e.outer_dp_rank * e.attn_dp_size + e.attn_dp_rank); + } + } + + gateway_ = std::make_shared( + engine_param_.outer_dp_size, engine_param_.attn_dp_size, std::move(node_dp_ranks), ffi_ctx_factory); + TM_LOG_INFO("%s", toString().c_str()); } template std::unique_ptr LlamaTritonModel::createModelInstance(int device_id) { - check_cuda_error(cudaSetDevice(device_id)); + check_cuda_error(cudaSetDevice(engine_param_.devices[device_id])); FT_CHECK(engines_[device_id] != nullptr); @@ -418,8 +439,9 @@ std::unique_ptr LlamaTritonModel::createModelInstance(int devic template void LlamaTritonModel::createSharedWeights(int device_id, int rank) noexcept { - check_cuda_error(cudaSetDevice(device_id)); - weights_[rank] = std::make_shared>(model_param_, engine_params_.at(rank), lora_param_, moe_param_); + check_cuda_error(cudaSetDevice(engine_param_.devices[device_id])); + weights_[rank % engine_param_.ngpus_per_node] = + std::make_shared>(model_param_, engine_params_.at(rank), lora_param_, moe_param_); // model inited with model_dir if (model_dir_ != "") { weights_[device_id]->loadModel(model_dir_); @@ -429,12 +451,12 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) noexcept template std::unordered_map LlamaTritonModel::getParams(int device_id, int rank) noexcept { - check_cuda_error(cudaSetDevice(device_id)); + check_cuda_error(cudaSetDevice(engine_param_.devices[device_id])); // shared_weight should be created before getParams - FT_CHECK(weights_[rank] != nullptr); + FT_CHECK(weights_[rank % engine_param_.ngpus_per_node] != nullptr); - TensorMap output = weights_[rank]->getParams(); + TensorMap output = weights_[rank % engine_param_.ngpus_per_node]->getParams(); std::unordered_map result; for (auto [name, tensor] : output) { @@ -447,11 +469,11 @@ std::unordered_map LlamaTritonModel::getParams(int devic template void LlamaTritonModel::processWeights(int device_id, int rank) noexcept { - check_cuda_error(cudaSetDevice(device_id)); + check_cuda_error(cudaSetDevice(engine_param_.devices[device_id])); FT_CHECK(weights_[device_id] != nullptr); cudaDeviceProp props{}; - check_cuda_error(cudaGetDeviceProperties(&props, device_id)); + check_cuda_error(cudaGetDeviceProperties(&props, engine_param_.devices[device_id])); weights_[device_id]->prepare(props); sync_check_cuda_error(); @@ -485,9 +507,9 @@ Communicators LlamaTritonModel::createCommSplits(int rank) template void LlamaTritonModel::createEngine(int device_id, int rank) { - check_cuda_error(cudaSetDevice(device_id)); + check_cuda_error(cudaSetDevice(engine_param_.devices[device_id])); - auto ctx = std::make_unique>(device_id); + auto ctx = std::make_unique>(engine_param_.devices[device_id]); ctx->comm = createCommSplits(rank); @@ -515,7 +537,7 @@ void LlamaTritonModel::createEngine(int device_id, int rank) std::move(model), std::move(ctx), gateway_, - device_id, + engine_param_.devices[device_id], dp_rank); } catch (const std::exception& e) { From ecc2623080a781e2b98a0e23746c1bcd027ff625 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 9 Apr 2025 12:32:54 +0000 Subject: [PATCH 04/10] fix build --- lmdeploy/turbomind/turbomind.py | 1 + src/turbomind/models/llama/LlamaBatch.cc | 3 +++ 2 files changed, 4 insertions(+) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index c51d5eeaa3..f4deee9a60 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -120,6 +120,7 @@ def update_parallel_config(cfg: TurbomindEngineConfig): # for simplicity, each node has dp assert cfg.outer_dp_size * cfg.attn_dp_size % cfg.nnodes == 0 + class TurboMind: """LMDeploy's inference engine. diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index bffc10d052..3ec37026e6 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -1630,6 +1630,7 @@ struct RequestData { } // namespace +#ifdef BUILD_MULTI_GPU namespace comm { void serialize(std::ostream& os, const RequestData& req) @@ -1695,6 +1696,8 @@ void deserialize(std::shared_ptr* req, int n, const std::vector void LlamaBatch::InternalThreadEntry() { From edb4dfe9fe04c22836e64b10f917e16847358696 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 10 Apr 2025 12:02:08 +0000 Subject: [PATCH 05/10] use tm cfg instead of env --- lmdeploy/cli/serve.py | 11 ++++++++++- lmdeploy/cli/utils.py | 18 ++++++++++++++++++ lmdeploy/turbomind/turbomind.py | 17 ++++++++++------- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 76d27b6727..43e3006b11 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -163,14 +163,18 @@ def add_parser_api_server(): prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group) max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group) quant_policy = ArgumentHelper.quant_policy(pt_group) - ArgumentHelper.dp(pt_group) + dp_act = ArgumentHelper.dp(pt_group) ArgumentHelper.dp_rank(pt_group) + # multi-node serving args + ArgumentHelper.node_rank(parser) + ArgumentHelper.num_nodes(parser) # turbomind args tb_group = parser.add_argument_group('TurboMind engine arguments') # common engine args tb_group._group_actions.append(dtype_act) tb_group._group_actions.append(tp_act) + tb_group._group_actions.append(dp_act) tb_group._group_actions.append(session_len_act) tb_group._group_actions.append(max_batch_size_act) tb_group._group_actions.append(cache_max_entry_act) @@ -183,6 +187,7 @@ def add_parser_api_server(): ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) ArgumentHelper.communicator(tb_group) + ArgumentHelper.ngpus_per_node(tb_group) # vlm args vision_group = parser.add_argument_group('Vision model arguments') @@ -310,6 +315,10 @@ def api_server(args): from lmdeploy.messages import TurbomindEngineConfig backend_config = TurbomindEngineConfig(dtype=args.dtype, tp=args.tp, + dp=args.dp, + nnodes=args.nnodes, + ngpus_per_node=args.ngpus_per_node, + node_rank=args.node_rank, max_batch_size=max_batch_size, session_len=args.session_len, model_format=args.model_format, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 8230676057..075280c7e7 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -170,6 +170,24 @@ def dp_rank(parser): default=0, help='data parallelism rank, all ranks between 0 ~ dp should be created.') + @staticmethod + def node_rank(parser): + """add argument node_rank to parser.""" + + return parser.add_argument('--node-rank', type=int, default=0, help='The current node rank.') + + @staticmethod + def num_nodes(parser): + """add argument num_nodes to parser.""" + + return parser.add_argument('--nnodes', type=int, default=1, help='The total node nums') + + @staticmethod + def ngpus_per_node(parser): + """add argument ngpus_per_node to parser.""" + + return parser.add_argument('--ngpus-per-node', type=int, default=None, help='The total gpu nums per node') + @staticmethod def session_id(parser): """Add argument session_id to parser.""" diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index f4deee9a60..f4f9578ee7 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -82,14 +82,15 @@ def complete_parallel_config(cfg: TurbomindEngineConfig): def update_parallel_config(cfg: TurbomindEngineConfig): - # multi-node, use torchrun environment variables if cfg.nnodes > 1: - rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - local_world_size = int(os.environ['LOCAL_WORLD_SIZE']) - assert local_rank == 0, 'only support init engine on local_rank 0' - cfg.node_rank = rank // local_world_size - cfg.ngpus_per_node = cfg.ngpus_per_node or local_world_size + # multi-node, use torchrun environment variables + # rank = int(os.environ['RANK']) + # local_rank = int(os.environ['LOCAL_RANK']) + # local_world_size = int(os.environ['LOCAL_WORLD_SIZE']) + # assert local_rank == 0, 'only support init engine on local_rank 0' + # cfg.node_rank = rank // local_world_size + # cfg.ngpus_per_node = cfg.ngpus_per_node or local_world_size + assert cfg.ngpus_per_node is not None cfg.device_num = cfg.ngpus_per_node * cfg.nnodes cfg.devices = cfg.devices or list(range(cfg.ngpus_per_node)) @@ -160,6 +161,8 @@ def __init__(self, from torch.distributed import TCPStore master_addr = os.environ.get('LMDEPLOY_DP_MASTER_ADDR') master_port = os.environ.get('LMDEPLOY_DP_MASTER_PORT') + assert master_addr is not None and master_port is not None, \ + 'LMDEPLOY_DP_MASTER_ADDR and LMDEPLOY_DP_MASTER_PORT should be set when using multi-node' self.store = TCPStore(host_name=master_addr, port=int(master_port), is_master=True) self.gpu_count = len(_engine_config.devices) From 9b8d8a766d710d9ea62aa2b5e23039dad8a705a2 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 23 Apr 2025 12:38:32 +0000 Subject: [PATCH 06/10] fix dp --- src/turbomind/kernels/norm/rms_norm.cu | 8 ++++---- src/turbomind/models/llama/unified_decoder.cc | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/turbomind/kernels/norm/rms_norm.cu b/src/turbomind/kernels/norm/rms_norm.cu index ee826c4105..4b6aaba20f 100644 --- a/src/turbomind/kernels/norm/rms_norm.cu +++ b/src/turbomind/kernels/norm/rms_norm.cu @@ -86,15 +86,15 @@ __global__ void RMSNorm(T* dst, void invokeRMSNorm(Tensor& out, const Tensor& x, const Tensor& w, float eps, cudaStream_t st) { + if (x.size() == 0) { + return; + } + TM_CHECK(x.ndim() == 2); TM_CHECK(out.shape() == x.shape()); TM_CHECK(out.dtype() == x.dtype()); TM_CHECK(w.dtype() == x.dtype() && w.shape(-1) == x.shape(-1)); - if (x.size() == 0) { - return; - } - auto invoke = [&](auto t) { using T = decltype(t); diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index c875c7852f..b771f0f00d 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -59,7 +59,7 @@ void UnifiedDecoder::AllreduceResidualRMSnorm(Tensor& hidden_states, if (0) {} else if (group0 || group1) { d_comm_->AllreduceResidualBiasRMSnormEx(hidden_states.raw_data(), - residual.raw_data(), + residual.data_or((void*)nullptr), bias.data_or((void*)nullptr), weight.raw_data(), rmsnorm_eps_, @@ -73,7 +73,7 @@ void UnifiedDecoder::AllreduceResidualRMSnorm(Tensor& hidden_states, } else if (d_comm_) { d_comm_->AllreduceResidualBiasRMSnorm(hidden_states.raw_data(), - residual.raw_data(), + residual.data_or((void*)nullptr), bias.data_or((void*)nullptr), weight.raw_data(), rmsnorm_eps_, @@ -86,7 +86,7 @@ void UnifiedDecoder::AllreduceResidualRMSnorm(Tensor& hidden_states, } else { invokeResidualBiasRMSNorm(hidden_states.raw_data(), - residual.raw_data(), + residual.data_or((void*)nullptr), weight.raw_data(), bias.data_or((void*)nullptr), dtype, @@ -132,7 +132,7 @@ void UnifiedDecoder::Forward(TensorMap& args, const std::vector& we Tensor local_hidden_states = global_hidden_states; const auto global_token_num = global_hidden_states.shape(0); - const auto local_token_num = local_residual.shape(0); + const auto local_token_num = local_residual.size() ? local_residual.shape(0) : 0; if (attn_dp_size_ > 1) { // Offset hidden states buffer for mixed DP TM_CHECK_EQ(local_token_nums.size(), attn_dp_size_); From 22569cb252fc8ea3233fde201b5fb057ddff1feb Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 23 Apr 2025 12:50:11 +0000 Subject: [PATCH 07/10] fix lint --- src/turbomind/comm/host_comm.h | 4 ++-- src/turbomind/comm/serialize.cc | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/turbomind/comm/host_comm.h b/src/turbomind/comm/host_comm.h index 134707a368..af11ba4861 100644 --- a/src/turbomind/comm/host_comm.h +++ b/src/turbomind/comm/host_comm.h @@ -3,15 +3,15 @@ #pragma once #include +#include #include #include #include #include -#include +#include "src/turbomind/comm/serialize.h" #include "src/turbomind/core/data_type.h" #include "src/turbomind/utils/logger.h" -#include "src/turbomind/comm/serialize.h" namespace turbomind::comm { diff --git a/src/turbomind/comm/serialize.cc b/src/turbomind/comm/serialize.cc index 5e044c5458..8441f4be45 100644 --- a/src/turbomind/comm/serialize.cc +++ b/src/turbomind/comm/serialize.cc @@ -108,15 +108,17 @@ void deserialize(std::istream& is, Layout& layout) layout = Layout(std::move(shape), std::move(stride)); } -void serialize(std::ostream& os, const Buffer& buffer) { +void serialize(std::ostream& os, const Buffer& buffer) +{ FT_CHECK(buffer.device() == turbomind::core::Device(kCPU)); serialize(os, buffer.size()); serialize(os, buffer.dtype()); os.write((char*)buffer.raw_data(), buffer.byte_size()); } -void deserialize(std::istream& is, Buffer& buffer) { - ssize_t size; +void deserialize(std::istream& is, Buffer& buffer) +{ + ssize_t size; DataType dtype; deserialize(is, size); deserialize(is, dtype); From eb03cbf57572d092d66be5d942f93d09ebba4119 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 24 Apr 2025 02:54:51 +0000 Subject: [PATCH 08/10] fix build --- src/turbomind/comm/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/turbomind/comm/CMakeLists.txt b/src/turbomind/comm/CMakeLists.txt index da1d0ba386..73490146df 100644 --- a/src/turbomind/comm/CMakeLists.txt +++ b/src/turbomind/comm/CMakeLists.txt @@ -24,7 +24,7 @@ if (BUILD_MULTI_GPU) target_link_libraries(host_comm INTERFACE gloo_comm) add_library(serialize STATIC serialize.cc) - target_link_libraries(serialize INTERFACE core) + target_link_libraries(serialize PRIVATE core) set_property(TARGET serialize PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(host_comm INTERFACE serialize) From e78156dd95ebdcfdd2b0bd3a4e0160eaa0bd0411 Mon Sep 17 00:00:00 2001 From: irexyc Date: Fri, 11 Jul 2025 06:23:34 +0000 Subject: [PATCH 09/10] fix ci --- lmdeploy/cli/serve.py | 8 ++++---- lmdeploy/messages.py | 2 +- lmdeploy/turbomind/turbomind.py | 1 - src/turbomind/comm/gloo/CMakeLists.txt | 11 +---------- src/turbomind/comm/gloo/gloo_comm.cc | 6 ------ src/turbomind/models/llama/llama_params.h | 6 ++++-- 6 files changed, 10 insertions(+), 24 deletions(-) diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 36ed5cfeb9..8b8f01b292 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -166,6 +166,7 @@ def add_parser_api_server(): quant_policy = ArgumentHelper.quant_policy(pt_group) model_format = ArgumentHelper.model_format(pt_group) dp_act = ArgumentHelper.dp(pt_group) + num_nodes_act = ArgumentHelper.num_nodes(pt_group) ArgumentHelper.ep(pt_group) ArgumentHelper.enable_microbatch(pt_group) ArgumentHelper.enable_eplb(pt_group) @@ -173,8 +174,7 @@ def add_parser_api_server(): ArgumentHelper.role(pt_group) ArgumentHelper.migration_backend(pt_group) # multi-node serving args - ArgumentHelper.node_rank(parser) - ArgumentHelper.num_nodes(parser) + node_rank_act = ArgumentHelper.node_rank(pt_group) # turbomind args tb_group = parser.add_argument_group('TurboMind engine arguments') @@ -190,12 +190,12 @@ def add_parser_api_server(): tb_group._group_actions.append(max_prefill_token_num_act) tb_group._group_actions.append(quant_policy) tb_group._group_actions.append(model_format) + tb_group._group_actions.append(num_nodes_act) + tb_group._group_actions.append(node_rank_act) ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) ArgumentHelper.communicator(tb_group) - ArgumentHelper.num_nodes(tb_group) - ArgumentHelper.node_rank(tb_group) ArgumentHelper.ngpus_per_node(tb_group) # vlm args diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 483a325d06..550f3adcd8 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -237,7 +237,7 @@ class TurbomindEngineConfig: outer_dp_size: int = None nnodes: int = 1 node_rank: int = 0 - ngpus_per_node: int = None + ngpus_per_node: Optional[int] = None devices: List[int] = None session_len: Optional[int] = None max_batch_size: int = None diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 92c7c9a717..18146c4c5c 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -217,7 +217,6 @@ def _create_engine(self): def _create_weight(self, model_comm): """Allocate weight buffer, load params if from_workspace.""" - # TODO: support mpi engine_cfg = self.config_dict['engine_config'] self.node_id = engine_cfg['node_rank'] diff --git a/src/turbomind/comm/gloo/CMakeLists.txt b/src/turbomind/comm/gloo/CMakeLists.txt index 52dc579324..3e8f8255c7 100644 --- a/src/turbomind/comm/gloo/CMakeLists.txt +++ b/src/turbomind/comm/gloo/CMakeLists.txt @@ -4,15 +4,13 @@ cmake_minimum_required(VERSION 3.8) include(FetchContent) FetchContent_Declare( gloo - GIT_REPOSITORY https://github.com/facebookincubator/gloo.git + GIT_REPOSITORY https://github.com/pytorch/gloo.git GIT_TAG cbe963b5a43cd75e6eca4f74d2bb38ec8dcfdbc8 ) # some settings of gloo, set(GLOO_INSTALL OFF CACHE BOOL "" FORCE) set(GLOO_STATIC_OR_SHARED STATIC CACHE STRING "" FORCE) -set(__USE_NCCL ${USE_NCCL}) -set(__BUILD_TEST ${BUILD_TEST}) set(USE_NCCL OFF) set(BUILD_TEST OFF) # set(USE_REDIS ON) # TODO remove, use tcp_store instead @@ -23,8 +21,6 @@ target_include_directories(gloo PUBLIC $ $ # config.h generated at cmake config time ) -set(USE_NCCL ${__USE_NCCL}) -set(BUILD_TEST ${__BUILD_TEST}) add_library(gloo_comm STATIC gloo_comm.cc @@ -32,8 +28,3 @@ add_library(gloo_comm STATIC ) set_property(TARGET gloo_comm PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(gloo_comm PUBLIC gloo logger) - -if(USE_REDIS) - include_directories(SYSTEM ${HIREDIS_INCLUDE_DIRS}) - target_link_libraries(gloo_comm PRIVATE ${HIREDIS_LIBRARIES}) -endif() diff --git a/src/turbomind/comm/gloo/gloo_comm.cc b/src/turbomind/comm/gloo/gloo_comm.cc index 31f6b990f2..fb0f61b69f 100644 --- a/src/turbomind/comm/gloo/gloo_comm.cc +++ b/src/turbomind/comm/gloo/gloo_comm.cc @@ -12,9 +12,6 @@ #include #include #include -#if GLOO_USE_REDIS -#include -#endif #include #include #include @@ -57,9 +54,6 @@ class Store: public ::gloo::rendezvous::PrefixStore { std::string host_; int port_; -#if GLOO_USE_REDIS -// ::gloo::rendezvous::RedisStore store_; -#endif TCPStore store_; using ::gloo::rendezvous::PrefixStore::prefix_; }; diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 6f4071d294..ae2d638ce9 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -41,7 +41,8 @@ struct ModelParam { }; struct MoeParam { - enum Method { + enum Method + { kNaive, kFused } method; @@ -104,7 +105,8 @@ struct EngineParam { std::vector devices; }; -enum class LoraPolicy : int { +enum class LoraPolicy : int +{ kNull, kPlora, }; From e06f256bbe53ee695d2209e49faf3e7d00c0fb61 Mon Sep 17 00:00:00 2001 From: irexyc Date: Fri, 11 Jul 2025 09:48:18 +0000 Subject: [PATCH 10/10] update gloo version to match pytroch/v2.8.0-rc4 --- src/turbomind/comm/gloo/CMakeLists.txt | 3 +-- src/turbomind/comm/gloo/gloo_comm.cc | 9 ++++++--- src/turbomind/comm/gloo/tcp_store.cc | 5 ++++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/turbomind/comm/gloo/CMakeLists.txt b/src/turbomind/comm/gloo/CMakeLists.txt index 3e8f8255c7..cb3bf80278 100644 --- a/src/turbomind/comm/gloo/CMakeLists.txt +++ b/src/turbomind/comm/gloo/CMakeLists.txt @@ -5,7 +5,7 @@ include(FetchContent) FetchContent_Declare( gloo GIT_REPOSITORY https://github.com/pytorch/gloo.git - GIT_TAG cbe963b5a43cd75e6eca4f74d2bb38ec8dcfdbc8 + GIT_TAG c7b7b022c124d9643957d9bd55f57ac59fce8fa2 # pytorch-v2.8.0-rc4 ) # some settings of gloo, @@ -13,7 +13,6 @@ set(GLOO_INSTALL OFF CACHE BOOL "" FORCE) set(GLOO_STATIC_OR_SHARED STATIC CACHE STRING "" FORCE) set(USE_NCCL OFF) set(BUILD_TEST OFF) -# set(USE_REDIS ON) # TODO remove, use tcp_store instead FetchContent_MakeAvailable(gloo) # gloo build doesn't add include directories as a target property... diff --git a/src/turbomind/comm/gloo/gloo_comm.cc b/src/turbomind/comm/gloo/gloo_comm.cc index fb0f61b69f..4f39c9c118 100644 --- a/src/turbomind/comm/gloo/gloo_comm.cc +++ b/src/turbomind/comm/gloo/gloo_comm.cc @@ -40,7 +40,10 @@ std::shared_ptr<::gloo::transport::Device> createGlooDevice() class Store: public ::gloo::rendezvous::PrefixStore { public: explicit Store(const std::string& host, int port, const std::string& prefix): - host_(host), port_(port), store_(host_, port_), ::gloo::rendezvous::PrefixStore(prefix, store_){}; + host_(host), port_(port), ::gloo::rendezvous::PrefixStore(prefix, nullptr) + { + store_ = std::make_shared(host_, port_); + }; ~Store() = default; @@ -54,7 +57,7 @@ class Store: public ::gloo::rendezvous::PrefixStore { std::string host_; int port_; - TCPStore store_; + using ::gloo::rendezvous::PrefixStore::store_; using ::gloo::rendezvous::PrefixStore::prefix_; }; @@ -127,7 +130,7 @@ struct GlooCommImpl: public HostCommImpl { // TM_LOG_INFO("[GlooCommImpl] rank=%d, n_ranks=%d, prefix=%s", rank_, n_ranks_, store_->prefix_.c_str()); device_ = createGlooDevice(); context_ = std::make_shared<::gloo::rendezvous::Context>(rank_, n_ranks_); - context_->connectFullMesh(*store_, device_); + context_->connectFullMesh(store_, device_); } ~GlooCommImpl() {} diff --git a/src/turbomind/comm/gloo/tcp_store.cc b/src/turbomind/comm/gloo/tcp_store.cc index 9a2e74351e..54706aba0f 100644 --- a/src/turbomind/comm/gloo/tcp_store.cc +++ b/src/turbomind/comm/gloo/tcp_store.cc @@ -15,7 +15,7 @@ namespace turbomind::comm { namespace { -// copy from pytorch https://github.com/pytorch/pytorch/blob/v2.5.1/torch/csrc/distributed/c10d/TCPStoreBackend.hpp +// copy from pytorch https://github.com/pytorch/pytorch/blob/v2.8.0-rc4/torch/csrc/distributed/c10d/TCPStoreBackend.hpp static const uint32_t validationMagicNumber = 0x3C85F7CE; @@ -41,6 +41,9 @@ enum class QueryType : uint8_t MULTI_SET, CANCEL_WAIT, PING, + QUEUE_PUSH, + QUEUE_POP, + QUEUE_LEN, }; } // namespace