diff --git a/crates/ty_ide/src/semantic_tokens.rs b/crates/ty_ide/src/semantic_tokens.rs index 8ddf2ba34de5d..aa74cc29cceb1 100644 --- a/crates/ty_ide/src/semantic_tokens.rs +++ b/crates/ty_ide/src/semantic_tokens.rs @@ -442,6 +442,31 @@ impl<'db> SemanticTokenVisitor<'db> { ty: Type, attr_name: &ast::Identifier, ) -> (SemanticTokenType, SemanticTokenModifier) { + enum UnifiedTokenType { + None, + /// All types have the same semantic token type + Uniform(SemanticTokenType), + /// The elements have different semantic token types + Fallback, + } + + impl UnifiedTokenType { + fn add(&mut self, ty: SemanticTokenType) { + *self = match self { + Self::None => Self::Uniform(ty), + Self::Uniform(current) if *current == ty => Self::Uniform(ty), + Self::Uniform(_) | Self::Fallback => Self::Fallback, + } + } + + fn into_semantic_token_type(self) -> Option { + match self { + UnifiedTokenType::None | UnifiedTokenType::Fallback => None, + UnifiedTokenType::Uniform(ty) => Some(ty), + } + } + } + let attr_name_str = attr_name.id.as_str(); let mut modifiers = SemanticTokenModifier::empty(); @@ -449,35 +474,53 @@ impl<'db> SemanticTokenVisitor<'db> { return classification; } - // Classify based on the inferred type of the attribute - match ty { - Type::ClassLiteral(_) => (SemanticTokenType::Class, modifiers), - Type::FunctionLiteral(_) => { - // This is a function accessed as an attribute, likely a method - (SemanticTokenType::Method, modifiers) - } - Type::BoundMethod(_) => { - // Method bound to an instance - (SemanticTokenType::Method, modifiers) - } - Type::ModuleLiteral(_) => { - // Module accessed as an attribute (e.g., from os import path) - (SemanticTokenType::Namespace, modifiers) - } - _ if ty.is_property_instance() => { - // Actual Python property - (SemanticTokenType::Property, modifiers) - } - _ => { - // Check for constant naming convention - if Self::is_constant_name(attr_name_str) { - modifiers |= SemanticTokenModifier::READONLY; - } + let elements = if let Some(union) = ty.as_union() { + union.elements(self.model.db()) + } else { + std::slice::from_ref(&ty) + }; - // For other types (variables, constants, etc.), classify as variable - (SemanticTokenType::Variable, modifiers) + let mut token_type = UnifiedTokenType::None; + + for element in elements { + // Classify based on the inferred type of the attribute + match element { + Type::ClassLiteral(_) => { + token_type.add(SemanticTokenType::Class); + } + Type::FunctionLiteral(_) => { + // This is a function accessed as an attribute, likely a method + token_type.add(SemanticTokenType::Method); + } + Type::BoundMethod(_) | Type::KnownBoundMethod(_) => { + // Method bound to an instance + token_type.add(SemanticTokenType::Method); + } + Type::ModuleLiteral(_) => { + // Module accessed as an attribute (e.g., from os import path) + token_type.add(SemanticTokenType::Namespace); + } + ty if ty.is_property_instance() => { + token_type.add(SemanticTokenType::Property); + } + _ => { + token_type = UnifiedTokenType::Fallback; + } } } + + if let Some(uniform) = token_type.into_semantic_token_type() { + return (uniform, modifiers); + } + + // Check for constant naming convention + if Self::is_constant_name(attr_name_str) { + modifiers |= SemanticTokenModifier::READONLY; + } + + // For other types (variables, constants, etc.), classify as variable + // Should this always be property? + (SemanticTokenType::Variable, modifiers) } fn classify_parameter( @@ -1819,6 +1862,7 @@ z = obj.CONSTANT # CONSTANT should be variable with readonly modifier w = obj.prop # prop should be property v = MyClass.method # method should be method (function) u = List.__name__ # __name__ should be variable +t = MyClass.prop # prop should be property on the class itself ", ); @@ -1862,6 +1906,9 @@ u = List.__name__ # __name__ should be variable "u" @ 596..597: Variable [definition] "List" @ 600..604: Variable "__name__" @ 605..613: Variable + "t" @ 651..652: Variable [definition] + "MyClass" @ 655..662: Class + "prop" @ 663..667: Property "#); } @@ -1896,6 +1943,264 @@ y = obj.unknown_attr # Should fall back to variable "#); } + #[test] + fn attribute_on_union_1() { + let test = SemanticTokenTest::new( + " +from random import random + +class Foo: + CONSTANT = 42 + + def method(self): + return \"hello\" + + @property + def prop(self) -> str: + return \"hello\" + +class Bar: + CONSTANT = 24 + + def method(self, x: int = 1) -> int: + return 42 + + @property + def prop(self) -> int: + return self.CONSTANT + + +foobar = Foo() if random() else Bar() +y = foobar.method # method should be method (bound method) +z = foobar.CONSTANT # CONSTANT should be variable with readonly modifier +w = foobar.prop # prop should be property +foobar_cls = Foo if random() else Bar +v = foobar_cls.method # method should be method (function) +x = foobar_cls.prop # prop should be property +", + ); + + let tokens = test.highlight_file(); + + assert_snapshot!(test.to_snapshot(&tokens), @r#" + "random" @ 6..12: Namespace + "random" @ 20..26: Method + "Foo" @ 34..37: Class [definition] + "CONSTANT" @ 43..51: Variable [definition, readonly] + "42" @ 54..56: Number + "method" @ 66..72: Method [definition] + "self" @ 73..77: SelfParameter [definition] + "\"hello\"" @ 95..102: String + "property" @ 109..117: Decorator + "prop" @ 126..130: Method [definition] + "self" @ 131..135: SelfParameter [definition] + "str" @ 140..143: Class + "\"hello\"" @ 160..167: String + "Bar" @ 175..178: Class [definition] + "CONSTANT" @ 184..192: Variable [definition, readonly] + "24" @ 195..197: Number + "method" @ 207..213: Method [definition] + "self" @ 214..218: SelfParameter [definition] + "x" @ 220..221: Parameter [definition] + "int" @ 223..226: Class + "1" @ 229..230: Number + "int" @ 235..238: Class + "42" @ 255..257: Number + "property" @ 264..272: Decorator + "prop" @ 281..285: Method [definition] + "self" @ 286..290: SelfParameter [definition] + "int" @ 295..298: Class + "self" @ 315..319: SelfParameter + "CONSTANT" @ 320..328: Variable [readonly] + "foobar" @ 331..337: Variable [definition] + "Foo" @ 340..343: Class + "random" @ 349..355: Variable + "Bar" @ 363..366: Class + "y" @ 369..370: Variable [definition] + "foobar" @ 373..379: Variable + "method" @ 380..386: Method + "z" @ 459..460: Variable [definition] + "foobar" @ 463..469: Variable + "CONSTANT" @ 470..478: Variable [readonly] + "w" @ 561..562: Variable [definition] + "foobar" @ 565..571: Variable + "prop" @ 572..576: Variable + "foobar_cls" @ 636..646: Variable [definition] + "Foo" @ 649..652: Class + "random" @ 656..662: Variable + "Bar" @ 670..673: Class + "v" @ 674..675: Variable [definition] + "foobar_cls" @ 678..688: Variable + "method" @ 689..695: Method + "x" @ 760..761: Variable [definition] + "foobar_cls" @ 764..774: Variable + "prop" @ 775..779: Property + "#); + } + + #[test] + fn attribute_on_union_2() { + let test = SemanticTokenTest::new( + " +from random import random + +# There is also this way to create union types: +class Baz: + if random(): + CONSTANT = 42 + + def method(self) -> int: + return 42 + + @property + def prop(self) -> int: + return 42 + else: + CONSTANT = \"hello\" + + def method(self) -> str: + return \"hello\" + + @property + def prop(self) -> str: + return \"hello\" + +baz = Baz() +s = baz.method # method should be bound method +t = baz.CONSTANT # CONSTANT should be variable with readonly +r = baz.prop # prop should be property +q = Baz.prop # prop should be property on the class as well +", + ); + + let tokens = test.highlight_file(); + + assert_snapshot!(test.to_snapshot(&tokens), @r#" + "random" @ 6..12: Namespace + "random" @ 20..26: Method + "Baz" @ 82..85: Class [definition] + "random" @ 94..100: Variable + "CONSTANT" @ 112..120: Variable [definition, readonly] + "42" @ 123..125: Number + "method" @ 139..145: Method [definition] + "self" @ 146..150: SelfParameter [definition] + "int" @ 155..158: Class + "42" @ 179..181: Number + "property" @ 192..200: Decorator + "prop" @ 213..217: Method [definition] + "self" @ 218..222: SelfParameter [definition] + "int" @ 227..230: Class + "42" @ 251..253: Number + "CONSTANT" @ 272..280: Variable [definition, readonly] + "\"hello\"" @ 283..290: String + "method" @ 304..310: Method [definition] + "self" @ 311..315: SelfParameter [definition] + "str" @ 320..323: Class + "\"hello\"" @ 344..351: String + "property" @ 362..370: Decorator + "prop" @ 383..387: Method [definition] + "self" @ 388..392: SelfParameter [definition] + "str" @ 397..400: Class + "\"hello\"" @ 421..428: String + "baz" @ 430..433: Variable [definition] + "Baz" @ 436..439: Class + "s" @ 442..443: Variable [definition] + "baz" @ 446..449: Variable + "method" @ 450..456: Method + "t" @ 494..495: Variable [definition] + "baz" @ 498..501: Variable + "CONSTANT" @ 502..510: Variable [readonly] + "r" @ 558..559: Variable [definition] + "baz" @ 562..565: Variable + "prop" @ 566..570: Variable + "q" @ 604..605: Variable [definition] + "Baz" @ 608..611: Class + "prop" @ 612..616: Property + "#); + } + + #[test] + fn attribute_on_union_3() { + // This is a test where the unions are not actually composed of the same elements, + // so the regular fallback logic should apply. + let test = SemanticTokenTest::new( + " +from random import random + +class Baz: + if random(): + CONSTANT = 42 + + def method(self) -> int: + return 42 + + @property + def prop(self) -> int: + return 42 + else: + def CONSTANT(self): + return \"hello\" + + @property + def method(self) -> str: + return \"hello\" + + prop: str = \"hello\" + +baz = Baz() +s = baz.method +t = baz.CONSTANT +r = baz.prop +q = Baz.prop +", + ); + + let tokens = test.highlight_file(); + + assert_snapshot!(test.to_snapshot(&tokens), @r#" + "random" @ 6..12: Namespace + "random" @ 20..26: Method + "Baz" @ 34..37: Class [definition] + "random" @ 46..52: Variable + "CONSTANT" @ 64..72: Variable [definition, readonly] + "42" @ 75..77: Number + "method" @ 91..97: Method [definition] + "self" @ 98..102: SelfParameter [definition] + "int" @ 107..110: Class + "42" @ 131..133: Number + "property" @ 144..152: Decorator + "prop" @ 165..169: Method [definition] + "self" @ 170..174: SelfParameter [definition] + "int" @ 179..182: Class + "42" @ 203..205: Number + "CONSTANT" @ 228..236: Method [definition] + "self" @ 237..241: SelfParameter [definition] + "\"hello\"" @ 263..270: String + "property" @ 281..289: Decorator + "method" @ 302..308: Method [definition] + "self" @ 309..313: SelfParameter [definition] + "str" @ 318..321: Class + "\"hello\"" @ 342..349: String + "prop" @ 359..363: Method [definition] + "str" @ 365..368: Class + "\"hello\"" @ 371..378: String + "baz" @ 380..383: Variable [definition] + "Baz" @ 386..389: Class + "s" @ 392..393: Variable [definition] + "baz" @ 396..399: Variable + "method" @ 400..406: Variable + "t" @ 408..409: Variable [definition] + "baz" @ 412..415: Variable + "CONSTANT" @ 416..424: Variable [readonly] + "r" @ 425..426: Variable [definition] + "baz" @ 429..432: Variable + "prop" @ 433..437: Variable + "q" @ 438..439: Variable [definition] + "Baz" @ 442..445: Class + "prop" @ 446..450: Variable + "#); + } + #[test] fn constant_name_detection() { let test = SemanticTokenTest::new( diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index e4846954c09b6..cef6a1ce0db97 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1306,7 +1306,7 @@ impl<'db> Type<'db> { matches!(self, Type::Union(_)) } - pub(crate) const fn as_union(self) -> Option> { + pub const fn as_union(self) -> Option> { match self { Type::Union(union_type) => Some(union_type), _ => None,