Skip to content

Commit 98a0ce0

Browse files
feat(ui): custom field types connection validation
In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic. *Actually, it was `"Unknown"`, but I changed it to custom for clarity. Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields. To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`. This ended up needing a bit of fanagling: - If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property. While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer. (Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.) - Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future. - We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`. Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`. This causes a TS error. This behaviour is somewhat controversial (see microsoft/TypeScript#14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent.
1 parent 7b93b5e commit 98a0ce0

17 files changed

+98
-76
lines changed

invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx

+7-3
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,13 @@ const AddNodePopover = () => {
7474

7575
return some(handles, (handle) => {
7676
const sourceType =
77-
handleFilter == 'source' ? fieldFilter : handle.type;
77+
handleFilter == 'source'
78+
? fieldFilter
79+
: handle.originalType ?? handle.type;
7880
const targetType =
79-
handleFilter == 'target' ? fieldFilter : handle.type;
81+
handleFilter == 'target'
82+
? fieldFilter
83+
: handle.originalType ?? handle.type;
8084

8185
return validateSourceAndTargetTypes(sourceType, targetType);
8286
});
@@ -111,7 +115,7 @@ const AddNodePopover = () => {
111115

112116
data.sort((a, b) => a.label.localeCompare(b.label));
113117

114-
return { data, t };
118+
return { data };
115119
},
116120
defaultSelectorOptions
117121
);

invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx

+4-5
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@ import { createSelector } from '@reduxjs/toolkit';
22
import { stateSelector } from 'app/store/store';
33
import { useAppSelector } from 'app/store/storeHooks';
44
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
5-
import { FIELDS } from 'features/nodes/types/constants';
65
import { memo } from 'react';
76
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
7+
import { getFieldColor } from '../edges/util/getEdgeColor';
88

99
const selector = createSelector(stateSelector, ({ nodes }) => {
1010
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
1111
nodes;
1212

13-
const stroke =
14-
currentConnectionFieldType && shouldColorEdges
15-
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
16-
: colorTokenToCssVar('base.500');
13+
const stroke = shouldColorEdges
14+
? getFieldColor(currentConnectionFieldType)
15+
: colorTokenToCssVar('base.500');
1716

1817
let className = 'react-flow__custom_connection-path';
1918

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
2+
import { FIELDS } from 'features/nodes/types/constants';
3+
import { FieldType } from 'features/nodes/types/types';
4+
5+
export const getFieldColor = (fieldType: FieldType | string | null): string => {
6+
if (!fieldType) {
7+
return colorTokenToCssVar('base.500');
8+
}
9+
const color = FIELDS[fieldType]?.color;
10+
11+
return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
12+
};

invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
22
import { stateSelector } from 'app/store/store';
33
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
44
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
5-
import { FIELDS } from 'features/nodes/types/constants';
65
import { isInvocationNode } from 'features/nodes/types/types';
6+
import { getFieldColor } from './getEdgeColor';
77

88
export const makeEdgeSelector = (
99
source: string,
@@ -29,7 +29,7 @@ export const makeEdgeSelector = (
2929

3030
const stroke =
3131
sourceType && nodes.shouldColorEdges
32-
? colorTokenToCssVar(FIELDS[sourceType].color)
32+
? getFieldColor(sourceType)
3333
: colorTokenToCssVar('base.500');
3434

3535
return {

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx

+10-20
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import { Tooltip } from '@chakra-ui/react';
2-
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
32
import {
43
COLLECTION_TYPES,
5-
FIELDS,
64
HANDLE_TOOLTIP_OPEN_DELAY,
75
MODEL_TYPES,
86
POLYMORPHIC_TYPES,
@@ -13,6 +11,7 @@ import {
1311
} from 'features/nodes/types/types';
1412
import { CSSProperties, memo, useMemo } from 'react';
1513
import { Handle, HandleType, Position } from 'reactflow';
14+
import { getFieldColor } from '../../../edges/util/getEdgeColor';
1615

1716
export const handleBaseStyles: CSSProperties = {
1817
position: 'absolute',
@@ -47,14 +46,14 @@ const FieldHandle = (props: FieldHandleProps) => {
4746
isConnectionStartField,
4847
connectionError,
4948
} = props;
50-
const { name, type, originalType } = fieldTemplate;
51-
const { color: typeColor } = FIELDS[type];
49+
const { name } = fieldTemplate;
50+
const type = fieldTemplate.originalType ?? fieldTemplate.type;
5251

5352
const styles: CSSProperties = useMemo(() => {
54-
const isCollectionType = COLLECTION_TYPES.includes(type);
55-
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
56-
const isModelType = MODEL_TYPES.includes(type);
57-
const color = colorTokenToCssVar(typeColor);
53+
const isCollectionType = COLLECTION_TYPES.some((t) => t === type);
54+
const isPolymorphicType = POLYMORPHIC_TYPES.some((t) => t === type);
55+
const isModelType = MODEL_TYPES.some((t) => t === type);
56+
const color = getFieldColor(type);
5857
const s: CSSProperties = {
5958
backgroundColor:
6059
isCollectionType || isPolymorphicType
@@ -97,23 +96,14 @@ const FieldHandle = (props: FieldHandleProps) => {
9796
isConnectionInProgress,
9897
isConnectionStartField,
9998
type,
100-
typeColor,
10199
]);
102100

103101
const tooltip = useMemo(() => {
104-
if (isConnectionInProgress && isConnectionStartField) {
105-
return originalType;
106-
}
107102
if (isConnectionInProgress && connectionError) {
108-
return connectionError ?? originalType;
103+
return connectionError;
109104
}
110-
return originalType;
111-
}, [
112-
connectionError,
113-
isConnectionInProgress,
114-
isConnectionStartField,
115-
originalType,
116-
]);
105+
return type;
106+
}, [connectionError, isConnectionInProgress, type]);
117107

118108
return (
119109
<Tooltip

invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ export const useFieldType = (
2020
if (!isInvocationNode(node)) {
2121
return;
2222
}
23-
return node?.data[KIND_MAP[kind]][fieldName]?.type;
23+
const field = node.data[KIND_MAP[kind]][fieldName];
24+
return field?.originalType ?? field?.type;
2425
},
2526
defaultSelectorOptions
2627
),

invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ const nodesSlice = createSlice({
258258
handleType === 'source'
259259
? node.data.outputs[handleId]
260260
: node.data.inputs[handleId];
261-
state.currentConnectionFieldType = field?.type ?? null;
261+
state.currentConnectionFieldType =
262+
field?.originalType ?? field?.type ?? null;
262263
},
263264
connectionMade: (state, action: PayloadAction<Connection>) => {
264265
const fieldType = state.currentConnectionFieldType;

invokeai/frontend/web/src/features/nodes/store/types.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ export type NodesState = {
2121
edges: Edge<InvocationEdgeExtra>[];
2222
nodeTemplates: Record<string, InvocationTemplate>;
2323
connectionStartParams: OnConnectStartParams | null;
24-
currentConnectionFieldType: FieldType | null;
24+
currentConnectionFieldType: FieldType | string | null;
2525
connectionMade: boolean;
2626
modifyingEdge: boolean;
2727
shouldShowFieldTypeLegend: boolean;

invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ export const buildNodeData = (
9494
name: outputName,
9595
type: outputTemplate.type,
9696
fieldKind: 'output',
97+
originalType: outputTemplate.originalType,
9798
};
9899

99100
outputsAccumulator[outputName] = outputFieldValue;

invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts

+8-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import { getIsGraphAcyclic } from './getIsGraphAcyclic';
1212
const isValidConnection = (
1313
edges: Edge[],
1414
handleCurrentType: HandleType,
15-
handleCurrentFieldType: FieldType,
15+
handleCurrentFieldType: FieldType | string,
1616
node: Node,
1717
handle: InputFieldValue | OutputFieldValue
1818
) => {
@@ -35,7 +35,12 @@ const isValidConnection = (
3535
}
3636
}
3737

38-
if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
38+
if (
39+
!validateSourceAndTargetTypes(
40+
handleCurrentFieldType,
41+
handle.originalType ?? handle.type
42+
)
43+
) {
3944
isValidConnection = false;
4045
}
4146

@@ -49,7 +54,7 @@ export const findConnectionToValidHandle = (
4954
handleCurrentNodeId: string,
5055
handleCurrentName: string,
5156
handleCurrentType: HandleType,
52-
handleCurrentFieldType: FieldType
57+
handleCurrentFieldType: FieldType | string
5358
): Connection | null => {
5459
if (node.id === handleCurrentNodeId) {
5560
return null;

invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import { createSelector } from '@reduxjs/toolkit';
22
import { stateSelector } from 'app/store/store';
3-
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
43
import { FieldType } from 'features/nodes/types/types';
54
import i18n from 'i18next';
65
import { HandleType } from 'reactflow';
6+
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
77
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
88

99
/**
@@ -15,7 +15,7 @@ export const makeConnectionErrorSelector = (
1515
nodeId: string,
1616
fieldName: string,
1717
handleType: HandleType,
18-
fieldType?: FieldType
18+
fieldType?: FieldType | string
1919
) => {
2020
return createSelector(stateSelector, (state) => {
2121
if (!fieldType) {

invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts

+12-10
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ import {
77
import { FieldType } from 'features/nodes/types/types';
88

99
export const validateSourceAndTargetTypes = (
10-
sourceType: FieldType,
11-
targetType: FieldType
10+
sourceType: FieldType | string,
11+
targetType: FieldType | string
1212
) => {
1313
// TODO: There's a bug with Collect -> Iterate nodes:
1414
// https://github.com/invoke-ai/InvokeAI/issues/3956
@@ -31,17 +31,18 @@ export const validateSourceAndTargetTypes = (
3131
*/
3232

3333
const isCollectionItemToNonCollection =
34-
sourceType === 'CollectionItem' && !COLLECTION_TYPES.includes(targetType);
34+
sourceType === 'CollectionItem' &&
35+
!COLLECTION_TYPES.some((t) => t === targetType);
3536

3637
const isNonCollectionToCollectionItem =
3738
targetType === 'CollectionItem' &&
38-
!COLLECTION_TYPES.includes(sourceType) &&
39-
!POLYMORPHIC_TYPES.includes(sourceType);
39+
!COLLECTION_TYPES.some((t) => t === sourceType) &&
40+
!POLYMORPHIC_TYPES.some((t) => t === sourceType);
4041

4142
const isAnythingToPolymorphicOfSameBaseType =
42-
POLYMORPHIC_TYPES.includes(targetType) &&
43+
POLYMORPHIC_TYPES.some((t) => t === targetType) &&
4344
(() => {
44-
if (!POLYMORPHIC_TYPES.includes(targetType)) {
45+
if (!POLYMORPHIC_TYPES.some((t) => t === targetType)) {
4546
return false;
4647
}
4748
const baseType =
@@ -57,11 +58,12 @@ export const validateSourceAndTargetTypes = (
5758

5859
const isGenericCollectionToAnyCollectionOrPolymorphic =
5960
sourceType === 'Collection' &&
60-
(COLLECTION_TYPES.includes(targetType) ||
61-
POLYMORPHIC_TYPES.includes(targetType));
61+
(COLLECTION_TYPES.some((t) => t === targetType) ||
62+
POLYMORPHIC_TYPES.some((t) => t === targetType));
6263

6364
const isCollectionToGenericCollection =
64-
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
65+
targetType === 'Collection' &&
66+
COLLECTION_TYPES.some((t) => t === sourceType);
6567

6668
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
6769

invokeai/frontend/web/src/features/nodes/types/constants.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,16 @@ export const isPolymorphicItemType = (
150150
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
151151
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
152152

153-
export const FIELDS: Record<FieldType, FieldUIConfig> = {
153+
export const FIELDS: Record<FieldType | string, FieldUIConfig> = {
154154
Any: {
155155
color: 'gray.500',
156156
description: 'Any field type is accepted.',
157157
title: 'Any',
158158
},
159-
Unknown: {
159+
Custom: {
160160
color: 'gray.500',
161-
description: 'Unknown field type is accepted.',
162-
title: 'Unknown',
161+
description: 'A custom field, provided by an external node.',
162+
title: 'Custom',
163163
},
164164
MetadataField: {
165165
color: 'gray.500',

invokeai/frontend/web/src/features/nodes/types/types.ts

+10-9
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ export const zFieldType = z.enum([
133133
'UNetField',
134134
'VaeField',
135135
'VaeModelField',
136-
'Unknown',
136+
'Custom',
137137
]);
138138

139139
export type FieldType = z.infer<typeof zFieldType>;
@@ -164,6 +164,7 @@ export const zFieldValueBase = z.object({
164164
id: z.string().trim().min(1),
165165
name: z.string().trim().min(1),
166166
type: zFieldType,
167+
originalType: z.string().optional(),
167168
});
168169
export type FieldValueBase = z.infer<typeof zFieldValueBase>;
169170

@@ -191,7 +192,7 @@ export type OutputFieldTemplate = {
191192
type: FieldType;
192193
title: string;
193194
description: string;
194-
originalType: string; // used for custom types
195+
originalType?: string; // used for custom types
195196
} & _OutputField;
196197

197198
export const zInputFieldValueBase = zFieldValueBase.extend({
@@ -791,8 +792,8 @@ export const zAnyInputFieldValue = zInputFieldValueBase.extend({
791792
value: z.any().optional(),
792793
});
793794

794-
export const zUnknownInputFieldValue = zInputFieldValueBase.extend({
795-
type: z.literal('Unknown'),
795+
export const zCustomInputFieldValue = zInputFieldValueBase.extend({
796+
type: z.literal('Custom'),
796797
value: z.any().optional(),
797798
});
798799

@@ -853,7 +854,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
853854
zMetadataItemPolymorphicInputFieldValue,
854855
zMetadataInputFieldValue,
855856
zMetadataCollectionInputFieldValue,
856-
zUnknownInputFieldValue,
857+
zCustomInputFieldValue,
857858
]);
858859

859860
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
@@ -864,16 +865,16 @@ export type InputFieldTemplateBase = {
864865
description: string;
865866
required: boolean;
866867
fieldKind: 'input';
867-
originalType: string; // used for custom types
868+
originalType?: string; // used for custom types
868869
} & _InputField;
869870

870871
export type AnyInputFieldTemplate = InputFieldTemplateBase & {
871872
type: 'Any';
872873
default: undefined;
873874
};
874875

875-
export type UnknownInputFieldTemplate = InputFieldTemplateBase & {
876-
type: 'Unknown';
876+
export type CustomInputFieldTemplate = InputFieldTemplateBase & {
877+
type: 'Custom';
877878
default: undefined;
878879
};
879880

@@ -1274,7 +1275,7 @@ export type InputFieldTemplate =
12741275
| MetadataInputFieldTemplate
12751276
| MetadataItemPolymorphicInputFieldTemplate
12761277
| MetadataCollectionInputFieldTemplate
1277-
| UnknownInputFieldTemplate;
1278+
| CustomInputFieldTemplate;
12781279

12791280
export const isInputFieldValue = (
12801281
field?: InputFieldValue | OutputFieldValue

0 commit comments

Comments
 (0)