diff --git a/be/src/exec/CMakeLists.txt b/be/src/exec/CMakeLists.txt index 84387d24733700..58468185156987 100644 --- a/be/src/exec/CMakeLists.txt +++ b/be/src/exec/CMakeLists.txt @@ -312,6 +312,7 @@ set(EXEC_FILES workgroup/pipeline_executor_set.cpp workgroup/pipeline_executor_set_manager.cpp workgroup/work_group.cpp + workgroup/mem_tracker_manager.cpp workgroup/scan_executor.cpp workgroup/scan_task_queue.cpp query_cache/multilane_operator.cpp diff --git a/be/src/exec/workgroup/mem_tracker_manager.cpp b/be/src/exec/workgroup/mem_tracker_manager.cpp new file mode 100644 index 00000000000000..37b5e9d9f100cb --- /dev/null +++ b/be/src/exec/workgroup/mem_tracker_manager.cpp @@ -0,0 +1,54 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mem_tracker_manager.h" + +#include + +#include "runtime/exec_env.h" +#include "runtime/mem_tracker.h" +#include "work_group.h" + +namespace starrocks::workgroup { +MemTrackerManager::MemTrackerPtr MemTrackerManager::get_parent_mem_tracker(const WorkGroupPtr& wg) { + if (WorkGroup::DEFAULT_MEM_POOL == wg->mem_pool()) { + return GlobalEnv::GetInstance()->query_pool_mem_tracker_shared(); + } + + const double mem_limit_fraction = wg->mem_limit(); + const int64_t memory_limit_bytes = + static_cast(GlobalEnv::GetInstance()->query_pool_mem_tracker()->limit() * mem_limit_fraction); + + // Frontend (FE) validation ensures that active resource groups (RGs) sharing + // the same mem_pool also have the same mem_limit. + if (_shared_mem_trackers.contains(wg->mem_pool())) { + // We must handle an edge case: + // 1. All RGs using a specific mem_pool are deleted. + // 2. The shared tracker for that pool remains cached here. + // 3. A new RG is created with the same mem_pool name but a different mem_limit. + // Therefore, we must verify the cached tracker's limit matches the current RG's limit. + if (auto& shared_mem_tracker = _shared_mem_trackers.at(wg->mem_pool()); + shared_mem_tracker->limit() == memory_limit_bytes) { + return shared_mem_tracker; + } + } + + auto shared_mem_tracker = + std::make_shared(MemTrackerType::RESOURCE_GROUP_SHARED_MEMORY_POOL, memory_limit_bytes, + wg->mem_pool(), GlobalEnv::GetInstance()->query_pool_mem_tracker()); + + _shared_mem_trackers.insert_or_assign(wg->mem_pool(), shared_mem_tracker); + return shared_mem_tracker; +} +} // namespace starrocks::workgroup \ No newline at end of file diff --git a/be/src/exec/workgroup/mem_tracker_manager.h b/be/src/exec/workgroup/mem_tracker_manager.h new file mode 100644 index 00000000000000..ab781a81524aa7 --- /dev/null +++ b/be/src/exec/workgroup/mem_tracker_manager.h @@ -0,0 +1,31 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "runtime/mem_tracker.h" +#include "work_group_fwd.h" + +namespace starrocks::workgroup { +struct MemTrackerManager { +public: + using MemTrackerPtr = std::shared_ptr; + MemTrackerPtr get_parent_mem_tracker(const WorkGroupPtr& wg); + +private: + std::unordered_map _shared_mem_trackers{}; +}; +} // namespace starrocks::workgroup diff --git a/be/src/exec/workgroup/work_group.cpp b/be/src/exec/workgroup/work_group.cpp index 27761fe560bac1..610c1515c713e1 100644 --- a/be/src/exec/workgroup/work_group.cpp +++ b/be/src/exec/workgroup/work_group.cpp @@ -18,10 +18,12 @@ #include "common/config.h" #include "exec/pipeline/pipeline_driver_executor.h" +#include "exec/workgroup/mem_tracker_manager.h" #include "exec/workgroup/pipeline_executor_set.h" #include "exec/workgroup/scan_task_queue.h" #include "glog/logging.h" #include "runtime/exec_env.h" +#include "runtime/mem_tracker.h" #include "util/cpu_info.h" #include "util/metrics.h" #include "util/starrocks_metrics.h" @@ -96,7 +98,7 @@ RunningQueryToken::~RunningQueryToken() { } WorkGroup::WorkGroup(std::string name, int64_t id, int64_t version, size_t cpu_limit, double memory_limit, - size_t concurrency, double spill_mem_limit_threshold, WorkGroupType type) + size_t concurrency, double spill_mem_limit_threshold, WorkGroupType type, std::string mem_pool) : _name(std::move(name)), _id(id), _version(version), @@ -105,6 +107,7 @@ WorkGroup::WorkGroup(std::string name, int64_t id, int64_t version, size_t cpu_l _memory_limit(memory_limit), _concurrency_limit(concurrency), _spill_mem_limit_threshold(spill_mem_limit_threshold), + _mem_pool(std::move(mem_pool)), _driver_sched_entity(this), _scan_sched_entity(this), _connector_scan_sched_entity(this) {} @@ -154,6 +157,11 @@ WorkGroup::WorkGroup(const TWorkGroup& twg) if (twg.__isset.spill_mem_limit_threshold) { _spill_mem_limit_threshold = twg.spill_mem_limit_threshold; } + if (twg.__isset.mem_pool) { + _mem_pool = twg.mem_pool; + } else { + _mem_pool = DEFAULT_MEM_POOL; + } } TWorkGroup WorkGroup::to_thrift() const { @@ -182,13 +190,20 @@ TWorkGroup WorkGroup::to_thrift_verbose() const { return twg; } -void WorkGroup::init() { - _memory_limit_bytes = _memory_limit == ABSENT_MEMORY_LIMIT - ? GlobalEnv::GetInstance()->query_pool_mem_tracker()->limit() - : GlobalEnv::GetInstance()->query_pool_mem_tracker()->limit() * _memory_limit; +void WorkGroup::init(std::shared_ptr& parent_mem_tracker) { + if (parent_mem_tracker->type() == MemTrackerType::RESOURCE_GROUP_SHARED_MEMORY_POOL) { + _memory_limit_bytes = parent_mem_tracker->limit(); + _shared_mem_tracker = parent_mem_tracker; + } else { + _memory_limit_bytes = _memory_limit == ABSENT_MEMORY_LIMIT ? parent_mem_tracker->limit() + : parent_mem_tracker->limit() * _memory_limit; + } + _spill_mem_limit_bytes = _spill_mem_limit_threshold * _memory_limit_bytes; + + //todo (m.bogusz) MemTracker can only handle raw ptr parent so we need to add parent_mem_tracker as member to workgroup _mem_tracker = std::make_shared(MemTrackerType::RESOURCE_GROUP, _memory_limit_bytes, _name, - GlobalEnv::GetInstance()->query_pool_mem_tracker()); + parent_mem_tracker.get()); _mem_tracker->set_reserve_limit(_spill_mem_limit_bytes); _driver_sched_entity.set_queue(std::make_unique( @@ -275,7 +290,7 @@ void WorkGroup::copy_metrics(const WorkGroup& rhs) { // ------------------------------------------------------------------------------------ WorkGroupManager::WorkGroupManager(PipelineExecutorSetConfig executors_manager_conf) - : _executors_manager(this, std::move(executors_manager_conf)) {} + : _executors_manager(this, std::move(executors_manager_conf)), _shared_mem_tracker_manager{} {} WorkGroupManager::~WorkGroupManager() = default; @@ -549,7 +564,8 @@ void WorkGroupManager::create_workgroup_unlocked(const WorkGroupPtr& wg, UniqueL return; } - wg->init(); + auto parent_mem_tracker = _shared_mem_tracker_manager.get_parent_mem_tracker(wg); + wg->init(parent_mem_tracker); _workgroups[unique_id] = wg; _sum_cpu_weight += wg->cpu_weight(); @@ -690,7 +706,8 @@ std::shared_ptr DefaultWorkGroupInitialization::create_default_workgr const double memory_limit = 1.0; const double spill_mem_limit_threshold = 1.0; // not enable spill mem limit threshold return std::make_shared("default_wg", WorkGroup::DEFAULT_WG_ID, WorkGroup::DEFAULT_VERSION, cpu_limit, - memory_limit, 0, spill_mem_limit_threshold, WorkGroupType::WG_DEFAULT); + memory_limit, 0, spill_mem_limit_threshold, WorkGroupType::WG_DEFAULT, + WorkGroup::DEFAULT_MEM_POOL); } std::shared_ptr DefaultWorkGroupInitialization::create_default_mv_workgroup() { @@ -700,7 +717,6 @@ std::shared_ptr DefaultWorkGroupInitialization::create_default_mv_wor double mv_spill_mem_limit_threshold = config::default_mv_resource_group_spill_mem_limit_threshold; return std::make_shared("default_mv_wg", WorkGroup::DEFAULT_MV_WG_ID, WorkGroup::DEFAULT_MV_VERSION, mv_cpu_limit, mv_memory_limit, mv_concurrency_limit, - mv_spill_mem_limit_threshold, WorkGroupType::WG_MV); + mv_spill_mem_limit_threshold, WorkGroupType::WG_MV, WorkGroup::DEFAULT_MEM_POOL); } - } // namespace starrocks::workgroup diff --git a/be/src/exec/workgroup/work_group.h b/be/src/exec/workgroup/work_group.h index 06333c17c756b3..7c297ac01d7cbf 100644 --- a/be/src/exec/workgroup/work_group.h +++ b/be/src/exec/workgroup/work_group.h @@ -23,6 +23,7 @@ #include "exec/pipeline/pipeline_driver_queue.h" #include "exec/pipeline/query_context.h" #include "exec/workgroup/work_group_fwd.h" +#include "mem_tracker_manager.h" #include "pipeline_executor_set_manager.h" #include "runtime/mem_tracker.h" #include "storage/olap_define.h" @@ -114,11 +115,11 @@ using RunningQueryTokenPtr = std::unique_ptr; class WorkGroup : public std::enable_shared_from_this { public: WorkGroup(std::string name, int64_t id, int64_t version, size_t cpu_weight, double memory_limit, size_t concurrency, - double spill_mem_limit_threshold, WorkGroupType type); + double spill_mem_limit_threshold, WorkGroupType type, std::string mem_pool); explicit WorkGroup(const TWorkGroup& twg); ~WorkGroup() = default; - void init(); + void init(std::shared_ptr& parent_mem_tracker); TWorkGroup to_thrift() const; TWorkGroup to_thrift_verbose() const; @@ -128,6 +129,7 @@ class WorkGroup : public std::enable_shared_from_this { void copy_metrics(const WorkGroup& rhs); MemTracker* mem_tracker() { return _mem_tracker.get(); } + std::shared_ptr grab_mem_tracker() { return _mem_tracker; } const MemTracker* mem_tracker() const { return _mem_tracker.get(); } MemTracker* connector_scan_mem_tracker() { return _connector_scan_mem_tracker.get(); } @@ -137,7 +139,7 @@ class WorkGroup : public std::enable_shared_from_this { const std::string& name() const { return _name; } size_t cpu_weight() const { return _cpu_weight; } size_t exclusive_cpu_cores() const { return _exclusive_cpu_cores; } - size_t mem_limit() const { return _memory_limit; } + double mem_limit() const { return _memory_limit; } int64_t mem_limit_bytes() const { return _memory_limit_bytes; } int64_t mem_consumption_bytes() const { return _mem_tracker == nullptr ? 0L : _mem_tracker->consumption(); } @@ -198,7 +200,7 @@ class WorkGroup : public std::enable_shared_from_this { int64_t big_query_scan_rows_limit() const { return _big_query_scan_rows_limit; } void incr_cpu_runtime_ns(int64_t delta_ns) { _cpu_runtime_ns += delta_ns; } int64_t cpu_runtime_ns() const { return _cpu_runtime_ns; } - + std::string mem_pool() const { return _mem_pool; } void set_shared_executors(PipelineExecutorSet* executors) { _executors = executors; } void set_exclusive_executors(std::unique_ptr executors) { _exclusive_executors = std::move(executors); @@ -212,6 +214,7 @@ class WorkGroup : public std::enable_shared_from_this { static constexpr int64 DEFAULT_MV_WG_ID = 1; static constexpr int64 DEFAULT_VERSION = 0; static constexpr int64 DEFAULT_MV_VERSION = 1; + inline static std::string DEFAULT_MEM_POOL{"default_mem_pool"}; // Yield scan io task when maximum time in nano-seconds has spent in current execution round. static constexpr int64_t YIELD_MAX_TIME_SPENT = 100'000'000L; @@ -240,7 +243,9 @@ class WorkGroup : public std::enable_shared_from_this { double _spill_mem_limit_threshold = 1.0; int64_t _spill_mem_limit_bytes = -1; + std::string _mem_pool; std::shared_ptr _mem_tracker = nullptr; + std::shared_ptr _shared_mem_tracker = nullptr; std::shared_ptr _connector_scan_mem_tracker = nullptr; WorkGroupDriverSchedEntity _driver_sched_entity; @@ -330,7 +335,7 @@ class WorkGroupManager { std::list _workgroup_expired_versions; std::atomic _sum_cpu_weight = 0; - + MemTrackerManager _shared_mem_tracker_manager; std::once_flag init_metrics_once_flag; std::unordered_map _wg_metrics; }; diff --git a/be/src/runtime/exec_env.h b/be/src/runtime/exec_env.h index f735fb22a25330..42afa8cd95a96a 100644 --- a/be/src/runtime/exec_env.h +++ b/be/src/runtime/exec_env.h @@ -131,6 +131,7 @@ class GlobalEnv { MemTracker* process_mem_tracker() { return _process_mem_tracker.get(); } MemTracker* query_pool_mem_tracker() { return _query_pool_mem_tracker.get(); } + std::shared_ptr query_pool_mem_tracker_shared() { return _query_pool_mem_tracker; } MemTracker* connector_scan_pool_mem_tracker() { return _connector_scan_pool_mem_tracker.get(); } MemTracker* load_mem_tracker() { return _load_mem_tracker.get(); } MemTracker* metadata_mem_tracker() { return _metadata_mem_tracker.get(); } diff --git a/be/src/runtime/mem_tracker.cpp b/be/src/runtime/mem_tracker.cpp index eacf0a9324a420..ed54f909860e35 100644 --- a/be/src/runtime/mem_tracker.cpp +++ b/be/src/runtime/mem_tracker.cpp @@ -272,6 +272,9 @@ std::string MemTracker::err_msg(const std::string& msg, RuntimeState* state) con << "You can change the limit by modifying [mem_limit] of this group"; } break; + case MemTrackerType::RESOURCE_GROUP_SHARED_MEMORY_POOL: + str << "Mem usage has exceed the limit of resource group memory pool [" << label() << "]. "; + break; case MemTrackerType::RESOURCE_GROUP_BIG_QUERY: str << "Mem usage has exceed the big query limit of the resource group [" << label() << "]. " << "You can change the limit by modifying [big_query_mem_limit] of this group"; diff --git a/be/src/runtime/mem_tracker.h b/be/src/runtime/mem_tracker.h index 95335c5c0b21fb..11f0982fa24569 100644 --- a/be/src/runtime/mem_tracker.h +++ b/be/src/runtime/mem_tracker.h @@ -90,6 +90,7 @@ enum class MemTrackerType { SCHEMA_CHANGE_TASK, SCHEMA_CHANGE, RESOURCE_GROUP, + RESOURCE_GROUP_SHARED_MEMORY_POOL, RESOURCE_GROUP_BIG_QUERY, JEMALLOC, PASSTHROUGH, diff --git a/be/test/CMakeLists.txt b/be/test/CMakeLists.txt index 347317880ba6c2..0fd0f7645ede1f 100644 --- a/be/test/CMakeLists.txt +++ b/be/test/CMakeLists.txt @@ -59,6 +59,8 @@ set(EXEC_FILES ./exec/paimon/paimon_delete_file_builder_test.cpp ./exec/workgroup/scan_task_queue_test.cpp ./exec/workgroup/pipeline_executor_set_test.cpp + ./exec/workgroup/work_group_manager_test.cpp + ./exec/workgroup/mem_tracker_manager_test.cpp ./exec/pipeline/schedule/observer_test.cpp ./exec/pipeline/pipeline_control_flow_test.cpp ./exec/pipeline/pipeline_driver_queue_test.cpp diff --git a/be/test/exec/pipeline/mem_limited_chunk_queue_test.cpp b/be/test/exec/pipeline/mem_limited_chunk_queue_test.cpp index d558f966694ab5..7d633a958f4c34 100644 --- a/be/test/exec/pipeline/mem_limited_chunk_queue_test.cpp +++ b/be/test/exec/pipeline/mem_limited_chunk_queue_test.cpp @@ -39,11 +39,11 @@ class MemLimitedChunkQueueTest : public ::testing::Test { auto fs = FileSystem::Default(); ASSERT_OK(fs->create_dir_recursive(path)); LOG(INFO) << "path: " << path; - - dummy_wg = std::make_shared("default_wg", workgroup::WorkGroup::DEFAULT_WG_ID, - workgroup::WorkGroup::DEFAULT_VERSION, 4, 100.0, 0, 1.0, - workgroup::WorkGroupType::WG_DEFAULT); - dummy_wg->init(); + auto parent = GlobalEnv::GetInstance()->query_pool_mem_tracker_shared(); + dummy_wg = std::make_shared( + "default_wg", workgroup::WorkGroup::DEFAULT_WG_ID, workgroup::WorkGroup::DEFAULT_VERSION, 4, 100.0, 0, + 1.0, workgroup::WorkGroupType::WG_DEFAULT, workgroup::WorkGroup::DEFAULT_MEM_POOL); + dummy_wg->init(parent); dummy_wg->set_shared_executors(ExecEnv::GetInstance()->workgroup_manager()->shared_executors()); dummy_dir_mgr = std::make_unique(); diff --git a/be/test/exec/pipeline/pipeline_driver_queue_test.cpp b/be/test/exec/pipeline/pipeline_driver_queue_test.cpp index 142b5bea758bad..6cc007280008ed 100644 --- a/be/test/exec/pipeline/pipeline_driver_queue_test.cpp +++ b/be/test/exec/pipeline/pipeline_driver_queue_test.cpp @@ -174,13 +174,17 @@ class WorkGroupDriverQueueTest : public ::testing::Test { public: void SetUp() override { _wg1 = std::make_shared("wg100", 100, workgroup::WorkGroup::DEFAULT_VERSION, 1, 0.5, 10, - 1.0, workgroup::WorkGroupType::WG_NORMAL); + 1.0, workgroup::WorkGroupType::WG_NORMAL, + workgroup::WorkGroup::DEFAULT_MEM_POOL); _wg2 = std::make_shared("wg200", 200, workgroup::WorkGroup::DEFAULT_VERSION, 2, 0.5, 10, - 1.0, workgroup::WorkGroupType::WG_NORMAL); + 1.0, workgroup::WorkGroupType::WG_NORMAL, + workgroup::WorkGroup::DEFAULT_MEM_POOL); _wg3 = std::make_shared("wg300", 300, workgroup::WorkGroup::DEFAULT_VERSION, 1, 0.5, 10, - 1.0, workgroup::WorkGroupType::WG_NORMAL); + 1.0, workgroup::WorkGroupType::WG_NORMAL, + workgroup::WorkGroup::DEFAULT_MEM_POOL); _wg4 = std::make_shared("wg400", 400, workgroup::WorkGroup::DEFAULT_VERSION, 1, 0.5, 10, - 1.0, workgroup::WorkGroupType::WG_NORMAL); + 1.0, workgroup::WorkGroupType::WG_NORMAL, + workgroup::WorkGroup::DEFAULT_MEM_POOL); _wg1 = ExecEnv::GetInstance()->workgroup_manager()->add_workgroup(_wg1); _wg2 = ExecEnv::GetInstance()->workgroup_manager()->add_workgroup(_wg2); _wg3 = ExecEnv::GetInstance()->workgroup_manager()->add_workgroup(_wg3); diff --git a/be/test/exec/pipeline/query_context_manger_test.cpp b/be/test/exec/pipeline/query_context_manger_test.cpp index 37e7d34637a8fe..243ec5c2b4dab4 100644 --- a/be/test/exec/pipeline/query_context_manger_test.cpp +++ b/be/test/exec/pipeline/query_context_manger_test.cpp @@ -277,9 +277,9 @@ TEST(QueryContextManagerTest, testSetWorkgroup) { auto query_ctx_mgr = std::make_shared(6); ASSERT_TRUE(query_ctx_mgr->init().ok()); - workgroup::WorkGroupPtr wg = std::make_shared("wg1", 1, 1, 1, 1, 1 /* concurrency_limit */, - 1.0 /* spill_mem_limit_threshold */, - workgroup::WorkGroupType::WG_NORMAL); + workgroup::WorkGroupPtr wg = std::make_shared( + "wg1", 1, 1, 1, 1, 1 /* concurrency_limit */, 1.0 /* spill_mem_limit_threshold */, + workgroup::WorkGroupType::WG_NORMAL, workgroup::WorkGroup::DEFAULT_MEM_POOL); auto* query_ctx1 = gen_query_ctx(parent_mem_tracker.get(), query_ctx_mgr.get(), 0, 1, 3, 60, 300); auto* query_ctx_overloaded = gen_query_ctx(parent_mem_tracker.get(), query_ctx_mgr.get(), 1, 2, 3, 60, 300); diff --git a/be/test/exec/workgroup/mem_tracker_manager_test.cpp b/be/test/exec/workgroup/mem_tracker_manager_test.cpp new file mode 100644 index 00000000000000..8624e9980cbeaf --- /dev/null +++ b/be/test/exec/workgroup/mem_tracker_manager_test.cpp @@ -0,0 +1,53 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "exec/workgroup/mem_tracker_manager.h" + +#include + +#include "exec/workgroup/work_group.h" +#include "testutil/parallel_test.h" + +namespace starrocks::workgroup { +PARALLEL_TEST(MemTrackerMangerTest, test_mem_tracker_for_default_mem_pool) { + MemTrackerManager manager; + const auto work_group{std::make_shared("default_wg", 123, WorkGroup::DEFAULT_VERSION, 1, 0.5, 0, 0.9, + WorkGroupType::WG_DEFAULT, WorkGroup::DEFAULT_MEM_POOL)}; + + const auto tracker{manager.get_parent_mem_tracker(work_group)}; + ASSERT_EQ(tracker, GlobalEnv::GetInstance()->query_pool_mem_tracker_shared()); +} + +PARALLEL_TEST(MemTrackerMangerTest, test_mem_tracker_for_custom_mem_pool) { + MemTrackerManager manager; + const auto work_group1{std::make_shared("wg_1", 123, WorkGroup::DEFAULT_VERSION, 1, 0.5, 0, 0.9, + WorkGroupType::WG_DEFAULT, "test_pool")}; + const auto work_group2{std::make_shared("wg_2", 134, WorkGroup::DEFAULT_VERSION, 1, 0.5, 0, 0.9, + WorkGroupType::WG_DEFAULT, "test_pool")}; + const auto work_group3{std::make_shared("wg_2", 134, WorkGroup::DEFAULT_VERSION, 1, 0.5, 0, 0.9, + WorkGroupType::WG_DEFAULT, "other_pool")}; + + ASSERT_EQ(manager.get_parent_mem_tracker(work_group1), manager.get_parent_mem_tracker(work_group2)); + ASSERT_NE(manager.get_parent_mem_tracker(work_group1), manager.get_parent_mem_tracker(work_group3)); +} +PARALLEL_TEST(MemTrackerMangerTest, test_mem_tracker_for_custom_mem_pool_overwrite) { + MemTrackerManager manager; + const auto work_group1{std::make_shared("wg_1", 123, WorkGroup::DEFAULT_VERSION, 1, 0.5, 0, 0.9, + WorkGroupType::WG_DEFAULT, "test_pool")}; + const auto work_group2{std::make_shared("wg_2", 134, WorkGroup::DEFAULT_VERSION, 1, 0.7, 0, 0.9, + WorkGroupType::WG_DEFAULT, "test_pool")}; + + ASSERT_NE(manager.get_parent_mem_tracker(work_group1), manager.get_parent_mem_tracker(work_group2)); +} +} // namespace starrocks::workgroup diff --git a/be/test/exec/workgroup/work_group_manager_test.cpp b/be/test/exec/workgroup/work_group_manager_test.cpp new file mode 100644 index 00000000000000..9d0df8aae67216 --- /dev/null +++ b/be/test/exec/workgroup/work_group_manager_test.cpp @@ -0,0 +1,85 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "exec/workgroup/work_group.h" +#include "runtime/mem_tracker.h" +#include "testutil/parallel_test.h" + +namespace starrocks::workgroup { +TWorkGroup create_twg(const int64_t id, const int64_t version, const std::string& name, const std::string& mem_pool, + const double mem_limit) { + TWorkGroup twg; + twg.__set_id(id); + twg.__set_version(version); + twg.__set_name(name); + twg.__set_mem_pool(mem_pool); + twg.__set_mem_limit(mem_limit); + return twg; +} + +PARALLEL_TEST(WorkGroupManagerTest, add_workgroups_different_mem_pools) { + PipelineExecutorSetConfig config{10, 1, 1, 1, CpuUtil::CpuIds{}, false, false, nullptr}; + auto _manager = std::make_unique(config); + + { + auto wg1 = std::make_shared(create_twg(102, 1, "wg", "test_pool", 0.5)); + auto wg2 = std::make_shared(create_twg(103, 1, "wg2", "test_pool1", 0.5)); + auto wg3 = std::make_shared(create_twg(104, 1, "wg3", WorkGroup::DEFAULT_MEM_POOL, 0.5)); + + _manager->add_workgroup(wg1); + _manager->add_workgroup(wg2); + _manager->add_workgroup(wg3); + + auto workgroups = _manager->list_workgroups(); + + ASSERT_EQ(3, workgroups.size()); + EXPECT_NE(wg2->mem_tracker()->parent(), wg1->mem_tracker()->parent()); + EXPECT_NE(wg2->mem_tracker()->parent(), wg3->mem_tracker()->parent()); + EXPECT_NE(wg1->mem_tracker()->parent(), wg3->mem_tracker()->parent()); + + EXPECT_EQ(wg1->mem_tracker()->parent()->type(), MemTrackerType::RESOURCE_GROUP_SHARED_MEMORY_POOL); + EXPECT_EQ(wg2->mem_tracker()->parent()->type(), MemTrackerType::RESOURCE_GROUP_SHARED_MEMORY_POOL); + EXPECT_EQ(wg3->mem_tracker()->parent()->type(), MemTrackerType::QUERY_POOL); + } + _manager->destroy(); +} + +PARALLEL_TEST(WorkGroupManagerTest, add_workgroups_same_mem_pools) { + PipelineExecutorSetConfig config{10, 1, 1, 1, CpuUtil::CpuIds{}, false, false, nullptr}; + auto _manager = std::make_unique(config); + + { + auto wg1 = std::make_shared(create_twg(105, 1, "wg5", "test_pool", 0.5)); + auto wg2 = std::make_shared(create_twg(106, 1, "wg6", "test_pool", 0.5)); + auto wg3 = std::make_shared(create_twg(107, 1, "wg7", WorkGroup::DEFAULT_MEM_POOL, 0.5)); + + _manager->add_workgroup(wg1); + _manager->add_workgroup(wg2); + _manager->add_workgroup(wg3); + + auto workgroups = _manager->list_workgroups(); + + ASSERT_EQ(3, workgroups.size()); + EXPECT_EQ(wg2->mem_tracker()->parent(), wg1->mem_tracker()->parent()); + EXPECT_EQ(wg2->mem_tracker()->parent()->type(), MemTrackerType::RESOURCE_GROUP_SHARED_MEMORY_POOL); + EXPECT_EQ(wg2->mem_limit_bytes(), wg2->mem_tracker()->parent()->limit()); + + EXPECT_NE(wg2->mem_tracker()->parent(), wg3->mem_tracker()->parent()); + EXPECT_EQ(wg3->mem_tracker()->parent()->type(), MemTrackerType::QUERY_POOL); + } + _manager->destroy(); +} +} // namespace starrocks::workgroup