Skip to content

Commit

Permalink
red-knot(Salsa): Types without refinements (#11899)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaReiser authored Jun 20, 2024
1 parent a26bd01 commit 22733cb
Show file tree
Hide file tree
Showing 13 changed files with 2,168 additions and 146 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/ruff_python_semantic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ruff_text_size = { workspace = true }

bitflags = { workspace = true }
is-macro = { workspace = true }
indexmap = { workspace = true, optional = true }
salsa = { workspace = true, optional = true }
smallvec = { workspace = true, optional = true }
smol_str = { workspace = true }
Expand All @@ -36,4 +37,4 @@ tempfile = { workspace = true }
workspace = true

[features]
red_knot = ["dep:salsa", "dep:tracing", "dep:hashbrown", "dep:smallvec"]
red_knot = ["dep:salsa", "dep:tracing", "dep:hashbrown", "dep:smallvec", "dep:indexmap"]
125 changes: 119 additions & 6 deletions crates/ruff_python_semantic/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,41 @@ use crate::module::resolver::{
file_to_module, internal::ModuleNameIngredient, internal::ModuleResolverSearchPaths,
resolve_module_query,
};

use crate::red_knot::semantic_index::symbol::ScopeId;
use crate::red_knot::semantic_index::{scopes_map, semantic_index, symbol_table};
use crate::red_knot::semantic_index::symbol::{
public_symbols_map, scopes_map, PublicSymbolId, ScopeId,
};
use crate::red_knot::semantic_index::{root_scope, semantic_index, symbol_table};
use crate::red_knot::types::{infer_types, public_symbol_ty};

#[salsa::jar(db=Db)]
pub struct Jar(
ModuleNameIngredient,
ModuleResolverSearchPaths,
ScopeId,
PublicSymbolId,
symbol_table,
resolve_module_query,
file_to_module,
scopes_map,
root_scope,
semantic_index,
infer_types,
public_symbol_ty,
public_symbols_map,
);

/// Database giving access to semantic information about a Python program.
pub trait Db: SourceDb + DbWithJar<Jar> + Upcast<dyn SourceDb> {}

#[cfg(test)]
pub(crate) mod tests {
use std::fmt::Formatter;
use std::marker::PhantomData;
use std::sync::Arc;

use salsa::DebugWithDb;
use salsa::ingredient::Ingredient;
use salsa::storage::HasIngredientsFor;
use salsa::{AsId, DebugWithDb};

use ruff_db::file_system::{FileSystem, MemoryFileSystem, OsFileSystem};
use ruff_db::vfs::Vfs;
Expand Down Expand Up @@ -86,7 +97,7 @@ pub(crate) mod tests {
///
/// ## Panics
/// If there are any pending salsa snapshots.
pub(crate) fn take_sale_events(&mut self) -> Vec<salsa::Event> {
pub(crate) fn take_salsa_events(&mut self) -> Vec<salsa::Event> {
let inner = Arc::get_mut(&mut self.events).expect("no pending salsa snapshots");

let events = inner.get_mut().unwrap();
Expand All @@ -98,7 +109,7 @@ pub(crate) mod tests {
/// ## Panics
/// If there are any pending salsa snapshots.
pub(crate) fn clear_salsa_events(&mut self) {
self.take_sale_events();
self.take_salsa_events();
}
}

Expand Down Expand Up @@ -150,4 +161,106 @@ pub(crate) mod tests {
#[allow(unused)]
Os(OsFileSystem),
}

pub(crate) fn assert_will_run_function_query<C, Db, Jar>(
db: &Db,
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
key: C::Key,
events: &[salsa::Event],
) where
C: salsa::function::Configuration<Jar = Jar>
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
Jar: HasIngredientsFor<C>,
Db: salsa::DbWithJar<Jar>,
C::Key: AsId,
{
will_run_function_query(db, to_function, key, events, true);
}

pub(crate) fn assert_will_not_run_function_query<C, Db, Jar>(
db: &Db,
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
key: C::Key,
events: &[salsa::Event],
) where
C: salsa::function::Configuration<Jar = Jar>
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
Jar: HasIngredientsFor<C>,
Db: salsa::DbWithJar<Jar>,
C::Key: AsId,
{
will_run_function_query(db, to_function, key, events, false);
}

fn will_run_function_query<C, Db, Jar>(
db: &Db,
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
key: C::Key,
events: &[salsa::Event],
should_run: bool,
) where
C: salsa::function::Configuration<Jar = Jar>
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
Jar: HasIngredientsFor<C>,
Db: salsa::DbWithJar<Jar>,
C::Key: AsId,
{
let (jar, _) =
<_ as salsa::storage::HasJar<<C as salsa::storage::IngredientsFor>::Jar>>::jar(db);
let ingredient = jar.ingredient();

let function_ingredient = to_function(ingredient);

let ingredient_index =
<salsa::function::FunctionIngredient<C> as Ingredient<Db>>::ingredient_index(
function_ingredient,
);

let did_run = events.iter().any(|event| {
if let salsa::EventKind::WillExecute { database_key } = event.kind {
database_key.ingredient_index() == ingredient_index
&& database_key.key_index() == key.as_id()
} else {
false
}
});

if should_run && !did_run {
panic!(
"Expected query {:?} to run but it didn't",
DebugIdx {
db: PhantomData::<Db>,
value_id: key.as_id(),
ingredient: function_ingredient,
}
);
} else if !should_run && did_run {
panic!(
"Expected query {:?} not to run but it did",
DebugIdx {
db: PhantomData::<Db>,
value_id: key.as_id(),
ingredient: function_ingredient,
}
);
}
}

struct DebugIdx<'a, I, Db>
where
I: Ingredient<Db>,
{
value_id: salsa::Id,
ingredient: &'a I,
db: PhantomData<Db>,
}

impl<'a, I, Db> std::fmt::Debug for DebugIdx<'a, I, Db>
where
I: Ingredient<Db>,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.ingredient.fmt_index(Some(self.value_id), f)
}
}
}
2 changes: 1 addition & 1 deletion crates/ruff_python_semantic/src/module/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ mod tests {
let foo_module2 = resolve_module(&db, foo_module_name);

assert!(!db
.take_sale_events()
.take_salsa_events()
.iter()
.any(|event| { matches!(event.kind, salsa::EventKind::WillExecute { .. }) }));

Expand Down
5 changes: 5 additions & 0 deletions crates/ruff_python_semantic/src/red_knot/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
use rustc_hash::FxHasher;
use std::hash::BuildHasherDefault;

pub mod ast_node_ref;
mod node_key;
pub mod semantic_index;
pub mod types;
pub(crate) type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>;
69 changes: 41 additions & 28 deletions crates/ruff_python_semantic/src/red_knot/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ use ruff_index::{IndexSlice, IndexVec};
use ruff_python_ast as ast;

use crate::red_knot::node_key::NodeKey;
use crate::red_knot::semantic_index::ast_ids::AstIds;
use crate::red_knot::semantic_index::ast_ids::{AstId, AstIds, ScopeClassId, ScopeFunctionId};
use crate::red_knot::semantic_index::builder::SemanticIndexBuilder;
use crate::red_knot::semantic_index::symbol::{
FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeSymbolId, ScopesMap, SymbolTable,
FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable,
};
use crate::Db;

Expand All @@ -21,7 +21,7 @@ mod builder;
pub mod definition;
pub mod symbol;

type SymbolMap = hashbrown::HashMap<ScopeSymbolId, (), ()>;
type SymbolMap = hashbrown::HashMap<ScopedSymbolId, (), ()>;

/// Returns the semantic index for `file`.
///
Expand All @@ -42,33 +42,22 @@ pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex {
pub(crate) fn symbol_table(db: &dyn Db, scope: ScopeId) -> Arc<SymbolTable> {
let index = semantic_index(db, scope.file(db));

index.symbol_table(scope.scope_id(db))
}

/// Returns a mapping from file specific [`FileScopeId`] to a program-wide unique [`ScopeId`].
#[salsa::tracked(return_ref)]
pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap {
let index = semantic_index(db, file);

let scopes: IndexVec<_, _> = index
.scopes
.indices()
.map(|id| ScopeId::new(db, file, id))
.collect();

ScopesMap::new(scopes)
index.symbol_table(scope.file_scope_id(db))
}

/// Returns the root scope of `file`.
pub fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId {
#[salsa::tracked]
pub(crate) fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId {
FileScopeId::root().to_scope_id(db, file)
}

/// Returns the symbol with the given name in `file`'s public scope or `None` if
/// no symbol with the given name exists.
pub fn global_symbol(db: &dyn Db, file: VfsFile, name: &str) -> Option<PublicSymbolId> {
pub fn public_symbol(db: &dyn Db, file: VfsFile, name: &str) -> Option<PublicSymbolId> {
let root_scope = root_scope(db, file);
root_scope.symbol(db, name)
let symbol_table = symbol_table(db, root_scope);
let local = symbol_table.symbol_id_by_name(name)?;
Some(local.to_public_symbol(db, file))
}

/// The symbol tables for an entire file.
Expand All @@ -90,14 +79,17 @@ pub struct SemanticIndex {
/// Note: We should not depend on this map when analysing other files or
/// changing a file invalidates all dependents.
ast_ids: IndexVec<FileScopeId, AstIds>,

/// Map from scope to the node that introduces the scope.
scope_nodes: IndexVec<FileScopeId, NodeWithScopeId>,
}

impl SemanticIndex {
/// Returns the symbol table for a specific scope.
///
/// Use the Salsa cached [`symbol_table`] query if you only need the
/// symbol table for a single scope.
fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
self.symbol_tables[scope_id].clone()
}

Expand Down Expand Up @@ -152,6 +144,10 @@ impl SemanticIndex {
pub(crate) fn ancestor_scopes(&self, scope: FileScopeId) -> AncestorsIter {
AncestorsIter::new(self, scope)
}

pub(crate) fn scope_node(&self, scope_id: FileScopeId) -> NodeWithScopeId {
self.scope_nodes[scope_id]
}
}

/// ID that uniquely identifies an expression inside a [`Scope`].
Expand Down Expand Up @@ -246,6 +242,28 @@ impl<'a> Iterator for ChildrenIter<'a> {
}
}

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(crate) enum NodeWithScopeId {
Module,
Class(AstId<ScopeClassId>),
ClassTypeParams(AstId<ScopeClassId>),
Function(AstId<ScopeFunctionId>),
FunctionTypeParams(AstId<ScopeFunctionId>),
}

impl NodeWithScopeId {
fn scope_kind(self) -> ScopeKind {
match self {
NodeWithScopeId::Module => ScopeKind::Module,
NodeWithScopeId::Class(_) => ScopeKind::Class,
NodeWithScopeId::Function(_) => ScopeKind::Function,
NodeWithScopeId::ClassTypeParams(_) | NodeWithScopeId::FunctionTypeParams(_) => {
ScopeKind::Annotation
}
}
}
}

impl FusedIterator for ChildrenIter<'_> {}

#[cfg(test)]
Expand Down Expand Up @@ -583,19 +601,14 @@ class C[T]:
let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4");

let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
let parsed = parsed_module(&db, file);
let ast = parsed.syntax();

let x_sym = root_table
.symbol_by_name("x")
.expect("x symbol should exist");

let x_stmt = ast.body[0].as_assign_stmt().unwrap();
let x = &x_stmt.targets[0];

assert_eq!(index.expression_scope(x).kind(), ScopeKind::Module);
assert_eq!(index.expression_scope_id(x), x_sym.scope());
assert_eq!(index.expression_scope_id(x), FileScopeId::root());

let def = ast.body[1].as_function_def_stmt().unwrap();
let y_stmt = def.body[0].as_assign_stmt().unwrap();
Expand Down
Loading

0 comments on commit 22733cb

Please sign in to comment.