Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,82 @@ m.name = "new" # No error
reveal_type(Mutable(name="A") < Mutable(name="B")) # revealed: bool
```

## Other `dataclass` parameters

Other parameters from normal dataclasses can also be set on models created using
`dataclass_transform`.

### Using function-based transformers

```py
from typing_extensions import dataclass_transform, TypeVar, Callable

T = TypeVar("T", bound=type)

@dataclass_transform()
def fancy_model(*, slots: bool = False) -> Callable[[T], T]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't slots=False the default default? Wouldn't a more interesting test be this?

Suggested change
def fancy_model(*, slots: bool = False) -> Callable[[T], T]:
def fancy_model(*, slots: bool = True) -> Callable[[T], T]:

(The same applies to your other examples below!)

Copy link
Contributor Author

@sharkdp sharkdp Nov 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default parameter here has no actual effect on anything. slots=False is the default behavior and there is no way to override this for a dataclass-transformer. There is no slots_default parameter for the dataclass_transform call.

For the parameters that can actually change their defaults (like kw_only), we already have tests elsewhere in this file. Here, we just test that we can also modify something like slots on an actual model/dataclass (not on the transformer / the template).

raise NotImplementedError

@fancy_model()
class NoSlots:
name: str

NoSlots.__slots__ # error: [unresolved-attribute]

@fancy_model(slots=True)
class WithSlots:
name: str

reveal_type(WithSlots.__slots__) # revealed: tuple[Literal["name"]]
```

### Using metaclass-based transformers

```py
from typing_extensions import dataclass_transform

@dataclass_transform()
class FancyMeta(type):
def __new__(cls, name, bases, namespace, *, slots: bool = False):
...
return super().__new__(cls, name, bases, namespace)

class FancyBase(metaclass=FancyMeta): ...

class NoSlots(FancyBase):
name: str

# error: [unresolved-attribute]
NoSlots.__slots__

class WithSlots(FancyBase, slots=True):
name: str

reveal_type(WithSlots.__slots__) # revealed: tuple[Literal["name"]]
```

### Using base-class-based transformers

```py
from typing_extensions import dataclass_transform

@dataclass_transform()
class FancyBase:
def __init_subclass__(cls, *, slots: bool = False):
...
super().__init_subclass__()

class NoSlots(FancyBase):
name: str

NoSlots.__slots__ # error: [unresolved-attribute]

class WithSlots(FancyBase, slots=True):
name: str

reveal_type(WithSlots.__slots__) # revealed: tuple[Literal["name"]]
```

## `field_specifiers`

The `field_specifiers` argument can be used to specify a list of functions that should be treated
Expand Down
13 changes: 13 additions & 0 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,19 @@ bitflags! {
}
}

pub(crate) const DATACLASS_FLAGS: &[(&str, DataclassFlags)] = &[
("init", DataclassFlags::INIT),
("repr", DataclassFlags::REPR),
("eq", DataclassFlags::EQ),
("order", DataclassFlags::ORDER),
("unsafe_hash", DataclassFlags::UNSAFE_HASH),
("frozen", DataclassFlags::FROZEN),
("match_args", DataclassFlags::MATCH_ARGS),
("kw_only", DataclassFlags::KW_ONLY),
("slots", DataclassFlags::SLOTS),
("weakref_slot", DataclassFlags::WEAKREF_SLOT),
];

impl get_size2::GetSize for DataclassFlags {}

impl Default for DataclassFlags {
Expand Down
38 changes: 11 additions & 27 deletions crates/ty_python_semantic/src/types/call/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ use crate::types::generics::{
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
use crate::types::tuple::{TupleLength, TupleType};
use crate::types::{
BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams,
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType, WrapperDescriptorKind,
enums, ide_support, infer_isolated_expression, todo_type,
BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DATACLASS_FLAGS, DataclassFlags,
DataclassParams, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType,
MemberLookupPolicy, NominalInstanceType, PropertyInstanceType, SpecialFormType,
TrackedConstraintSet, TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType,
WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, todo_type,
};
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
Expand Down Expand Up @@ -1134,28 +1134,12 @@ impl<'db> Bindings<'db> {
);
let mut flags = dataclass_params.flags(db);

if let Ok(Some(Type::BooleanLiteral(order))) =
overload.parameter_type_by_name("order", false)
{
flags.set(DataclassFlags::ORDER, order);
}

if let Ok(Some(Type::BooleanLiteral(eq))) =
overload.parameter_type_by_name("eq", false)
{
flags.set(DataclassFlags::EQ, eq);
}

if let Ok(Some(Type::BooleanLiteral(kw_only))) =
overload.parameter_type_by_name("kw_only", false)
{
flags.set(DataclassFlags::KW_ONLY, kw_only);
}

if let Ok(Some(Type::BooleanLiteral(frozen))) =
overload.parameter_type_by_name("frozen", false)
{
flags.set(DataclassFlags::FROZEN, frozen);
for (param, flag) in DATACLASS_FLAGS {
if let Ok(Some(Type::BooleanLiteral(value))) =
overload.parameter_type_by_name(param, false)
{
flags.set(*flag, value);
}
}

Type::DataclassDecorator(DataclassParams::new(
Expand Down
24 changes: 11 additions & 13 deletions crates/ty_python_semantic/src/types/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ use crate::types::tuple::{TupleSpec, TupleType};
use crate::types::typed_dict::typed_dict_params_from_class_def;
use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard};
use crate::types::{
ApplyTypeMappingVisitor, Binding, BoundSuperType, CallableType, DataclassFlags,
DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor,
IsDisjointVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType,
MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType,
TypeContext, TypeMapping, TypeRelation, TypedDictParams, UnionBuilder, VarianceInferable,
declaration_type, determine_upper_bound, exceeds_max_specialization_depth,
infer_definition_types,
ApplyTypeMappingVisitor, Binding, BoundSuperType, CallableType, DATACLASS_FLAGS,
DataclassFlags, DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor,
HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownInstanceType,
ManualPEP695TypeAliasType, MaterializationKind, NormalizedVisitor, PropertyInstanceType,
StringLiteralType, TypeAliasType, TypeContext, TypeMapping, TypeRelation, TypedDictParams,
UnionBuilder, VarianceInferable, declaration_type, determine_upper_bound,
exceeds_max_specialization_depth, infer_definition_types,
};
use crate::{
Db, FxIndexMap, FxIndexSet, FxOrderSet, Program,
Expand Down Expand Up @@ -2229,12 +2229,10 @@ impl<'db> ClassLiteral<'db> {
if let Some(is_set) =
keyword.value.as_boolean_literal_expr().map(|b| b.value)
{
match arg_name.as_str() {
"eq" => flags.set(DataclassFlags::EQ, is_set),
"order" => flags.set(DataclassFlags::ORDER, is_set),
"kw_only" => flags.set(DataclassFlags::KW_ONLY, is_set),
"frozen" => flags.set(DataclassFlags::FROZEN, is_set),
_ => {}
for (flag_name, flag) in DATACLASS_FLAGS {
if arg_name.as_str() == *flag_name {
flags.set(*flag, is_set);
}
}
}
}
Expand Down
Loading