diff --git a/src/kernels/build.rs b/src/kernels/build.rs index e573e8d..c58caa6 100644 --- a/src/kernels/build.rs +++ b/src/kernels/build.rs @@ -86,10 +86,16 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed=src/flashinfer_adapter.cu"); // DO not change this, this featch custom flashinfer v0.6.2 headers // which is compatible with our code (added more gqa group_size) + let fi_repo = std::env::var("CARGO_FEATURE_FLASHINFER_REPO").unwrap_or( + "https://github.com/guoqingbao/flashinfer.git".to_string() + ); + let fi_commit = std::env::var("CARGO_FEATURE_FLASHINFER_COMMIT").unwrap_or( + "960cb902ce15ec085d42aa1bbe7026979c9a04dd".to_string() // v0.6.2 + ); builder = builder.arg("-DUSE_FLASHINFER").with_git_dependency( "flashinfer", - "https://github.com/guoqingbao/flashinfer.git", - "960cb902ce15ec085d42aa1bbe7026979c9a04dd", // v0.6.2 + fi_repo.as_str(), + fi_commit.as_str(), vec!["include"], false, );