Skip to content

Commit 5e533bd

Browse files
authored
feat: Improve llama.cpp argument handling and add device parsing tests (#6041)
* feat: Improve llama.cpp argument handling and add device parsing tests This commit refactors how arguments are passed to llama.cpp, specifically by only adding arguments when their values differ from their defaults. This reduces the verbosity of the command and prevents potential conflicts or errors when llama.cpp's default behavior aligns with the desired setting. Additionally, new tests have been added for parsing device output from llama.cpp, ensuring the accurate extraction of GPU information (ID, name, total memory, and free memory). This improves the robustness of device detection. The following changes were made: * **Remove redundant `--ctx-size` argument:** The `--ctx-size` argument is now only explicitly added if `cfg.ctx_size` is greater than 0. * **Conditional argument adding for default values:** * `--split-mode` is only added if `cfg.split_mode` is not empty and not 'layer'. * `--main-gpu` is only added if `cfg.main_gpu` is not undefined and not 0. * `--cache-type-k` is only added if `cfg.cache_type_k` is not 'f16'. * `--cache-type-v` is only added if `cfg.cache_type_v` is not 'f16' (when `flash_attn` is enabled) or not 'f32' (otherwise). This also corrects the `flash_attn` condition. * `--defrag-thold` is only added if `cfg.defrag_thold` is not 0.1. * `--rope-scaling` is only added if `cfg.rope_scaling` is not 'none'. * `--rope-scale` is only added if `cfg.rope_scale` is not 1. * `--rope-freq-base` is only added if `cfg.rope_freq_base` is not 0. * `--rope-freq-scale` is only added if `cfg.rope_freq_scale` is not 1. * **Add `parse_device_output` tests:** Comprehensive unit tests were added to `src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs` to validate the parsing of llama.cpp device output under various scenarios, including multiple devices, single devices, different backends (CUDA, Vulkan, SYCL), complex GPU names, and error conditions. * fixup cache_type_v comparision
1 parent a8613e5 commit 5e533bd

File tree

2 files changed

+210
-16
lines changed
  • extensions/llamacpp-extension/src
  • src-tauri/src/core/utils/extensions/inference_llamacpp_extension

2 files changed

+210
-16
lines changed

extensions/llamacpp-extension/src/index.ts

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,11 +1247,6 @@ export default class llamacpp_extension extends AIEngine {
12471247
])
12481248
args.push('--mmproj', mmprojPath)
12491249
}
1250-
1251-
if (cfg.ctx_size !== undefined) {
1252-
args.push('-c', String(cfg.ctx_size))
1253-
}
1254-
12551250
// Add remaining options from the interface
12561251
if (cfg.chat_template) args.push('--chat-template', cfg.chat_template)
12571252
const gpu_layers =
@@ -1263,8 +1258,9 @@ export default class llamacpp_extension extends AIEngine {
12631258
if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size))
12641259
if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size))
12651260
if (cfg.device.length > 0) args.push('--device', cfg.device)
1266-
if (cfg.split_mode.length > 0) args.push('--split-mode', cfg.split_mode)
1267-
if (cfg.main_gpu !== undefined)
1261+
if (cfg.split_mode.length > 0 && cfg.split_mode != 'layer')
1262+
args.push('--split-mode', cfg.split_mode)
1263+
if (cfg.main_gpu !== undefined && cfg.main_gpu != 0)
12681264
args.push('--main-gpu', String(cfg.main_gpu))
12691265

12701266
// Boolean flags
@@ -1280,19 +1276,25 @@ export default class llamacpp_extension extends AIEngine {
12801276
} else {
12811277
if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size))
12821278
if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict))
1283-
args.push('--cache-type-k', cfg.cache_type_k)
1279+
if (cfg.cache_type_k && cfg.cache_type_k != 'f16')
1280+
args.push('--cache-type-k', cfg.cache_type_k)
12841281
if (
1285-
(cfg.flash_attn && cfg.cache_type_v != 'f16') ||
1286-
cfg.cache_type_v != 'f32'
1282+
cfg.flash_attn &&
1283+
(cfg.cache_type_v != 'f16' && cfg.cache_type_v != 'f32')
12871284
) {
12881285
args.push('--cache-type-v', cfg.cache_type_v)
12891286
}
1290-
args.push('--defrag-thold', String(cfg.defrag_thold))
1291-
1292-
args.push('--rope-scaling', cfg.rope_scaling)
1293-
args.push('--rope-scale', String(cfg.rope_scale))
1294-
args.push('--rope-freq-base', String(cfg.rope_freq_base))
1295-
args.push('--rope-freq-scale', String(cfg.rope_freq_scale))
1287+
if (cfg.defrag_thold && cfg.defrag_thold != 0.1)
1288+
args.push('--defrag-thold', String(cfg.defrag_thold))
1289+
1290+
if (cfg.rope_scaling && cfg.rope_scaling != 'none')
1291+
args.push('--rope-scaling', cfg.rope_scaling)
1292+
if (cfg.rope_scale && cfg.rope_scale != 1)
1293+
args.push('--rope-scale', String(cfg.rope_scale))
1294+
if (cfg.rope_freq_base && cfg.rope_freq_base != 0)
1295+
args.push('--rope-freq-base', String(cfg.rope_freq_base))
1296+
if (cfg.rope_freq_scale && cfg.rope_freq_scale != 1)
1297+
args.push('--rope-freq-scale', String(cfg.rope_freq_scale))
12961298
}
12971299

