Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions crates/red_knot_python_semantic/src/ast_node_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,13 @@ use ruff_db::parsed::ParsedModule;
/// This means that changes to expressions in other scopes don't invalidate the expression's id, giving
/// us some form of scope-stable identity for expressions. Only queries accessing the node field
/// run on every AST change. All other queries only run when the expression's identity changes.
///
/// The one exception to this is if it is known that all queries tacking the tracked struct
/// as argument or returning it as part of their result are known to access the node field.
/// Marking the field tracked is then unnecessary.
#[derive(Clone)]
pub struct AstNodeRef<T> {
/// Owned reference to the node's [`ParsedModule`].
///
/// The node's reference is guaranteed to remain valid as long as it's enclosing
/// [`ParsedModule`] is alive.
_parsed: ParsedModule,
parsed: ParsedModule,

/// Pointer to the referenced node.
node: std::ptr::NonNull<T>,
Expand All @@ -59,7 +55,7 @@ impl<T> AstNodeRef<T> {
/// the invariant `node belongs to parsed` is upheld.
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
Self {
_parsed: parsed,
parsed,
node: std::ptr::NonNull::from(node),
}
}
Expand Down Expand Up @@ -89,17 +85,44 @@ where
}
}

impl<T> PartialEq for AstNodeRef<T> {
impl<T> PartialEq for AstNodeRef<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.node.eq(&other.node)
if self.parsed == other.parsed {
// Comparing the pointer addresses is sufficient to determine equality
// if the parsed are the same.
self.node.eq(&other.node)
} else {
// Otherwise perform a deep comparison.
self.node().eq(other.node())
}
}
}

impl<T> Eq for AstNodeRef<T> {}
impl<T> Eq for AstNodeRef<T> where T: Eq {}

impl<T> Hash for AstNodeRef<T> {
impl<T> Hash for AstNodeRef<T>
where
T: Hash,
{
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.node.hash(state);
self.node().hash(state);
}
}

#[allow(unsafe_code)]
unsafe impl<T> salsa::Update for AstNodeRef<T> {
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
let old_ref = &mut (*old_pointer);

if old_ref.parsed == new_value.parsed && old_ref.node.eq(&new_value.node) {
false
} else {
*old_ref = new_value;
true
}
}
}

Expand Down Expand Up @@ -133,7 +156,7 @@ mod tests {
let stmt_cloned = &cloned.syntax().body[0];
let cloned_node = unsafe { AstNodeRef::new(cloned.clone(), stmt_cloned) };

assert_ne!(node1, cloned_node);
assert_eq!(node1, cloned_node);

let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
let other = ParsedModule::new(other_raw);
Expand Down
16 changes: 8 additions & 8 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ impl DefinitionCategory {
/// [`DefinitionKind`] fields in salsa tracked structs should be tracked (attributed with `#[tracked]`)
/// because the kind is a thin wrapper around [`AstNodeRef`]. See the [`AstNodeRef`] documentation
/// for an in-depth explanation of why this is necessary.
#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub enum DefinitionKind<'db> {
Import(AstNodeRef<ast::Alias>),
ImportFrom(ImportFromDefinitionKind),
Expand Down Expand Up @@ -559,7 +559,7 @@ impl<'db> From<Option<Unpack<'db>>> for TargetKind<'db> {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
#[allow(dead_code)]
pub struct MatchPatternDefinitionKind {
pattern: AstNodeRef<ast::Pattern>,
Expand All @@ -577,7 +577,7 @@ impl MatchPatternDefinitionKind {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub struct ComprehensionDefinitionKind {
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
Expand All @@ -603,7 +603,7 @@ impl ComprehensionDefinitionKind {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub struct ImportFromDefinitionKind {
node: AstNodeRef<ast::StmtImportFrom>,
alias_index: usize,
Expand All @@ -619,7 +619,7 @@ impl ImportFromDefinitionKind {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub struct AssignmentDefinitionKind<'db> {
target: TargetKind<'db>,
value: AstNodeRef<ast::Expr>,
Expand All @@ -645,7 +645,7 @@ impl<'db> AssignmentDefinitionKind<'db> {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub struct WithItemDefinitionKind {
node: AstNodeRef<ast::WithItem>,
target: AstNodeRef<ast::ExprName>,
Expand All @@ -666,7 +666,7 @@ impl WithItemDefinitionKind {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub struct ForStmtDefinitionKind<'db> {
target: TargetKind<'db>,
iterable: AstNodeRef<ast::Expr>,
Expand Down Expand Up @@ -697,7 +697,7 @@ impl<'db> ForStmtDefinitionKind<'db> {
}
}

#[derive(Clone, Debug, Hash)]
#[derive(Clone, Debug)]
pub struct ExceptHandlerDefinitionKind {
handler: AstNodeRef<ast::ExceptHandlerExceptHandler>,
is_star: bool,
Expand Down
1 change: 1 addition & 0 deletions crates/red_knot_python_semantic/src/unpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub(crate) struct Unpack<'db> {
/// expression is `(a, b)`.
#[no_eq]
#[return_ref]
#[tracked]
pub(crate) target: AstNodeRef<ast::Expr>,

/// The ingredient representing the value expression of the unpacking. For example, in
Expand Down
Loading