Skip to content

Commit

Permalink
fix parseDisjointPoolConfig and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bratpiorka committed Jan 17, 2025
1 parent ed09541 commit 8ee9060
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 36 deletions.
57 changes: 21 additions & 36 deletions source/common/umf_pools/disjoint_pool_config_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,46 +175,31 @@ DisjointPoolAllConfigs parseDisjointPoolConfig(const std::string &config,
};

size_t MaxSize = (std::numeric_limits<size_t>::max)();
size_t EnableBuffers = 1;

// Update pool settings if specified in environment.
size_t EnableBuffers = 1;
if (config != "") {
std::string Params = config;
size_t Pos = Params.find(';');
if (Pos != std::string::npos) {
if (Pos > 0) {
GetValue(Params, Pos, EnableBuffers);
}
Params.erase(0, Pos + 1);
size_t Pos = Params.find(';');
if (Pos != std::string::npos) {
if (Pos > 0) {
GetValue(Params, Pos, MaxSize);
}
Params.erase(0, Pos + 1);
do {
size_t Pos = Params.find(';');
if (Pos != std::string::npos) {
if (Pos > 0) {
std::string MemParams = Params.substr(0, Pos);
MemTypeParser(MemParams);
}
Params.erase(0, Pos + 1);
if (Params.size() == 0) {
break;
}
} else {
MemTypeParser(Params);
break;
}
} while (true);
} else {
// set MaxPoolSize for all configs
GetValue(Params, Params.size(), MaxSize);
}
bool EnableBuffersSet = false;
bool MaxSizeSet = false;
size_t Start = 0;
size_t End = config.find(';');
while (true) {
std::string Param = config.substr(Start, End - Start);
if (!EnableBuffersSet && isdigit(Param[0])) {
GetValue(Param, Param.size(), EnableBuffers);
EnableBuffersSet = true;
} else if (!MaxSizeSet && isdigit(Param[0])) {
GetValue(Param, Param.size(), MaxSize);
MaxSizeSet = true;
} else {
GetValue(Params, Params.size(), EnableBuffers);
MemTypeParser(Param);
}

if (End == std::string::npos) {
break;
}

Start = End + 1;
End = config.find(';', Start);
}

AllConfigs.EnableBuffers = EnableBuffers;
Expand Down
46 changes: 46 additions & 0 deletions test/usm/usmPoolManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "umf_pools/disjoint_pool_config_parser.hpp"
#include "ur_pool_manager.hpp"

#include <uur/fixtures.h>
Expand All @@ -18,6 +19,26 @@ auto createMockPoolHandle() {
[](umf_memory_pool_t *) {});
}

bool compareConfig(const usm::umf_disjoint_pool_config_t &left,
usm::umf_disjoint_pool_config_t &right) {
return left.MaxPoolableSize == right.MaxPoolableSize &&
left.Capacity == right.Capacity &&
left.SlabMinSize == right.SlabMinSize;
}

bool compareConfigs(const usm::DisjointPoolAllConfigs &left,
usm::DisjointPoolAllConfigs &right) {
return left.EnableBuffers == right.EnableBuffers &&
compareConfig(left.Configs[usm::DisjointPoolMemType::Host],
right.Configs[usm::DisjointPoolMemType::Host]) &&
compareConfig(left.Configs[usm::DisjointPoolMemType::Device],
right.Configs[usm::DisjointPoolMemType::Device]) &&
compareConfig(left.Configs[usm::DisjointPoolMemType::Shared],
right.Configs[usm::DisjointPoolMemType::Shared]) &&
compareConfig(left.Configs[usm::DisjointPoolMemType::SharedReadOnly],
right.Configs[usm::DisjointPoolMemType::SharedReadOnly]);
}

TEST_P(urUsmPoolDescriptorTest, poolIsPerContextTypeAndDevice) {
auto &devices = uur::DevicesEnvironment::instance->devices;

Expand Down Expand Up @@ -111,4 +132,29 @@ TEST_P(urUsmPoolManagerTest, poolManagerGetNonexistant) {
}
}

TEST_P(urUsmPoolManagerTest, config) {
// Check default config
usm::DisjointPoolAllConfigs def;
usm::DisjointPoolAllConfigs parsed1 =
usm::parseDisjointPoolConfig("1;host:2M,4,64K;device:4M,4,64K;"
"shared:0,0,2M;read_only_shared:4M,4,2M",
0);
ASSERT_EQ(compareConfigs(def, parsed1), true);

// Check partially set config
usm::DisjointPoolAllConfigs parsed2 =
usm::parseDisjointPoolConfig("1;device:4M;shared:0,0,2M", 0);
ASSERT_EQ(compareConfigs(def, parsed2), true);

// Check non-default config
usm::DisjointPoolAllConfigs test(def);
test.Configs[usm::DisjointPoolMemType::Shared].MaxPoolableSize = 128 * 1024;
test.Configs[usm::DisjointPoolMemType::Shared].Capacity = 4;
test.Configs[usm::DisjointPoolMemType::Shared].SlabMinSize = 64 * 1024;

usm::DisjointPoolAllConfigs parsed3 =
usm::parseDisjointPoolConfig("1;shared:128K,4,64K", 0);
ASSERT_EQ(compareConfigs(test, parsed3), true);
}

UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urUsmPoolManagerTest);

0 comments on commit 8ee9060

Please sign in to comment.