@@ -13,25 +13,32 @@ export interface ConcatAttributes extends AttributeWithCacheKey {
13
13
readonly axis : number ;
14
14
}
15
15
16
- const validateInputs = ( inputs : readonly TensorView [ ] ) : void => {
16
+ const validateInputs = ( inputs : readonly TensorView [ ] , axis : number ) : void => {
17
17
if ( ! inputs || inputs . length < 1 ) {
18
18
throw new Error ( 'too few inputs' ) ;
19
19
}
20
-
21
- const inputType = inputs [ 0 ] . dataType ;
22
- const inputDimensionality = inputs [ 0 ] . dims . length ;
23
-
24
- for ( const input of inputs ) {
20
+ const referenceIndex = 0 ;
21
+ const referenceInput = inputs [ referenceIndex ] ;
22
+ const inputType = referenceInput . dataType ;
23
+ const inputRank = referenceInput . dims . length ;
24
+ inputs . forEach ( ( input , i ) => {
25
+ if ( i === referenceIndex ) {
26
+ return ;
27
+ }
25
28
// make sure types of all inputs match
26
29
if ( input . dataType !== inputType ) {
27
30
throw new Error ( 'input tensors should be one type' ) ;
28
31
}
29
-
30
32
// make sure the dimensionality of all inputs are the same
31
- if ( input . dims . length !== inputDimensionality ) {
33
+ if ( input . dims . length !== inputRank ) {
32
34
throw new Error ( 'input tensors should have the same shape' ) ;
33
35
}
34
- }
36
+ input . dims . forEach ( ( dim , i ) => {
37
+ if ( i !== axis && dim !== referenceInput . dims [ i ] ) {
38
+ throw new Error ( 'non concat dimensions must match' ) ;
39
+ }
40
+ } ) ;
41
+ } ) ;
35
42
} ;
36
43
37
44
const calculateInputIndexImpl = ( numberOfTensors : number , sizeInConcatAxisStr : string ) : string => `
@@ -64,65 +71,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
64
71
return codeLines . join ( '\n' ) ;
65
72
} ;
66
73
67
- const createConcatProgramInfo = ( inputs : readonly TensorView [ ] , axis : number ) : ProgramInfo => {
68
- const inputShape = inputs [ 0 ] . dims . slice ( ) ;
69
- if ( axis >= inputShape . length || axis < ( - 1 * inputShape . length ) ) {
70
- throw new Error ( 'axis specified for concat doesn\'t match input dimensionality' ) ;
71
- }
72
- const adjustedAxis = ( axis < 0 ) ? inputShape . length + axis : axis ;
73
- // ensure all of the non-concatenated axes match each other
74
- // calculate the shape of the output tensor while we do that
75
- const outputShape = inputShape . slice ( 0 ) ;
76
- for ( let i = 1 ; i < inputs . length ; i ++ ) {
77
- const dataNShape = inputs [ i ] . dims . slice ( ) ;
78
- for ( let axisIndex = 0 ; axisIndex < inputShape . length ; axisIndex ++ ) {
79
- // add to the placeholder for computing output shape
80
- if ( axisIndex === adjustedAxis ) {
81
- outputShape [ adjustedAxis ] += dataNShape [ axisIndex ] ;
74
+ const createConcatProgramInfo =
75
+ ( inputs : readonly TensorView [ ] , adjustedAxis : number , outputShape : number [ ] , dataType : DataType ) : ProgramInfo => {
76
+ const outputSize = ShapeUtil . size ( outputShape ) ;
77
+
78
+ const sizeInConcatAxis = new Array < number > ( inputs . length ) ;
79
+ const inputVars = new Array < IndicesHelper > ( inputs . length ) ;
80
+
81
+ let previousSum = 0 ;
82
+ const inputDependencies : ProgramInputTensorInfoDependency [ ] = [ ] ;
83
+ const inputRanks = [ ] ;
84
+ const programUniforms : ProgramUniform [ ] = [ { type : DataType . uint32 , data : outputSize } ] ;
85
+ for ( let i = 0 ; i < inputs . length ; ++ i ) {
86
+ previousSum += inputs [ i ] . dims [ adjustedAxis ] ;
87
+ sizeInConcatAxis [ i ] = previousSum ;
88
+ inputRanks . push ( inputs [ i ] . dims . length ) ;
89
+ inputVars [ i ] = inputVariable ( `input${ i } ` , dataType , inputRanks [ i ] ) ;
90
+ inputDependencies . push ( 'rank' ) ;
91
+ programUniforms . push ( { type : DataType . uint32 , data : sizeInConcatAxis [ i ] } ) ;
82
92
}
83
- // ensure all non-cancatenated axes match each other
84
- else if ( inputShape [ axisIndex ] !== dataNShape [ axisIndex ] ) {
85
- throw new Error ( 'non concat dimensions must match' ) ;
93
+ for ( let i = 0 ; i < inputs . length ; ++ i ) {
94
+ programUniforms . push ( ...createTensorShapeVariables ( inputs [ i ] . dims ) ) ;
86
95
}
87
- }
88
- }
89
-
90
- const outputSize = ShapeUtil . size ( outputShape ) ;
91
-
92
- const sizeInConcatAxis = new Array < number > ( inputs . length ) ;
93
- const inputVars = new Array < IndicesHelper > ( inputs . length ) ;
94
- const dataType = inputs [ 0 ] . dataType ;
95
-
96
- let previousSum = 0 ;
97
- const inputDependencies : ProgramInputTensorInfoDependency [ ] = [ ] ;
98
- const inputRanks = [ ] ;
99
- const programUniforms : ProgramUniform [ ] = [ { type : DataType . uint32 , data : outputSize } ] ;
100
- for ( let i = 0 ; i < inputs . length ; ++ i ) {
101
- previousSum += inputs [ i ] . dims [ adjustedAxis ] ;
102
- sizeInConcatAxis [ i ] = previousSum ;
103
- inputRanks . push ( inputs [ i ] . dims . length ) ;
104
- inputVars [ i ] = inputVariable ( `input${ i } ` , dataType , inputRanks [ i ] ) ;
105
- inputDependencies . push ( 'rank' ) ;
106
- programUniforms . push ( { type : DataType . uint32 , data : sizeInConcatAxis [ i ] } ) ;
107
- }
108
- for ( let i = 0 ; i < inputs . length ; ++ i ) {
109
- programUniforms . push ( ...createTensorShapeVariables ( inputs [ i ] . dims ) ) ;
110
- }
111
- programUniforms . push ( ...createTensorShapeVariables ( outputShape ) ) ;
96
+ programUniforms . push ( ...createTensorShapeVariables ( outputShape ) ) ;
112
97
113
- const output = outputVariable ( 'output' , dataType , outputShape . length ) ;
114
- const indicesAxis = output . indicesGet ( 'indices' , adjustedAxis ) ;
115
- const sizeInConcatAxisStr =
116
- Array . from ( Array ( sizeInConcatAxis . length ) . keys ( ) ) . map ( i => `uniforms.sizeInConcatAxis${ i } ` ) . join ( ',' ) ;
117
- const getShaderSource = ( shaderHelper : ShaderHelper ) => `
98
+ const output = outputVariable ( 'output' , dataType , outputShape . length ) ;
99
+ const indicesAxis = output . indicesGet ( 'indices' , adjustedAxis ) ;
100
+ const sizeInConcatAxisStr =
101
+ Array . from ( Array ( sizeInConcatAxis . length ) . keys ( ) ) . map ( i => `uniforms.sizeInConcatAxis${ i } ` ) . join ( ',' ) ;
102
+ const getShaderSource = ( shaderHelper : ShaderHelper ) => `
118
103
119
104
${ ( ( ) => {
120
- shaderHelper . registerUniform ( 'outputSize' , 'u32' ) ;
121
- for ( let i = 0 ; i < inputs . length ; i ++ ) {
122
- shaderHelper . registerUniform ( `sizeInConcatAxis${ i } ` , 'u32' ) ;
123
- }
124
- return shaderHelper . declareVariables ( ...inputVars , output ) ;
125
- } ) ( ) }
105
+ shaderHelper . registerUniform ( 'outputSize' , 'u32' ) ;
106
+ for ( let i = 0 ; i < inputs . length ; i ++ ) {
107
+ shaderHelper . registerUniform ( `sizeInConcatAxis${ i } ` , 'u32' ) ;
108
+ }
109
+ return shaderHelper . declareVariables ( ...inputVars , output ) ;
110
+ } ) ( ) }
126
111
127
112
${ calculateInputIndexImpl ( sizeInConcatAxis . length , sizeInConcatAxisStr ) }
128
113
@@ -140,23 +125,30 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
140
125
${ assignOutputData ( inputVars , output ) }
141
126
}` ;
142
127
143
- return {
144
- name : 'Concat' ,
145
- shaderCache : { hint : `${ axis } ` , inputDependencies} ,
146
- getRunData : ( ) => ( {
147
- outputs : [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ,
148
- dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
149
- programUniforms,
150
- } ) ,
151
- getShaderSource,
152
- } ;
153
- } ;
128
+ return {
129
+ name : 'Concat' ,
130
+ shaderCache : { hint : `${ adjustedAxis } ` , inputDependencies} ,
131
+ getRunData : ( ) => ( {
132
+ outputs : [ { dims : outputShape , dataType} ] ,
133
+ dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
134
+ programUniforms,
135
+ } ) ,
136
+ getShaderSource,
137
+ } ;
138
+ } ;
154
139
155
140
export const concat = ( context : ComputeContext , attributes : ConcatAttributes ) : void => {
156
- validateInputs ( context . inputs ) ;
141
+ const inputs = context . inputs ;
142
+ const inputShape = inputs [ 0 ] . dims ;
143
+ const adjustedAxis = ShapeUtil . normalizeAxis ( attributes . axis , inputShape . length ) ;
144
+ validateInputs ( inputs , adjustedAxis ) ;
145
+ const outputShape = inputShape . slice ( ) ;
146
+ outputShape [ adjustedAxis ] =
147
+ inputs . reduce ( ( sum , input ) => sum + ( input . dims . length > adjustedAxis ? input . dims [ adjustedAxis ] : 0 ) , 0 ) ;
157
148
// 0 length tensors are valid for concat, remove them
158
- const nonEmptyInputs = context . inputs . filter ( input => ShapeUtil . size ( input . dims ) > 0 ) ;
159
- context . compute ( createConcatProgramInfo ( nonEmptyInputs , attributes . axis ) , { inputs : nonEmptyInputs } ) ;
149
+ const nonEmptyInputs = inputs . filter ( input => ShapeUtil . size ( input . dims ) > 0 ) ;
150
+ context . compute (
151
+ createConcatProgramInfo ( nonEmptyInputs , adjustedAxis , outputShape , inputs [ 0 ] . dataType ) , { inputs : nonEmptyInputs } ) ;
160
152
} ;
161
153
162
154
export const parseConcatAttributes = ( attributes : Record < string , unknown > ) : ConcatAttributes =>
0 commit comments