Skip to content

Commit

Permalink
Allow infer types in type argument positions
Browse files Browse the repository at this point in the history
  • Loading branch information
weswigham committed Mar 7, 2018
1 parent a138985 commit f373993
Show file tree
Hide file tree
Showing 14 changed files with 298 additions and 44 deletions.
2 changes: 2 additions & 0 deletions src/compiler/binder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,8 @@ namespace ts {
return ContainerFlags.IsContainer | ContainerFlags.HasLocals;

case SyntaxKind.ConditionalType:
case SyntaxKind.CallExpression:
case SyntaxKind.NewExpression:
return ContainerFlags.IsInferenceContainer;

case SyntaxKind.SourceFile:
Expand Down
77 changes: 66 additions & 11 deletions src/compiler/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3021,6 +3021,11 @@ namespace ts {
if (type.flags & TypeFlags.Substitution) {
return typeToTypeNodeHelper((<SubstitutionType>type).typeParameter, context);
}
if (type.flags & TypeFlags.InferType) {
// Infer types only parse as identifiers, so the target should always be a TypeParameter that becomes a TypeReferenceNode
const ref = typeToTypeNodeHelper((<InferType>type).target, context) as TypeReferenceNode;
return createInferTypeNode(createTypeParameterDeclaration(ref.typeName as Identifier));
}

Debug.fail("Should be unreachable.");

Expand Down Expand Up @@ -3517,7 +3522,7 @@ namespace ts {
const params = getTypeParametersOfClassOrInterface(
parentSymbol.flags & SymbolFlags.Alias ? resolveAlias(parentSymbol) : parentSymbol
);
typeParameterNodes = mapToTypeNodes(map(params, (nextSymbol as TransientSymbol).mapper), context);
typeParameterNodes = mapToTypeNodes(mapIndexless(params, (nextSymbol as TransientSymbol).mapper), context);
}
else {
typeParameterNodes = typeParametersToTypeParameterDeclarations(symbol, context);
Expand Down Expand Up @@ -4736,12 +4741,14 @@ namespace ts {
case SyntaxKind.JSDocTemplateTag:
case SyntaxKind.MappedType:
case SyntaxKind.ConditionalType:
case SyntaxKind.CallExpression:
case SyntaxKind.NewExpression:
const outerTypeParameters = getOuterTypeParameters(node, includeThisTypes);
if (node.kind === SyntaxKind.MappedType) {
return append(outerTypeParameters, getDeclaredTypeOfTypeParameter(getSymbolOfNode((<MappedTypeNode>node).typeParameter)));
}
else if (node.kind === SyntaxKind.ConditionalType) {
return concatenate(outerTypeParameters, getInferTypeParameters(<ConditionalTypeNode>node));
else if (node.kind === SyntaxKind.ConditionalType || node.kind === SyntaxKind.NewExpression || node.kind === SyntaxKind.CallExpression) {
return concatenate(outerTypeParameters, getInferTypeParameters(<ConditionalTypeNode | CallLikeExpression>node));
}
const outerAndOwnTypeParameters = appendTypeParameters(outerTypeParameters, getEffectiveTypeParameterDeclarations(<DeclarationWithTypeParameters>node) || emptyArray);
const thisType = includeThisTypes &&
Expand Down Expand Up @@ -8334,7 +8341,7 @@ namespace ts {
return type.resolvedFalseType || (type.resolvedFalseType = instantiateType(type.root.falseType, type.mapper));
}

function getInferTypeParameters(node: ConditionalTypeNode): TypeParameter[] {
function getInferTypeParameters(node: ConditionalTypeNode | CallLikeExpression): TypeParameter[] {
let result: TypeParameter[];
if (node.locals) {
node.locals.forEach(symbol => {
Expand Down Expand Up @@ -8375,10 +8382,16 @@ namespace ts {
return links.resolvedType;
}

function createInferType(target: TypeParameter): InferType {
const type = createType(TypeFlags.InferType) as InferType;
type.target = target;
return type;
}

function getTypeFromInferTypeNode(node: InferTypeNode): Type {
const links = getNodeLinks(node);
if (!links.resolvedType) {
links.resolvedType = getDeclaredTypeOfTypeParameter(getSymbolOfNode(node.typeParameter));
links.resolvedType = createInferType(getDeclaredTypeOfTypeParameter(getSymbolOfNode(node.typeParameter)));
}
return links.resolvedType;
}
Expand Down Expand Up @@ -8882,7 +8895,7 @@ namespace ts {
// mapper to the type parameters to produce the effective list of type arguments, and compute the
// instantiation cache key from the type IDs of the type arguments.
const combinedMapper = type.objectFlags & ObjectFlags.Instantiated ? combineTypeMappers(type.mapper, mapper) : mapper;
const typeArguments = map(typeParameters, combinedMapper);
const typeArguments = mapIndexless(typeParameters, combinedMapper);
const id = getTypeListId(typeArguments);
let result = links.instantiations.get(id);
if (!result) {
Expand Down Expand Up @@ -8965,7 +8978,7 @@ namespace ts {
// We are instantiating a conditional type that has one or more type parameters in scope. Apply the
// mapper to the type parameters to produce the effective list of type arguments, and compute the
// instantiation cache key from the type IDs of the type arguments.
const typeArguments = map(root.outerTypeParameters, mapper);
const typeArguments = mapIndexless(root.outerTypeParameters, mapper);
const id = getTypeListId(typeArguments);
let result = root.instantiations.get(id);
if (!result) {
Expand Down Expand Up @@ -9036,6 +9049,9 @@ namespace ts {
if (type.flags & TypeFlags.Substitution) {
return mapper((<SubstitutionType>type).typeParameter);
}
if (type.flags & TypeFlags.InferType) {
return mapper((<InferType>type).target, /*isInferDeclaration*/ true);
}
}
return type;
}
Expand Down Expand Up @@ -9642,9 +9658,15 @@ namespace ts {
if (source.flags & TypeFlags.Substitution) {
source = relation === definitelyAssignableRelation ? (<SubstitutionType>source).typeParameter : (<SubstitutionType>source).substitute;
}
if (source.flags & TypeFlags.InferType) {
source = (<InferType>source).target;
}
if (target.flags & TypeFlags.Substitution) {
target = (<SubstitutionType>target).typeParameter;
}
if (target.flags & TypeFlags.InferType) {
target = (<InferType>target).target;
}

// both types are the same - covers 'they are the same primitive type or both are Any' or the same type parameter cases
if (source === target) return Ternary.True;
Expand Down Expand Up @@ -11587,6 +11609,12 @@ namespace ts {
if (!couldContainTypeVariables(target)) {
return;
}
if (source.flags & TypeFlags.InferType) {
source = (source as InferType).target;
}
if (target.flags & TypeFlags.InferType) {
target = (target as InferType).target;
}
if (source.flags & TypeFlags.Any) {
// We are inferring from an 'any' type. We want to infer this type for every type parameter
// referenced in the target type, so we record it as the propagation type and infer from the
Expand Down Expand Up @@ -17529,10 +17557,29 @@ namespace ts {
candidate = originalCandidate;
if (candidate.typeParameters) {
let typeArgumentTypes: Type[];
const isJavascript = isInJavaScriptFile(candidate.declaration);
if (typeArguments) {
const typeArgumentResult = checkTypeArguments(candidate, typeArguments, /*reportErrors*/ false);
if (typeArgumentResult) {
typeArgumentTypes = typeArgumentResult;
if (node.locals) {
// Call has `infer` arguments that still need to be inferred and instantiated
const inferParams = getInferTypeParameters(node);
// Mapper replaces references to infered type parameters with emptyObjectType
// Causing the original location to be the _only_ inference site
const preprocessMapper = (p: TypeParameter, isInferDecl: boolean) => {
if (!isInferDecl && contains(inferParams, p)) return emptyObjectType;
return p;
};
const resultsWithNonInferInferredVarsDefaulted = map(typeArgumentResult, t => instantiateType(t, preprocessMapper));
const partialCandidate = getSignatureInstantiation(candidate, resultsWithNonInferInferredVarsDefaulted, isJavascript);
const context = createInferenceContext(inferParams, partialCandidate, InferenceFlags.None);
const inferences = inferTypeArguments(node, partialCandidate, args, excludeArgument, context);
const mapper = createTypeMapper(inferParams, inferences);
typeArgumentTypes = map(typeArgumentResult, t => instantiateType(t, mapper));
}
else {
typeArgumentTypes = typeArgumentResult;
}
}
else {
candidateForTypeArgumentError = originalCandidate;
Expand All @@ -17542,7 +17589,6 @@ namespace ts {
else {
typeArgumentTypes = inferTypeArguments(node, candidate, args, excludeArgument, inferenceContext);
}
const isJavascript = isInJavaScriptFile(candidate.declaration);
candidate = getSignatureInstantiation(candidate, typeArgumentTypes, isJavascript);
}
if (!checkApplicableSignature(node, args, candidate, relation, excludeArgument, /*reportErrors*/ false)) {
Expand Down Expand Up @@ -20544,9 +20590,18 @@ namespace ts {
forEachChild(node, checkSourceElement);
}

function isConditionalTypeExtendsClause(n: Node) {
return n.parent && n.parent.kind === SyntaxKind.ConditionalType && (<ConditionalTypeNode>n.parent).extendsType === n;
}

function isCallOrNewExpressionTypeArgument(n: Node) {
return n.parent && (n.parent.kind === SyntaxKind.CallExpression || n.parent.kind === SyntaxKind.NewExpression)
&& contains((<CallExpression | NewExpression>n.parent).typeArguments, n);
}

function checkInferType(node: InferTypeNode) {
if (!findAncestor(node, n => n.parent && n.parent.kind === SyntaxKind.ConditionalType && (<ConditionalTypeNode>n.parent).extendsType === n)) {
grammarErrorOnNode(node, Diagnostics.infer_declarations_are_only_permitted_in_the_extends_clause_of_a_conditional_type);
if (!findAncestor(node, n => isConditionalTypeExtendsClause(n) || isCallOrNewExpressionTypeArgument(n))) {
grammarErrorOnNode(node, Diagnostics.infer_declarations_are_only_permitted_in_the_extends_clause_of_a_conditional_type_or_in_call_or_new_expression_type_argument_lists);
}
checkSourceElement(node.typeParameter);
}
Expand Down
11 changes: 11 additions & 0 deletions src/compiler/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,17 @@ namespace ts {
array.length = 0;
}

export function mapIndexless<T, U>(array: ReadonlyArray<T>, f: (x: T) => U): U[] {
let result: U[];
if (array) {
result = [];
for (const elem of array) {
result.push(f(elem));
}
}
return result;
}

export function map<T, U>(array: ReadonlyArray<T>, f: (x: T, i: number) => U): U[] {
let result: U[];
if (array) {
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/diagnosticMessages.json
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@
"category": "Error",
"code": 1337
},
"'infer' declarations are only permitted in the 'extends' clause of a conditional type.": {
"'infer' declarations are only permitted in the 'extends' clause of a conditional type or in call or new expression type argument lists.": {
"category": "Error",
"code": 1338
},
Expand Down
10 changes: 8 additions & 2 deletions src/compiler/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3538,6 +3538,7 @@ namespace ts {
/* @internal */
ContainsAnyFunctionType = 1 << 26, // Type is or contains the anyFunctionType
NonPrimitive = 1 << 27, // intrinsic object type
InferType = 1 << 28, // A type whose concrete value upon instantiation will be inferred at a given site
/* @internal */
GenericMappedType = 1 << 29, // Flag used by maybeTypeOfKind

Expand All @@ -3562,7 +3563,7 @@ namespace ts {
ESSymbolLike = ESSymbol | UniqueESSymbol,
UnionOrIntersection = Union | Intersection,
StructuredType = Object | Union | Intersection,
TypeVariable = TypeParameter | IndexedAccess,
TypeVariable = TypeParameter | IndexedAccess | InferType,
InstantiableNonPrimitive = TypeVariable | Conditional | Substitution,
InstantiablePrimitive = Index,
Instantiable = InstantiableNonPrimitive | InstantiablePrimitive,
Expand Down Expand Up @@ -3817,6 +3818,11 @@ namespace ts {
resolvedDefaultType?: Type;
}

// Infer Types (TypeFlags.InferType)
export interface InferType extends Type {
target: TypeParameter;
}

// Indexed access types (TypeFlags.IndexedAccess)
// Possible forms are T[xxx], xxx[T], or xxx[keyof T], where T is a type variable
export interface IndexedAccessType extends InstantiableType {
Expand Down Expand Up @@ -3919,7 +3925,7 @@ namespace ts {
}

/* @internal */
export type TypeMapper = (t: TypeParameter) => Type;
export type TypeMapper = (t: TypeParameter, isInferDeclaration?: boolean) => Type;

export const enum InferencePriority {
NakedTypeVariable = 1 << 0, // Naked type variable in union or intersection type
Expand Down
14 changes: 9 additions & 5 deletions tests/baselines/reference/api/tsserverlibrary.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,7 @@ declare namespace ts {
Conditional = 2097152,
Substitution = 4194304,
NonPrimitive = 134217728,
InferType = 268435456,
Literal = 224,
Unit = 13536,
StringOrNumberLiteral = 96,
Expand All @@ -2085,12 +2086,12 @@ declare namespace ts {
ESSymbolLike = 1536,
UnionOrIntersection = 393216,
StructuredType = 458752,
TypeVariable = 1081344,
InstantiableNonPrimitive = 7372800,
TypeVariable = 269516800,
InstantiableNonPrimitive = 275808256,
InstantiablePrimitive = 524288,
Instantiable = 7897088,
StructuredOrInstantiable = 8355840,
Narrowable = 142575359,
Instantiable = 276332544,
StructuredOrInstantiable = 276791296,
Narrowable = 411010815,
NotUnionOrUnit = 134283777,
}
type DestructuringPattern = BindingPattern | ObjectLiteralExpression | ArrayLiteralExpression;
Expand Down Expand Up @@ -2184,6 +2185,9 @@ declare namespace ts {
}
interface TypeParameter extends InstantiableType {
}
interface InferType extends Type {
target: TypeParameter;
}
interface IndexedAccessType extends InstantiableType {
objectType: Type;
indexType: Type;
Expand Down
14 changes: 9 additions & 5 deletions tests/baselines/reference/api/typescript.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,7 @@ declare namespace ts {
Conditional = 2097152,
Substitution = 4194304,
NonPrimitive = 134217728,
InferType = 268435456,
Literal = 224,
Unit = 13536,
StringOrNumberLiteral = 96,
Expand All @@ -2085,12 +2086,12 @@ declare namespace ts {
ESSymbolLike = 1536,
UnionOrIntersection = 393216,
StructuredType = 458752,
TypeVariable = 1081344,
InstantiableNonPrimitive = 7372800,
TypeVariable = 269516800,
InstantiableNonPrimitive = 275808256,
InstantiablePrimitive = 524288,
Instantiable = 7897088,
StructuredOrInstantiable = 8355840,
Narrowable = 142575359,
Instantiable = 276332544,
StructuredOrInstantiable = 276791296,
Narrowable = 411010815,
NotUnionOrUnit = 134283777,
}
type DestructuringPattern = BindingPattern | ObjectLiteralExpression | ArrayLiteralExpression;
Expand Down Expand Up @@ -2184,6 +2185,9 @@ declare namespace ts {
}
interface TypeParameter extends InstantiableType {
}
interface InferType extends Type {
target: TypeParameter;
}
interface IndexedAccessType extends InstantiableType {
objectType: Type;
indexType: Type;
Expand Down
24 changes: 24 additions & 0 deletions tests/baselines/reference/inferTypeArgumentKeyword.errors.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
tests/cases/conformance/types/typeParameters/typeArgumentLists/inferTypeArgumentKeyword.ts(7,51): error TS2345: Argument of type '{ z: number; }' is not assignable to parameter of type '{ z: { y: number; }; }'.
Types of property 'z' are incompatible.
Type 'number' is not assignable to type '{ y: number; }'.
tests/cases/conformance/types/typeParameters/typeArgumentLists/inferTypeArgumentKeyword.ts(10,30): error TS2345: Argument of type '{ y: number; }' is not assignable to parameter of type 'number'.


==== tests/cases/conformance/types/typeParameters/typeArgumentLists/inferTypeArgumentKeyword.ts (2 errors) ====
declare function foo<A, B, C>(x: A, y: B, z: { z: C }): A & B & C;

// good
foo<infer A, {x: string}, A>({y: 12}, {x: "yes"}, {z: {y: 12}});

// error on 3rd arg
foo<infer A, {x: string}, A>({y: 12}, {x: "yes"}, {z: 12});
~~~~~~~
!!! error TS2345: Argument of type '{ z: number; }' is not assignable to parameter of type '{ z: { y: number; }; }'.
!!! error TS2345: Types of property 'z' are incompatible.
!!! error TS2345: Type 'number' is not assignable to type '{ y: number; }'.

// error on first arg
foo<A, {x: string}, infer A>({y: 12}, {x: "yes"}, {z: 12});
~~~~~~~
!!! error TS2345: Argument of type '{ y: number; }' is not assignable to parameter of type 'number'.

20 changes: 20 additions & 0 deletions tests/baselines/reference/inferTypeArgumentKeyword.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//// [inferTypeArgumentKeyword.ts]
declare function foo<A, B, C>(x: A, y: B, z: { z: C }): A & B & C;

// good
foo<infer A, {x: string}, A>({y: 12}, {x: "yes"}, {z: {y: 12}});

// error on 3rd arg
foo<infer A, {x: string}, A>({y: 12}, {x: "yes"}, {z: 12});

// error on first arg
foo<A, {x: string}, infer A>({y: 12}, {x: "yes"}, {z: 12});


//// [inferTypeArgumentKeyword.js]
// good
foo({ y: 12 }, { x: "yes" }, { z: { y: 12 } });
// error on 3rd arg
foo({ y: 12 }, { x: "yes" }, { z: 12 });
// error on first arg
foo({ y: 12 }, { x: "yes" }, { z: 12 });
Loading

0 comments on commit f373993

Please sign in to comment.