12981300
logger.info('Calling Tauri command llama_load with args:', args)

src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,3 +708,195 @@ pub async fn is_process_running(pid: i32, state: State<'_, AppState>) -> Result<
708708
pub fn is_port_available(port: u16) -> bool {
709709
std::net::TcpListener::bind(("127.0.0.1", port)).is_ok()
710710
}
711+
712+
// tests
713+
//
714+
#[cfg(test)]
715+
mod tests {
716+
use super::*;
717+
718+
#[test]
719+
fn test_parse_multiple_devices() {
720+
let output = r#"ggml_vulkan: Found 2 Vulkan devices:
721+
ggml_vulkan: 0 = NVIDIA GeForce RTX 3090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: KHR_coopmat
722+
ggml_vulkan: 1 = AMD Radeon Graphics (RADV GFX1151) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 0 | matrix cores: KHR_coopmat
723+
Available devices:
724+
Vulkan0: NVIDIA GeForce RTX 3090 (24576 MiB, 24576 MiB free)
725+
Vulkan1: AMD Radeon Graphics (RADV GFX1151) (87722 MiB, 87722 MiB free)
726+
"#;
727+
728+
let devices = parse_device_output(output).unwrap();
729+
730+
assert_eq!(devices.len(), 2);
731+
732+
// Check first device
733+
assert_eq!(devices[0].id, "Vulkan0");
734+
assert_eq!(devices[0].name, "NVIDIA GeForce RTX 3090");
735+
assert_eq!(devices[0].mem, 24576);
736+
assert_eq!(devices[0].free, 24576);
737+
738+
// Check second device
739+
assert_eq!(devices[1].id, "Vulkan1");
740+
assert_eq!(devices[1].name, "AMD Radeon Graphics (RADV GFX1151)");
741+
assert_eq!(devices[1].mem, 87722);
742+
assert_eq!(devices[1].free, 87722);
743+
}
744+
745+
#[test]
746+
fn test_parse_single_device() {
747+
let output = r#"Available devices:
748+
CUDA0: NVIDIA GeForce RTX 4090 (24576 MiB, 24000 MiB free)"#;
749+
750+
let devices = parse_device_output(output).unwrap();
751+
752+
assert_eq!(devices.len(), 1);
753+
assert_eq!(devices[0].id, "CUDA0");
754+
assert_eq!(devices[0].name, "NVIDIA GeForce RTX 4090");
755+
assert_eq!(devices[0].mem, 24576);
756+
assert_eq!(devices[0].free, 24000);
757+
}
758+
759+
#[test]
760+
fn test_parse_with_extra_whitespace_and_empty_lines() {
761+
let output = r#"
762+
Available devices:
763+
764+
Vulkan0: NVIDIA GeForce RTX 3090 (24576 MiB, 24576 MiB free)
765+
766+
Vulkan1: AMD Radeon Graphics (RADV GFX1151) (87722 MiB, 87722 MiB free)
767+
768+
"#;
769+
770+
let devices = parse_device_output(output).unwrap();
771+
772+
assert_eq!(devices.len(), 2);
773+
assert_eq!(devices[0].id, "Vulkan0");
774+
assert_eq!(devices[1].id, "Vulkan1");
775+
}
776+
777+
#[test]
778+
fn test_parse_different_backends() {
779+
let output = r#"Available devices:
780+
CUDA0: NVIDIA GeForce RTX 4090 (24576 MiB, 24000 MiB free)
781+
Vulkan0: NVIDIA GeForce RTX 3090 (24576 MiB, 24576 MiB free)
782+
SYCL0: Intel(R) Arc(TM) A750 Graphics (8000 MiB, 7721 MiB free)"#;
783+
784+
let devices = parse_device_output(output).unwrap();
785+
786+
assert_eq!(devices.len(), 3);
787+
788+
assert_eq!(devices[0].id, "CUDA0");
789+
assert_eq!(devices[0].name, "NVIDIA GeForce RTX 4090");
790+
791+
assert_eq!(devices[1].id, "Vulkan0");
792+
assert_eq!(devices[1].name, "NVIDIA GeForce RTX 3090");
793+
794+
assert_eq!(devices[2].id, "SYCL0");
795+
assert_eq!(devices[2].name, "Intel(R) Arc(TM) A750 Graphics");
796+
assert_eq!(devices[2].mem, 8000);
797+
assert_eq!(devices[2].free, 7721);
798+
}
799+
800+
#[test]
801+
fn test_parse_complex_gpu_names() {
802+
let output = r#"Available devices:
803+
Vulkan0: Intel(R) Arc(tm) A750 Graphics (DG2) (8128 MiB, 8128 MiB free)
804+
Vulkan1: AMD Radeon RX 7900 XTX (Navi 31) [RDNA 3] (24576 MiB, 24000 MiB free)"#;
805+
806+
let devices = parse_device_output(output).unwrap();
807+
808+
assert_eq!(devices.len(), 2);
809+
810+
assert_eq!(devices[0].id, "Vulkan0");
811+
assert_eq!(devices[0].name, "Intel(R) Arc(tm) A750 Graphics (DG2)");
812+
assert_eq!(devices[0].mem, 8128);
813+
assert_eq!(devices[0].free, 8128);
814+
815+
assert_eq!(devices[1].id, "Vulkan1");
816+
assert_eq!(devices[1].name, "AMD Radeon RX 7900 XTX (Navi 31) [RDNA 3]");
817+
assert_eq!(devices[1].mem, 24576);
818+
assert_eq!(devices[1].free, 24000);
819+
}
820+
821+
#[test]
822+
fn test_parse_no_devices() {
823+
let output = r#"Available devices:"#;
824+
825+
let devices = parse_device_output(output).unwrap();
826+
assert_eq!(devices.len(), 0);
827+
}
828+
829+
#[test]
830+
fn test_parse_missing_header() {
831+
let output = r#"Vulkan0: NVIDIA GeForce RTX 3090 (24576 MiB, 24576 MiB free)"#;
832+
833+
let result = parse_device_output(output);
834+
assert!(result.is_err());
835+
assert!(result
836+
.unwrap_err()
837+
.to_string()
838+
.contains("Could not find 'Available devices:' section"));
839+
}
840+
841+
#[test]
842+
fn test_parse_malformed_device_line() {
843+
let output = r#"Available devices:
844+
Vulkan0: NVIDIA GeForce RTX 3090 (24576 MiB, 24576 MiB free)
845+
Invalid line without colon
846+
Vulkan1: AMD Radeon Graphics (RADV GFX1151) (87722 MiB, 87722 MiB free)"#;
847+
848+
let devices = parse_device_output(output).unwrap();
849+
850+
// Should skip the malformed line and parse the valid ones
851+
assert_eq!(devices.len(), 2);
852+
assert_eq!(devices[0].id, "Vulkan0");
853+
assert_eq!(devices[1].id, "Vulkan1");
854+
}
855+
856+
#[test]
857+
fn test_parse_device_line_individual() {
858+
// Test the individual line parser
859+
let line = "Vulkan0: NVIDIA GeForce RTX 3090 (24576 MiB, 24576 MiB free)";
860+
let device = parse_device_line(line).unwrap().unwrap();
861+
862+
assert_eq!(device.id, "Vulkan0");
863+
assert_eq!(device.name, "NVIDIA GeForce RTX 3090");
864+
assert_eq!(device.mem, 24576);
865+
assert_eq!(device.free, 24576);
866+
}
867+
868+
#[test]
869+
fn test_memory_pattern_detection() {
870+
assert!(is_memory_pattern("24576 MiB, 24576 MiB free"));
871+
assert!(is_memory_pattern("8000 MiB, 7721 MiB free"));
872+
assert!(!is_memory_pattern("just some text"));
873+
assert!(!is_memory_pattern("24576 MiB"));
874+
assert!(!is_memory_pattern("24576, 24576"));
875+
}
876+
877+
#[test]
878+
fn test_parse_memory_value() {
879+
assert_eq!(parse_memory_value("24576 MiB").unwrap(), 24576);
880+
assert_eq!(parse_memory_value("7721 MiB free").unwrap(), 7721);
881+
assert_eq!(parse_memory_value("8000").unwrap(), 8000);
882+
883+
assert!(parse_memory_value("").is_err());
884+
assert!(parse_memory_value("not_a_number MiB").is_err());
885+
}
886+
887+
#[test]
888+
fn test_find_memory_pattern() {
889+
let text = "NVIDIA GeForce RTX 3090 (24576 MiB, 24576 MiB free)";
890+
let result = find_memory_pattern(text);
891+
assert!(result.is_some());
892+
let (_start, content) = result.unwrap();
893+
assert_eq!(content, "24576 MiB, 24576 MiB free");
894+
895+
// Test with multiple parentheses
896+
let text = "Intel(R) Arc(tm) A750 Graphics (DG2) (8128 MiB, 8128 MiB free)";
897+
let result = find_memory_pattern(text);
898+
assert!(result.is_some());
899+
let (_start, content) = result.unwrap();
900+
assert_eq!(content, "8128 MiB, 8128 MiB free");
901+
}
902+
}

0 commit comments

Comments
 (0)