@@ -115,7 +115,7 @@ import * as ScopeUtils from './scopeUtils';
115
115
import { evaluateStaticBoolExpression } from './staticExpressions';
116
116
import { indeterminateSymbolId, Symbol, SymbolFlags } from './symbol';
117
117
import { isConstantName, isDunderName, isPrivateOrProtectedName } from './symbolNameUtils';
118
- import { getLastTypedDeclaredForSymbol, isFinalVariable } from './symbolUtils';
118
+ import { getLastTypedDeclaredForSymbol, isFinalVariable, isNotRequiredTypedDictVariable, isRequiredTypedDictVariable } from './symbolUtils';
119
119
import { PrintableType, TracePrinter } from './tracePrinter';
120
120
import {
121
121
CachedType,
@@ -10220,6 +10220,51 @@ export function createTypeEvaluator(
10220
10220
return ClassType.cloneForSpecialization(classType, [convertToInstance(typeArg)], !!typeArgs);
10221
10221
}
10222
10222
10223
+ function createRequiredType(
10224
+ classType: ClassType,
10225
+ errorNode: ParseNode,
10226
+ isRequired: boolean,
10227
+ typeArgs: TypeResult[] | undefined
10228
+ ): Type {
10229
+ if (!typeArgs || typeArgs.length !== 1) {
10230
+ addError(
10231
+ isRequired ? Localizer.Diagnostic.requiredArgCount() : Localizer.Diagnostic.notRequiredArgCount(),
10232
+ errorNode
10233
+ );
10234
+ return classType;
10235
+ }
10236
+
10237
+ const typeArgType = typeArgs[0].type;
10238
+
10239
+ // Make sure this is used only in a dataclass.
10240
+ const containingClassNode = ParseTreeUtils.getEnclosingClass(errorNode, /* stopAtFunction */ true);
10241
+ const classTypeInfo = containingClassNode ? getTypeOfClass(containingClassNode) : undefined;
10242
+
10243
+ let isUsageLegal = false;
10244
+
10245
+ if (classTypeInfo && isClass(classTypeInfo.classType) && ClassType.isTypedDictClass(classTypeInfo.classType)) {
10246
+ // The only legal usage is when used in a type annotation statement.
10247
+ if (
10248
+ errorNode.parent?.nodeType === ParseNodeType.TypeAnnotation &&
10249
+ errorNode.parent.typeAnnotation === errorNode
10250
+ ) {
10251
+ isUsageLegal = true;
10252
+ }
10253
+ }
10254
+
10255
+ if (!isUsageLegal) {
10256
+ addError(
10257
+ isRequired
10258
+ ? Localizer.Diagnostic.requiredNotInTypedDict()
10259
+ : Localizer.Diagnostic.notRequiredNotInTypedDict(),
10260
+ errorNode
10261
+ );
10262
+ return ClassType.cloneForSpecialization(classType, [convertToInstance(typeArgType)], !!typeArgs);
10263
+ }
10264
+
10265
+ return typeArgType;
10266
+ }
10267
+
10223
10268
function createUnpackType(errorNode: ParseNode, typeArgs: TypeResult[] | undefined): Type {
10224
10269
if (!typeArgs || typeArgs.length !== 1) {
10225
10270
addError(Localizer.Diagnostic.unpackArgCount(), errorNode);
@@ -10435,7 +10480,7 @@ export function createTypeEvaluator(
10435
10480
function transformTypeForPossibleEnumClass(node: NameNode, typeOfExpr: Type): Type {
10436
10481
// If the node is within a class that derives from the metaclass
10437
10482
// "EnumMeta", we need to treat assignments differently.
10438
- const enclosingClassNode = ParseTreeUtils.getEnclosingClass(node, true);
10483
+ const enclosingClassNode = ParseTreeUtils.getEnclosingClass(node, /* stopAtFunction */ true);
10439
10484
if (enclosingClassNode) {
10440
10485
const enumClassInfo = getTypeOfClass(enclosingClassNode);
10441
10486
@@ -10630,6 +10675,8 @@ export function createTypeEvaluator(
10630
10675
Concatenate: { alias: '', module: 'builtins' },
10631
10676
TypeGuard: { alias: '', module: 'builtins' },
10632
10677
Unpack: { alias: '', module: 'builtins' },
10678
+ Required: { alias: '', module: 'builtins' },
10679
+ NotRequired: { alias: '', module: 'builtins' },
10633
10680
};
10634
10681
10635
10682
const aliasMapEntry = specialTypes[assignedName];
@@ -11462,7 +11509,7 @@ export function createTypeEvaluator(
11462
11509
11463
11510
// There was no cached type, so create a new one.
11464
11511
// Retrieve the containing class node if the function is a method.
11465
- const containingClassNode = ParseTreeUtils.getEnclosingClass(node, true);
11512
+ const containingClassNode = ParseTreeUtils.getEnclosingClass(node, /* stopAtFunction */ true);
11466
11513
let containingClassType: ClassType | undefined;
11467
11514
if (containingClassNode) {
11468
11515
const classInfo = getTypeOfClass(containingClassNode);
@@ -15808,6 +15855,11 @@ export function createTypeEvaluator(
15808
15855
case 'Unpack': {
15809
15856
return createUnpackType(errorNode, typeArgs);
15810
15857
}
15858
+
15859
+ case 'Required':
15860
+ case 'NotRequired': {
15861
+ return createRequiredType(classType, errorNode, aliasedName === 'Required', typeArgs);
15862
+ }
15811
15863
}
15812
15864
}
15813
15865
@@ -19924,9 +19976,18 @@ export function createTypeEvaluator(
19924
19976
// Only variables (not functions, classes, etc.) are considered.
19925
19977
const lastDecl = getLastTypedDeclaredForSymbol(symbol);
19926
19978
if (lastDecl && lastDecl.type === DeclarationType.Variable) {
19979
+ const valueType = getDeclaredTypeOfSymbol(symbol) || UnknownType.create();
19980
+ let isRequired = !ClassType.isCanOmitDictValues(classType);
19981
+
19982
+ if (isRequiredTypedDictVariable(symbol)) {
19983
+ isRequired = true;
19984
+ } else if (isNotRequiredTypedDictVariable(symbol)) {
19985
+ isRequired = false;
19986
+ }
19987
+
19927
19988
keyMap.set(name, {
19928
- valueType: getDeclaredTypeOfSymbol(symbol) || UnknownType.create() ,
19929
- isRequired: !ClassType.isCanOmitDictValues(classType) ,
19989
+ valueType,
19990
+ isRequired,
19930
19991
isProvided: false,
19931
19992
});
19932
19993
}
0 commit comments