diff --git a/src/pipelines.js b/src/pipelines.js index 231ea704e..af4594b53 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -933,6 +933,8 @@ export class FeatureExtractionPipeline extends Pipeline { // Skip pooling } else if (pooling === 'mean') { result = mean_pooling(result, inputs.attention_mask); + } else if (pooling === 'cls') { + result = result.slice(null, 0); } else { throw Error(`Pooling method '${pooling}' not supported.`); }