|
31 | 31 | #include <mutex> |
32 | 32 | #include <queue> |
33 | 33 | #include <chrono> |
| 34 | +#include <unordered_set> |
| 35 | +#include <optional> |
34 | 36 |
|
35 | 37 | #include "ggml-impl.h" |
36 | 38 | #include "ggml-backend-impl.h" |
@@ -93,6 +95,26 @@ int32_t ggml_cann_get_device() { |
93 | 95 | return id; |
94 | 96 | } |
95 | 97 |
|
| 98 | +/** |
| 99 | + * @brief Get the value of the specified environment variable (name). |
| 100 | + * if not empty, return a std::string object |
| 101 | + */ |
| 102 | +std::optional<std::string> get_env(const std::string& name) { |
| 103 | + const char* val = std::getenv(name.c_str()); |
| 104 | + if (!val) return std::nullopt; |
| 105 | + std::string res = std::string(val); |
| 106 | + std::transform(res.begin(), res.end(), res.begin(), ::tolower); |
| 107 | + return res; |
| 108 | +} |
| 109 | + |
| 110 | +/** |
| 111 | + * @brief Verify whether the environment variable is a valid value. |
| 112 | + */ |
| 113 | +bool parse_bool(const std::string& value) { |
| 114 | + std::unordered_set<std::string> valid_values = {"on", "1", "yes", "y", "enable", "true"}; |
| 115 | + return valid_values.find(value) != valid_values.end(); |
| 116 | +} |
| 117 | + |
96 | 118 | /** |
97 | 119 | * @brief Initialize the CANN device information. |
98 | 120 | * |
@@ -214,7 +236,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { |
214 | 236 | * @param device The device ID to associate with this buffer pool. |
215 | 237 | */ |
216 | 238 | explicit ggml_cann_pool_buf_prio(int device) : device(device) { |
217 | | - disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; |
| 239 | + disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); |
218 | 240 | } |
219 | 241 |
|
220 | 242 | /** |
@@ -410,7 +432,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { |
410 | 432 | * @param device The device ID to associate with this buffer pool. |
411 | 433 | */ |
412 | 434 | explicit ggml_cann_pool_buf(int device) : device(device) { |
413 | | - disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; |
| 435 | + disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); |
414 | 436 | } |
415 | 437 |
|
416 | 438 | /** |
@@ -731,16 +753,18 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { |
731 | 753 | */ |
732 | 754 | std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device( |
733 | 755 | int device) { |
734 | | - bool disable_vmm = (getenv("GGML_CANN_DISABLE_VMM_POOL") != nullptr); |
735 | | - if (!disable_vmm && ggml_cann_info().devices[device].vmm) { |
736 | | - GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device); |
737 | | - return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device)); |
738 | | - } |
739 | | - bool enable_buf_prio = (getenv("GGML_CANN_ENABLE_BUF_PRIO_POOL") != nullptr); |
740 | | - if (enable_buf_prio) { |
| 756 | + std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or(""); |
| 757 | + |
| 758 | + if (mem_pool_type == "prio") { |
741 | 759 | GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device); |
742 | 760 | return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device)); |
743 | 761 | } |
| 762 | + |
| 763 | + if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") { |
| 764 | + GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device); |
| 765 | + return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device)); |
| 766 | + } |
| 767 | + |
744 | 768 | GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device); |
745 | 769 | return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device)); |
746 | 770 | } |
|
0 commit comments