diff --git a/crates/prek/src/languages/golang/gomod.rs b/crates/prek/src/languages/golang/gomod.rs index 36efd35a8..bd82fb8fc 100644 --- a/crates/prek/src/languages/golang/gomod.rs +++ b/crates/prek/src/languages/golang/gomod.rs @@ -1,7 +1,271 @@ +use std::io; +use std::path::Path; + use anyhow::Result; +use tracing::trace; +use crate::config::Language; use crate::hook::Hook; +use crate::languages::version::LanguageRequest; + +fn parse_go_mod_directives(contents: &str) -> (Option, Option) { + let mut go_version: Option = None; + let mut toolchain: Option = None; + + for line in contents.lines() { + let mut line = line.trim(); + if line.is_empty() { + continue; + } + + // Strip `//` comments. + if let Some((before, _)) = line.split_once("//") { + line = before.trim(); + if line.is_empty() { + continue; + } + } + + let mut tokens = line.split_whitespace(); + let Some(directive) = tokens.next() else { + continue; + }; + let value = tokens.next(); + + // `go 1.22.0` + if go_version.is_none() && directive == "go" { + if let Some(version) = value { + go_version = Some(version.to_string()); + } + continue; + } + + // `toolchain go1.22.1` + if toolchain.is_none() && directive == "toolchain" { + if let Some(version) = value { + // `toolchain` in go.mod does not accept `default`. + if version != "default" { + toolchain = Some(version.to_string()); + } + } + } + } + + (go_version, toolchain) +} + +fn normalize_go_semver_min(version: &str) -> String { + // `go.mod` commonly uses `1.23` (no patch). The semver range parser is happier when + // we provide a full `MAJOR.MINOR.PATCH` minimum. + let mut parts = version.split('.').collect::>(); + if parts.is_empty() { + return version.to_string(); + } + + // If any part isn't a pure integer (e.g., `1.23rc1`), keep it as-is. + // TODO: support pre-release versions properly. + if parts.iter().any(|p| p.parse::().is_err()) { + return version.to_string(); + } + + match parts.len() { + 1 => { + parts.push("0"); + parts.push("0"); + } + 2 => { + parts.push("0"); + } + _ => {} + } + + parts.join(".") +} + +fn choose_language_version_from_go_mod(contents: &str) -> Option { + let (go_version, toolchain) = parse_go_mod_directives(contents); + + // Prefer `go` to maximize cache reuse: it's typically stable across patch updates. + let go_version = go_version.or(toolchain)?; + let stripped = go_version.strip_prefix("go").unwrap_or(&go_version); + let normalized = normalize_go_semver_min(stripped); + Some(format!(">= {normalized}")) +} + +async fn extract_go_mod_language_request(repo_path: &Path) -> Result> { + let go_mod = repo_path.join("go.mod"); + let contents = match fs_err::tokio::read(&go_mod).await { + Ok(bytes) => bytes, + Err(err) if err.kind() == io::ErrorKind::NotFound => return Ok(None), + Err(err) => return Err(err.into()), + }; + let contents = str::from_utf8(&contents)?; + + Ok(choose_language_version_from_go_mod(contents)) +} + +pub(crate) async fn extract_go_mod_metadata(hook: &mut Hook) -> Result<()> { + // Respect an explicitly configured `language_version`. + if !hook.language_request.is_any() { + trace!(hook = %hook, "Skipping go.mod metadata extraction because language_version is already configured"); + return Ok(()); + } + + let Some(repo_path) = hook.repo_path() else { + return Ok(()); + }; + + let Some(req_str) = extract_go_mod_language_request(repo_path).await? else { + trace!(hook = %hook, "No go or toolchain directive found in go.mod"); + return Ok(()); + }; + + let req = match LanguageRequest::parse(Language::Golang, &req_str) { + Ok(req) => req, + Err(err) => { + trace!(%req_str, error = %err, "Ignoring invalid go.mod-derived language_version"); + return Ok(()); + } + }; + + trace!(hook = %hook, version = %req_str, "Using go.mod-derived language_version"); + hook.language_request = req; -pub(crate) async fn extract_go_mod_metadata(_hook: &mut Hook) -> Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn go_line_is_used_when_only_go_present() { + let contents = r"module example.com/foo + +go 1.22.0 +"; + assert_eq!( + choose_language_version_from_go_mod(contents).as_deref(), + Some(">= 1.22.0") + ); + } + + #[test] + fn go_is_preferred_over_toolchain() { + let contents = r"module example.com/foo + +go 1.22.0 +toolchain go1.22.3 +"; + assert_eq!( + choose_language_version_from_go_mod(contents).as_deref(), + Some(">= 1.22.0") + ); + } + + #[test] + fn invalid_toolchain_value_is_ignored() { + let contents = r"module example.com/foo + +toolchain default +"; + assert_eq!( + choose_language_version_from_go_mod(contents).as_deref(), + None + ); + } + + #[test] + fn comments_and_whitespace_are_ignored() { + let contents = "// header + +// go 1.22 +go 1.20.4 // ignored +// trailing +"; + assert_eq!( + choose_language_version_from_go_mod(contents).as_deref(), + Some(">= 1.20.4") + ); + } + + #[test] + fn toolchain_is_used_when_no_go_present() { + let contents = r"module example.com/foo + +toolchain go1.23.10 +"; + assert_eq!( + choose_language_version_from_go_mod(contents).as_deref(), + Some(">= 1.23.10") + ); + } + + #[test] + fn go_minor_is_normalized_to_patch() { + let contents = r"module example.com/foo + +go 1.23 +"; + assert_eq!( + choose_language_version_from_go_mod(contents).as_deref(), + Some(">= 1.23.0") + ); + } + + #[tokio::test] + async fn extract_language_request_from_repo_go_line() -> anyhow::Result<()> { + let dir = tempfile::tempdir()?; + fs_err::tokio::write( + dir.path().join("go.mod"), + "module example.com/foo\n\ngo 1.22\n", + ) + .await?; + + let Some(req) = extract_go_mod_language_request(dir.path()).await? else { + anyhow::bail!("Expected a language request"); + }; + assert_eq!(req, ">= 1.22.0"); + + Ok(()) + } + + #[tokio::test] + async fn extract_language_request_from_repo_toolchain_when_no_go() -> anyhow::Result<()> { + let dir = tempfile::tempdir()?; + fs_err::tokio::write( + dir.path().join("go.mod"), + "module example.com/foo\n\ntoolchain go1.23.10\n", + ) + .await?; + + let Some(req) = extract_go_mod_language_request(dir.path()).await? else { + anyhow::bail!("Expected a language request"); + }; + + assert_eq!(req, ">= 1.23.10"); + Ok(()) + } + + #[tokio::test] + async fn extract_language_request_ignores_invalid_toolchain_value() -> anyhow::Result<()> { + let dir = tempfile::tempdir()?; + fs_err::tokio::write( + dir.path().join("go.mod"), + "module example.com/foo\n\ntoolchain default\n", + ) + .await?; + + let req = extract_go_mod_language_request(dir.path()).await?; + assert!(req.is_none()); + Ok(()) + } + + #[tokio::test] + async fn extract_language_request_missing_go_mod_is_none() -> anyhow::Result<()> { + let dir = tempfile::tempdir()?; + let req = extract_go_mod_language_request(dir.path()).await?; + assert!(req.is_none()); + Ok(()) + } +} diff --git a/crates/prek/src/languages/golang/installer.rs b/crates/prek/src/languages/golang/installer.rs index 483aecf12..56af356c1 100644 --- a/crates/prek/src/languages/golang/installer.rs +++ b/crates/prek/src/languages/golang/installer.rs @@ -133,7 +133,7 @@ impl GoInstaller { let resolved_version = self .resolve_version(request) .await - .context("Failed to resolve Go version")?; + .with_context(|| format!("Failed to resolve go version `{request}`"))?; trace!(version = %resolved_version, "Installing go"); self.download(store, &resolved_version).await @@ -195,7 +195,7 @@ impl GoInstaller { let version = versions .into_iter() .find(|version| req.matches(version, None)) - .context("Version not found on remote")?; + .with_context(|| format!("Version `{req}` not found on remote"))?; Ok(version) } diff --git a/crates/prek/src/languages/golang/version.rs b/crates/prek/src/languages/golang/version.rs index 1ebcfc65a..89b87c644 100644 --- a/crates/prek/src/languages/golang/version.rs +++ b/crates/prek/src/languages/golang/version.rs @@ -64,6 +64,21 @@ pub(crate) enum GoRequest { // MajorMinorPrerelease(u64, u64, String), } +impl Display for GoRequest { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + GoRequest::Any => write!(f, "any"), + GoRequest::Major(major) => write!(f, "go{major}"), + GoRequest::MajorMinor(major, minor) => write!(f, "go{major}.{minor}"), + GoRequest::MajorMinorPatch(major, minor, patch) => { + write!(f, "go{major}.{minor}.{patch}") + } + GoRequest::Path(path) => write!(f, "path: {}", path.display()), + GoRequest::Range(_, raw) => write!(f, "{raw}"), + } + } +} + impl FromStr for GoRequest { type Err = Error; @@ -143,3 +158,97 @@ impl GoRequest { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_go_request_from_str() { + let cases = vec![ + ("", GoRequest::Any), + ("go", GoRequest::Any), + ("go1", GoRequest::Major(1)), + ("1", GoRequest::Major(1)), + ("go1.20", GoRequest::MajorMinor(1, 20)), + ("1.20", GoRequest::MajorMinor(1, 20)), + ("go1.20.3", GoRequest::MajorMinorPatch(1, 20, 3)), + ("1.20.3", GoRequest::MajorMinorPatch(1, 20, 3)), + ( + ">= 1.20, < 1.22", + GoRequest::Range( + semver::VersionReq::parse(">= 1.20, < 1.22").unwrap(), + ">= 1.20, < 1.22".into(), + ), + ), + ]; + + for (input, expected) in cases { + let req = GoRequest::from_str(input).unwrap(); + assert_eq!(req, expected, "Input: {input}"); + } + } + + #[test] + fn test_go_request_invalid() { + let invalid_cases = vec!["go1.20.3.4", "go1.beta", "invalid_version"]; + for input in invalid_cases { + let req = GoRequest::from_str(input); + assert!(req.is_err(), "Input: {input}"); + } + } + + #[test] + fn test_go_request_matches() { + let version = GoVersion(semver::Version::new(1, 20, 3)); + let cases = vec![ + (GoRequest::Any, true), + (GoRequest::Major(1), true), + (GoRequest::Major(2), false), + (GoRequest::MajorMinor(1, 20), true), + (GoRequest::MajorMinor(1, 21), false), + (GoRequest::MajorMinorPatch(1, 20, 3), true), + (GoRequest::MajorMinorPatch(1, 20, 4), false), + ( + GoRequest::Range( + semver::VersionReq::parse(">= 1.19, < 1.21").unwrap(), + ">= 1.19, < 1.21".into(), + ), + true, + ), + ( + GoRequest::Range( + semver::VersionReq::parse(">= 1.21").unwrap(), + ">= 1.21".into(), + ), + false, + ), + ]; + + for (req, expected) in cases { + let result = req.matches(&version, None); + assert_eq!(result, expected, "Request: {req}"); + } + } + + #[test] + fn test_go_request_display() { + let cases = vec![ + (GoRequest::Any, "any"), + (GoRequest::Major(1), "go1"), + (GoRequest::MajorMinor(1, 20), "go1.20"), + (GoRequest::MajorMinorPatch(1, 20, 3), "go1.20.3"), + ( + GoRequest::Range( + semver::VersionReq::parse(">= 1.20, < 1.22").unwrap(), + ">= 1.20, < 1.22".into(), + ), + ">= 1.20, < 1.22", + ), + ]; + for (req, expected) in cases { + let req_str = req.to_string(); + assert_eq!(req_str, expected, "Request: {req:?}"); + } + } +} diff --git a/crates/prek/tests/languages/golang.rs b/crates/prek/tests/languages/golang.rs index 0a1e86448..355edc88d 100644 --- a/crates/prek/tests/languages/golang.rs +++ b/crates/prek/tests/languages/golang.rs @@ -324,3 +324,70 @@ fn local_additional_deps() -> anyhow::Result<()> { Ok(()) } + +/// Ensure `go.mod` metadata (go/toolchain directives) is used to constrain +/// the Go version for remote hooks. +#[test] +fn remote_go_mod_metadata_sets_language_version() -> anyhow::Result<()> { + // Create a remote repo containing a golang hook. + let go_hook = TestContext::new(); + go_hook.init_project(); + go_hook.configure_git_author(); + go_hook.disable_auto_crlf(); + + go_hook + .work_dir() + .child("go.mod") + .write_str(indoc::indoc! {r" + module example.com/go-hook + + go 2.100 // unrealistic version to ensure the downloading fails + "})?; + + go_hook + .work_dir() + .child(MANIFEST_FILE) + .write_str(indoc::indoc! {r" + - id: echo + name: echo + entry: echo + language: golang + verbose: true + "})?; + + go_hook.git_add("."); + go_hook.git_commit("Initial commit"); + Command::new("git") + .args(["tag", "v1.0", "-m", "v1.0"]) + .current_dir(go_hook.work_dir()) + .output()?; + + // Use it as a remote repo in a separate project. + let context = TestContext::new(); + context.init_project(); + + let hook_url = go_hook.work_dir().to_str().unwrap(); + context.write_pre_commit_config(&indoc::formatdoc! {r" + repos: + - repo: {hook_url} + rev: v1.0 + hooks: + - id: echo + verbose: true + ", hook_url = hook_url}); + context.git_add("."); + + cmd_snapshot!(context.filters(), context.run(), @" + success: false + exit_code: 2 + ----- stdout ----- + + ----- stderr ----- + error: Failed to install hook `echo` + caused by: Failed to install go + caused by: Failed to resolve go version `>= 2.100.0` + caused by: Version `>= 2.100.0` not found on remote + "); + + Ok(()) +}