Skip to content

Commit 7ed245b

Browse files
committed
Implemented support for PEP 655: Marking individual TypedDict items as required or potentially-missing.
1 parent e0a9334 commit 7ed245b

File tree

12 files changed

+236
-8
lines changed

12 files changed

+236
-8
lines changed

packages/pyright-internal/src/analyzer/binder.ts

+26-2
Original file line numberDiff line numberDiff line change
@@ -1426,7 +1426,7 @@ export class Binder extends ParseTreeWalker {
14261426
}
14271427

14281428
visitImportFrom(node: ImportFromNode): boolean {
1429-
const typingSymbolsOfInterest = ['Final', 'TypeAlias', 'ClassVar'];
1429+
const typingSymbolsOfInterest = ['Final', 'TypeAlias', 'ClassVar', 'Required', 'NotRequired'];
14301430
const importInfo = AnalyzerNodeInfo.getImportInfo(node.module);
14311431

14321432
let resolvedPath = '';
@@ -2857,6 +2857,8 @@ export class Binder extends ParseTreeWalker {
28572857
node: target,
28582858
isConstant: isConstantName(name.value),
28592859
isFinal: finalInfo.isFinal,
2860+
isRequired: this._isRequiredAnnotation(typeAnnotationNode),
2861+
isNotRequired: this._isNotRequiredAnnotation(typeAnnotationNode),
28602862
typeAliasAnnotation: isExplicitTypeAlias ? typeAnnotation : undefined,
28612863
typeAliasName: isExplicitTypeAlias ? target : undefined,
28622864
path: this._fileInfo.filePath,
@@ -2965,7 +2967,7 @@ export class Binder extends ParseTreeWalker {
29652967
}
29662968

29672969
// Determines if the specified type annotation expression is a "Final".
2968-
// It returns two boolean values indicating if the expression is a "Final"
2970+
// It returns a value indicating whether the expression is a "Final"
29692971
// expression and whether it's a "raw" Final with no type arguments.
29702972
private _isAnnotationFinal(typeAnnotation: ExpressionNode | undefined): FinalInfo {
29712973
let isFinal = false;
@@ -2992,6 +2994,28 @@ export class Binder extends ParseTreeWalker {
29922994
return { isFinal, finalTypeNode };
29932995
}
29942996

2997+
// Determines if the specified type annotation is wrapped in a "Required".
2998+
private _isRequiredAnnotation(typeAnnotation: ExpressionNode | undefined): boolean {
2999+
if (typeAnnotation && typeAnnotation.nodeType === ParseNodeType.Index && typeAnnotation.items.length === 1) {
3000+
if (this._isTypingAnnotation(typeAnnotation.baseExpression, 'Required')) {
3001+
return true;
3002+
}
3003+
}
3004+
3005+
return false;
3006+
}
3007+
3008+
// Determines if the specified type annotation is wrapped in a "NotRequired".
3009+
private _isNotRequiredAnnotation(typeAnnotation: ExpressionNode | undefined): boolean {
3010+
if (typeAnnotation && typeAnnotation.nodeType === ParseNodeType.Index && typeAnnotation.items.length === 1) {
3011+
if (this._isTypingAnnotation(typeAnnotation.baseExpression, 'NotRequired')) {
3012+
return true;
3013+
}
3014+
}
3015+
3016+
return false;
3017+
}
3018+
29953019
private _isAnnotationTypeAlias(typeAnnotation: ExpressionNode | undefined) {
29963020
if (!typeAnnotation) {
29973021
return false;

packages/pyright-internal/src/analyzer/declaration.ts

+6
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ export interface VariableDeclaration extends DeclarationBase {
112112
// constant in that reassignment is not permitted)?
113113
isFinal?: boolean;
114114

115+
// Is the declaration annotated with "Required"?
116+
isRequired?: boolean;
117+
118+
// Is the declaration annotated with "NotRequired"?
119+
isNotRequired?: boolean;
120+
115121
// Points to the "TypeAlias" annotation described in PEP 613.
116122
typeAliasAnnotation?: ExpressionNode;
117123

packages/pyright-internal/src/analyzer/symbol.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ export const enum SymbolFlags {
4444
// set when accessed through a class instance.
4545
ClassVar = 1 << 7,
4646

47-
// // Indicates that the symbol is in __all__.
47+
// Indicates that the symbol is in __all__.
4848
InDunderAll = 1 << 8,
4949
}
5050

packages/pyright-internal/src/analyzer/symbolUtils.ts

+8
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,11 @@ export function isTypedDictMemberAccessedThroughIndex(symbol: Symbol): boolean {
3939
export function isFinalVariable(symbol: Symbol): boolean {
4040
return symbol.getDeclarations().some((decl) => isFinalVariableDeclaration(decl));
4141
}
42+
43+
export function isRequiredTypedDictVariable(symbol: Symbol) {
44+
return symbol.getDeclarations().some((decl) => decl.type === DeclarationType.Variable && !!decl.isRequired);
45+
}
46+
47+
export function isNotRequiredTypedDictVariable(symbol: Symbol) {
48+
return symbol.getDeclarations().some((decl) => decl.type === DeclarationType.Variable && !!decl.isNotRequired);
49+
}

packages/pyright-internal/src/analyzer/typeEvaluator.ts

+66-5
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ import * as ScopeUtils from './scopeUtils';
115115
import { evaluateStaticBoolExpression } from './staticExpressions';
116116
import { indeterminateSymbolId, Symbol, SymbolFlags } from './symbol';
117117
import { isConstantName, isDunderName, isPrivateOrProtectedName } from './symbolNameUtils';
118-
import { getLastTypedDeclaredForSymbol, isFinalVariable } from './symbolUtils';
118+
import { getLastTypedDeclaredForSymbol, isFinalVariable, isNotRequiredTypedDictVariable, isRequiredTypedDictVariable } from './symbolUtils';
119119
import { PrintableType, TracePrinter } from './tracePrinter';
120120
import {
121121
CachedType,
@@ -10220,6 +10220,51 @@ export function createTypeEvaluator(
1022010220
return ClassType.cloneForSpecialization(classType, [convertToInstance(typeArg)], !!typeArgs);
1022110221
}
1022210222

10223+
function createRequiredType(
10224+
classType: ClassType,
10225+
errorNode: ParseNode,
10226+
isRequired: boolean,
10227+
typeArgs: TypeResult[] | undefined
10228+
): Type {
10229+
if (!typeArgs || typeArgs.length !== 1) {
10230+
addError(
10231+
isRequired ? Localizer.Diagnostic.requiredArgCount() : Localizer.Diagnostic.notRequiredArgCount(),
10232+
errorNode
10233+
);
10234+
return classType;
10235+
}
10236+
10237+
const typeArgType = typeArgs[0].type;
10238+
10239+
// Make sure this is used only in a dataclass.
10240+
const containingClassNode = ParseTreeUtils.getEnclosingClass(errorNode, /* stopAtFunction */ true);
10241+
const classTypeInfo = containingClassNode ? getTypeOfClass(containingClassNode) : undefined;
10242+
10243+
let isUsageLegal = false;
10244+
10245+
if (classTypeInfo && isClass(classTypeInfo.classType) && ClassType.isTypedDictClass(classTypeInfo.classType)) {
10246+
// The only legal usage is when used in a type annotation statement.
10247+
if (
10248+
errorNode.parent?.nodeType === ParseNodeType.TypeAnnotation &&
10249+
errorNode.parent.typeAnnotation === errorNode
10250+
) {
10251+
isUsageLegal = true;
10252+
}
10253+
}
10254+
10255+
if (!isUsageLegal) {
10256+
addError(
10257+
isRequired
10258+
? Localizer.Diagnostic.requiredNotInTypedDict()
10259+
: Localizer.Diagnostic.notRequiredNotInTypedDict(),
10260+
errorNode
10261+
);
10262+
return ClassType.cloneForSpecialization(classType, [convertToInstance(typeArgType)], !!typeArgs);
10263+
}
10264+
10265+
return typeArgType;
10266+
}
10267+
1022310268
function createUnpackType(errorNode: ParseNode, typeArgs: TypeResult[] | undefined): Type {
1022410269
if (!typeArgs || typeArgs.length !== 1) {
1022510270
addError(Localizer.Diagnostic.unpackArgCount(), errorNode);
@@ -10435,7 +10480,7 @@ export function createTypeEvaluator(
1043510480
function transformTypeForPossibleEnumClass(node: NameNode, typeOfExpr: Type): Type {
1043610481
// If the node is within a class that derives from the metaclass
1043710482
// "EnumMeta", we need to treat assignments differently.
10438-
const enclosingClassNode = ParseTreeUtils.getEnclosingClass(node, true);
10483+
const enclosingClassNode = ParseTreeUtils.getEnclosingClass(node, /* stopAtFunction */ true);
1043910484
if (enclosingClassNode) {
1044010485
const enumClassInfo = getTypeOfClass(enclosingClassNode);
1044110486

@@ -10630,6 +10675,8 @@ export function createTypeEvaluator(
1063010675
Concatenate: { alias: '', module: 'builtins' },
1063110676
TypeGuard: { alias: '', module: 'builtins' },
1063210677
Unpack: { alias: '', module: 'builtins' },
10678+
Required: { alias: '', module: 'builtins' },
10679+
NotRequired: { alias: '', module: 'builtins' },
1063310680
};
1063410681

1063510682
const aliasMapEntry = specialTypes[assignedName];
@@ -11462,7 +11509,7 @@ export function createTypeEvaluator(
1146211509

1146311510
// There was no cached type, so create a new one.
1146411511
// Retrieve the containing class node if the function is a method.
11465-
const containingClassNode = ParseTreeUtils.getEnclosingClass(node, true);
11512+
const containingClassNode = ParseTreeUtils.getEnclosingClass(node, /* stopAtFunction */ true);
1146611513
let containingClassType: ClassType | undefined;
1146711514
if (containingClassNode) {
1146811515
const classInfo = getTypeOfClass(containingClassNode);
@@ -15808,6 +15855,11 @@ export function createTypeEvaluator(
1580815855
case 'Unpack': {
1580915856
return createUnpackType(errorNode, typeArgs);
1581015857
}
15858+
15859+
case 'Required':
15860+
case 'NotRequired': {
15861+
return createRequiredType(classType, errorNode, aliasedName === 'Required', typeArgs);
15862+
}
1581115863
}
1581215864
}
1581315865

@@ -19924,9 +19976,18 @@ export function createTypeEvaluator(
1992419976
// Only variables (not functions, classes, etc.) are considered.
1992519977
const lastDecl = getLastTypedDeclaredForSymbol(symbol);
1992619978
if (lastDecl && lastDecl.type === DeclarationType.Variable) {
19979+
const valueType = getDeclaredTypeOfSymbol(symbol) || UnknownType.create();
19980+
let isRequired = !ClassType.isCanOmitDictValues(classType);
19981+
19982+
if (isRequiredTypedDictVariable(symbol)) {
19983+
isRequired = true;
19984+
} else if (isNotRequiredTypedDictVariable(symbol)) {
19985+
isRequired = false;
19986+
}
19987+
1992719988
keyMap.set(name, {
19928-
valueType: getDeclaredTypeOfSymbol(symbol) || UnknownType.create(),
19929-
isRequired: !ClassType.isCanOmitDictValues(classType),
19989+
valueType,
19990+
isRequired,
1993019991
isProvided: false,
1993119992
});
1993219993
}

packages/pyright-internal/src/localization/localize.ts

+4
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,8 @@ export namespace Localizer {
451451
export const noReturnContainsReturn = () => getRawString('Diagnostic.noReturnContainsReturn');
452452
export const noReturnContainsYield = () => getRawString('Diagnostic.noReturnContainsYield');
453453
export const noReturnReturnsNone = () => getRawString('Diagnostic.noReturnReturnsNone');
454+
export const notRequiredArgCount = () => getRawString('Diagnostic.notRequiredArgCount');
455+
export const notRequiredNotInTypedDict = () => getRawString('Diagnostic.notRequiredNotInTypedDict');
454456
export const objectNotCallable = () =>
455457
new ParameterizedString<{ type: string }>(getRawString('Diagnostic.objectNotCallable'));
456458
export const obscuredClassDeclaration = () =>
@@ -535,6 +537,8 @@ export namespace Localizer {
535537
export const recursiveDefinition = () =>
536538
new ParameterizedString<{ name: string }>(getRawString('Diagnostic.recursiveDefinition'));
537539
export const relativeImportNotAllowed = () => getRawString('Diagnostic.relativeImportNotAllowed');
540+
export const requiredArgCount = () => getRawString('Diagnostic.requiredArgCount');
541+
export const requiredNotInTypedDict = () => getRawString('Diagnostic.requiredNotInTypedDict');
538542
export const returnMissing = () =>
539543
new ParameterizedString<{ returnType: string }>(getRawString('Diagnostic.returnMissing'));
540544
export const returnOutsideFunction = () => getRawString('Diagnostic.returnOutsideFunction');

packages/pyright-internal/src/localization/package.nls.en-us.json

+4
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@
218218
"noReturnContainsReturn": "Function with declared return type \"NoReturn\" cannot include a return statement",
219219
"noReturnContainsYield": "Function with declared return type \"NoReturn\" cannot include a yield statement",
220220
"noReturnReturnsNone": "Function with declared type of \"NoReturn\" cannot return \"None\"",
221+
"notRequiredArgCount": "Expected a single type argument after \"NotRequired\"",
222+
"notRequiredNotInTypedDict": "\"NotRequired\" is allowed only within TypedDict",
221223
"objectNotCallable": "Object of type \"{type}\" is not callable",
222224
"obscuredClassDeclaration": "Class declaration \"{name}\" is obscured by a declaration of the same name",
223225
"obscuredFunctionDeclaration": "Function declaration \"{name}\" is obscured by a declaration of the same name",
@@ -264,6 +266,8 @@
264266
"raiseParams": "\"raise\" requires one or more parameters when used outside of except clause",
265267
"relativeImportNotAllowed": "Relative imports cannot be used with \"import .a\" form; use \"from . import a\" instead",
266268
"recursiveDefinition": "Type of \"{name}\" could not be determined because it refers to itself",
269+
"requiredArgCount": "Expected a single type argument after \"Required\"",
270+
"requiredNotInTypedDict": "\"Required\" is allowed only within TypedDict",
267271
"returnOutsideFunction": "\"return\" can be used only within a function",
268272
"returnMissing": "Function with declared type of \"{returnType}\" must return value",
269273
"returnTypeContravariant": "Contravariant type variable cannot be used in return type",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# This sample tests the handling of Required and NotRequired
2+
# (PEP 655) in TypedDict definitions.
3+
4+
from typing import Annotated, NotRequired, Required, TypedDict
5+
6+
7+
class TD1(TypedDict):
8+
a: Required[int]
9+
b: NotRequired[int]
10+
11+
# This should generate an error because NotRequired can't be
12+
# used in this context.
13+
c: NotRequired[NotRequired[int]]
14+
15+
# This should generate an error because Required can't be
16+
# used in this context.
17+
d: Required[Required[int]]
18+
19+
e: NotRequired[Annotated[int, "hi"]]
20+
21+
# This should generate an error because it's missing type args.
22+
f: Required
23+
24+
# This should generate an error because it's missing type args.
25+
g: NotRequired
26+
27+
28+
# This should generate an error because Required can't be
29+
# used in this context.
30+
x: Required[int]
31+
32+
# This should generate an error because NotRequired can't be
33+
# used in this context.
34+
y: Required[int]
35+
36+
37+
class Foo:
38+
# This should generate an error because Required can't be
39+
# used in this context.
40+
x: Required[int]
41+
42+
# This should generate an error because NotRequired can't be
43+
# used in this context.
44+
y: Required[int]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# This sample tests the handling of Required and NotRequired
2+
# (PEP 655) in TypedDict definitions.
3+
4+
from typing import Literal, NotRequired, Optional, Required, Type, TypedDict
5+
6+
7+
class TD1(TypedDict, total=False):
8+
a: Required[int]
9+
b: NotRequired[str]
10+
c: Required[int | str]
11+
d: Required[Optional[str]]
12+
e: Required[Literal[1, 2, 3]]
13+
f: Required[None]
14+
g: Required[Type[int]]
15+
16+
17+
td1_1: TD1 = {"a": 3, "c": "hi", "d": None, "e": 3, "f": None, "g": int}
18+
19+
# This should generate an error because a is missing.
20+
td1_2: TD1 = {"c": "hi", "d": None, "e": 3, "f": None, "g": int}
21+
22+
# This should generate an error because c is missing.
23+
td1_3: TD1 = {"a": 3, "d": None, "e": 3, "f": None, "g": int}
24+
25+
# This should generate an error because d is missing.
26+
td1_4: TD1 = {"a": 3, "c": "hi", "e": 3, "f": None, "g": int}
27+
28+
# This should generate an error because e is missing.
29+
td1_5: TD1 = {"a": 3, "c": "hi", "d": None, "f": None, "g": int}
30+
31+
# This should generate an error because f is missing.
32+
td1_6: TD1 = {"a": 3, "c": "hi", "d": None, "e": 3, "g": int}
33+
34+
# This should generate an error because g is missing.
35+
td1_7: TD1 = {"a": 3, "c": "hi", "d": None, "e": 3, "f": None}
36+
37+
38+
class TD2(TypedDict, total=True):
39+
a: Required[int]
40+
b: NotRequired[str]
41+
c: Required[int | str]
42+
d: NotRequired[Optional[str]]
43+
e: NotRequired[Literal[1, 2, 3]]
44+
f: NotRequired[None]
45+
g: NotRequired[Type[int]]
46+
47+
48+
td2_1: TD2 = {"a": 3, "c": "hi", "d": None, "e": 3, "f": None, "g": int}
49+
50+
td2_2: TD2 = {"a": 3, "c": "hi"}
51+
52+
# This should generate an error because c is missing.
53+
td2_3: TD2 = {"a": 3}

packages/pyright-internal/src/tests/typeEvaluator2.test.ts

+18
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,24 @@ test('TypedDict12', () => {
705705
TestUtils.validateResults(analysisResults, 0);
706706
});
707707

708+
test('Required1', () => {
709+
// Analyze with Python 3.10 settings.
710+
const configOptions = new ConfigOptions('.');
711+
configOptions.defaultPythonVersion = PythonVersion.V3_10;
712+
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['required1.py'], configOptions);
713+
714+
TestUtils.validateResults(analysisResults, 8);
715+
});
716+
717+
test('Required2', () => {
718+
// Analyze with Python 3.10 settings.
719+
const configOptions = new ConfigOptions('.');
720+
configOptions.defaultPythonVersion = PythonVersion.V3_10;
721+
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['required2.py'], configOptions);
722+
723+
TestUtils.validateResults(analysisResults, 7);
724+
});
725+
708726
test('Metaclass1', () => {
709727
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['metaclass1.py']);
710728
TestUtils.validateResults(analysisResults, 0);

packages/pyright-internal/typeshed-fallback/stdlib/typing.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ if sys.version_info >= (3, 10):
6262
Concatenate: _SpecialForm = ...
6363
TypeAlias: _SpecialForm = ...
6464
TypeGuard: _SpecialForm = ...
65+
Required: _SpecialForm = ...
66+
NotRequired: _SpecialForm = ...
6567

6668
class TypeVarTuple:
6769
__name__: str

packages/pyright-internal/typeshed-fallback/stubs/typing-extensions/typing_extensions.pyi

+4
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,7 @@ class TypeVarTuple:
122122

123123
# PEP 647
124124
TypeGuard: _SpecialForm = ...
125+
126+
# PEP 655
127+
Required: _SpecialForm = ...
128+
NotRequired: _SpecialForm = ...

0 commit comments

Comments
 (0)