diff --git a/Cargo.lock b/Cargo.lock index bb8ef2fba3b17..dcf953a4f0724 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3259,6 +3259,7 @@ version = "0.0.0" dependencies = [ "insta", "regex", + "ruff_linter", "ruff_python_ast", "ruff_python_formatter", "ruff_python_trivia", diff --git a/crates/ruff_linter/src/settings/types.rs b/crates/ruff_linter/src/settings/types.rs index c81e21d64c53a..ce22da37b1319 100644 --- a/crates/ruff_linter/src/settings/types.rs +++ b/crates/ruff_linter/src/settings/types.rs @@ -500,6 +500,11 @@ impl ExtensionMapping { let ext = path.extension()?.to_str()?; self.0.get(ext).copied() } + + /// Return the [`Language`] for a given file extension. + pub fn get_extension(&self, ext: &str) -> Option { + self.0.get(ext).copied() + } } impl From> for ExtensionMapping { diff --git a/crates/ruff_markdown/Cargo.toml b/crates/ruff_markdown/Cargo.toml index c611deeecd0f4..f40d1e3b8cacb 100644 --- a/crates/ruff_markdown/Cargo.toml +++ b/crates/ruff_markdown/Cargo.toml @@ -17,8 +17,12 @@ ruff_source_file = { workspace = true } ruff_text_size = { workspace = true } ruff_workspace = { workspace = true } -insta = { workspace = true } regex = { workspace = true } +[dev-dependencies] +ruff_linter = { workspace = true } + +insta = { workspace = true } + [lints] workspace = true diff --git a/crates/ruff_markdown/src/lib.rs b/crates/ruff_markdown/src/lib.rs index 7277f0c7bf63e..a0cc4daf9a209 100644 --- a/crates/ruff_markdown/src/lib.rs +++ b/crates/ruff_markdown/src/lib.rs @@ -93,7 +93,10 @@ pub fn format_code_blocks( let end = code_line.start(); let unformatted_code = dedent(&source[TextRange::new(start, end)]); - let py_source_type = PySourceType::from_extension(&language); + let py_source_type = match settings.extension.get_extension(&language) { + None => PySourceType::from_extension(&language), + Some(language) => PySourceType::from(language), + }; let options = settings.to_format_options(py_source_type, &unformatted_code, path); @@ -132,6 +135,7 @@ pub fn format_code_blocks( #[cfg(test)] mod tests { use insta::assert_snapshot; + use ruff_linter::settings::types::{ExtensionMapping, ExtensionPair, Language}; use ruff_workspace::FormatterSettings; use crate::{MarkdownResult, format_code_blocks}; @@ -384,4 +388,27 @@ print( 'hello' ) &FormatterSettings::default() ), @"Unchanged"); } + + #[test] + fn format_code_blocks_extension_mapping() { + // format "py" mapped as "pyi" instead + let code = r#" +```py +def foo(): ... +def bar(): ... +``` + "#; + let mapping = ExtensionMapping::from_iter([ExtensionPair { + extension: "py".to_string(), + language: Language::Pyi, + }]); + assert_snapshot!(format_code_blocks( + code, + None, + &FormatterSettings { + extension: mapping, + ..Default::default() + } + ), @"Unchanged"); + } }