diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md index 8132de6ae6bfc9..02586698be0e52 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md @@ -880,3 +880,31 @@ def f(x: T) -> T: def g(x: S) -> S: return f(x) # error: [invalid-argument-type] ``` + +## Inferring typevars in iterable parameters from literal string and bytes arguments + +```py +from typing import Iterable, TypeVar +from typing_extensions import LiteralString + +FlatT = TypeVar("FlatT") + +def flatten(*iterables: Iterable[FlatT]) -> list[FlatT]: + return [x for iterable in iterables for x in iterable] + +def flatten_covariant(*iterables: Iterable[FlatT]) -> tuple[FlatT, ...]: + return tuple(x for iterable in iterables for x in iterable) + +reveal_type(flatten("abc", (1, 2, 3))) # revealed: list[str | int] +# TODO: we could have `Literal["a", "b", "c"]` instead of `str` here +reveal_type(flatten_covariant("abc", (1, 2, 3))) # revealed: tuple[str | Literal[1, 2, 3], ...] + +def literal_string_case(literal_string: LiteralString): + reveal_type(flatten(literal_string, (1, 2, 3))) # revealed: list[str | int] + +reveal_type(flatten(b"abc")) # revealed: list[int] +reveal_type(flatten(b"abc", ("x",))) # revealed: list[int | str] +# TODO: we could have `Literal[97, 98, 99]` instead of `int` in the next two lines +reveal_type(flatten_covariant(b"abc")) # revealed: tuple[int, ...] +reveal_type(flatten_covariant(b"abc", ("x",))) # revealed: tuple[int | Literal["x"], ...] +``` diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 23504924248ec6..04167b7fedc311 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -2254,6 +2254,19 @@ impl<'db> SpecializationBuilder<'db> { } } + ( + formal @ (Type::NominalInstance(_) | Type::ProtocolInstance(_)), + actual_literal @ (Type::StringLiteral(_) + | Type::LiteralString + | Type::BytesLiteral(_)), + ) => { + // Retry specialization with the literal's fallback instance (`str` / `bytes`) + // so literal iterables can contribute to generic inference. + if let Some(actual_instance) = actual_literal.literal_fallback_instance(self.db) { + return self.infer_map_impl(formal, actual_instance, polarity, f, seen); + } + } + (formal, Type::ProtocolInstance(actual_protocol)) => { // TODO: This will only handle protocol classes that explicit inherit // from other generic protocol classes by listing it as a base class.