diff --git a/src/config/configSchema.ts b/src/config/configSchema.ts index dbc713d63..0cead3e74 100644 --- a/src/config/configSchema.ts +++ b/src/config/configSchema.ts @@ -1,4 +1,4 @@ -import type { TiktokenEncoding } from 'tiktoken'; +import type { TiktokenEncoding } from 'tiktoken/init'; import { z } from 'zod'; // Output style enum diff --git a/src/core/metrics/TokenCounter.ts b/src/core/metrics/TokenCounter.ts index 7ae1dcb46..06cb6a76b 100644 --- a/src/core/metrics/TokenCounter.ts +++ b/src/core/metrics/TokenCounter.ts @@ -1,6 +1,26 @@ -import { get_encoding, type Tiktoken, type TiktokenEncoding } from 'tiktoken'; +import { get_encoding, init, type Tiktoken, type TiktokenEncoding } from 'tiktoken/init'; import { logger } from '../../shared/logger.js'; +/** + * Initialize tiktoken with a pre-compiled WebAssembly module. + * + * When called with a module, tiktoken skips the expensive WASM compilation step + * and only performs instantiation (~6ms vs ~250ms). + * When called without a module, falls back to reading and compiling from disk. + */ +export const initTiktokenWasm = async (wasmModule?: WebAssembly.Module): Promise => { + if (wasmModule) { + await init((imports) => WebAssembly.instantiate(wasmModule, imports)); + } else { + const fs = await import('node:fs/promises'); + const { createRequire } = await import('node:module'); + const require = createRequire(import.meta.url); + const wasmPath = require.resolve('tiktoken/tiktoken_bg.wasm'); + const wasmBinary = await fs.readFile(wasmPath); + await init((imports) => WebAssembly.instantiate(wasmBinary, imports)); + } +}; + export class TokenCounter { private encoding: Tiktoken; diff --git a/src/core/metrics/calculateMetrics.ts b/src/core/metrics/calculateMetrics.ts index 6ec76a3e4..26fd086ed 100644 --- a/src/core/metrics/calculateMetrics.ts +++ b/src/core/metrics/calculateMetrics.ts @@ -9,6 +9,7 @@ import { calculateGitDiffMetrics } from './calculateGitDiffMetrics.js'; import { calculateGitLogMetrics } from './calculateGitLogMetrics.js'; import { calculateOutputMetrics } from './calculateOutputMetrics.js'; import { calculateSelectiveFileMetrics } from './calculateSelectiveFileMetrics.js'; +import { getCompiledTiktokenWasmModule } from './wasmModuleCache.js'; import type { TokenCountTask } from './workers/calculateMetricsWorker.js'; export interface CalculateMetricsResult { @@ -38,6 +39,19 @@ export const calculateMetrics = async ( ): Promise => { progressCallback('Calculating metrics...'); + // Pre-compile tiktoken WASM module once in the main thread and pass it to workers. + // This avoids each worker independently compiling the ~5.3MB WASM binary (~250ms each). + // Only compile when we will create our own task runner; skip if deps.taskRunner is provided. + let tiktokenWasmModule: WebAssembly.Module | undefined; + if (!deps.taskRunner) { + try { + tiktokenWasmModule = await getCompiledTiktokenWasmModule(); + } catch { + // Fall back to per-worker compilation if main thread precompile fails + tiktokenWasmModule = undefined; + } + } + // Initialize a single task runner for all metrics calculations const taskRunner = deps.taskRunner ?? @@ -45,6 +59,7 @@ export const calculateMetrics = async ( numOfTasks: processedFiles.length, workerType: 'calculateMetrics', runtime: 'worker_threads', + extraWorkerData: tiktokenWasmModule ? { tiktokenWasmModule } : undefined, }); try { diff --git a/src/core/metrics/calculateOutputMetrics.ts b/src/core/metrics/calculateOutputMetrics.ts index ad41ae918..240930b2b 100644 --- a/src/core/metrics/calculateOutputMetrics.ts +++ b/src/core/metrics/calculateOutputMetrics.ts @@ -1,4 +1,4 @@ -import type { TiktokenEncoding } from 'tiktoken'; +import type { TiktokenEncoding } from 'tiktoken/init'; import { logger } from '../../shared/logger.js'; import type { TaskRunner } from '../../shared/processConcurrency.js'; import type { TokenCountTask } from './workers/calculateMetricsWorker.js'; diff --git a/src/core/metrics/calculateSelectiveFileMetrics.ts b/src/core/metrics/calculateSelectiveFileMetrics.ts index 02f52726a..356e04887 100644 --- a/src/core/metrics/calculateSelectiveFileMetrics.ts +++ b/src/core/metrics/calculateSelectiveFileMetrics.ts @@ -1,5 +1,5 @@ import pc from 'picocolors'; -import type { TiktokenEncoding } from 'tiktoken'; +import type { TiktokenEncoding } from 'tiktoken/init'; import { logger } from '../../shared/logger.js'; import type { TaskRunner } from '../../shared/processConcurrency.js'; import type { RepomixProgressCallback } from '../../shared/types.js'; diff --git a/src/core/metrics/tokenCounterFactory.ts b/src/core/metrics/tokenCounterFactory.ts index 8f51f0ba5..6fca53c05 100644 --- a/src/core/metrics/tokenCounterFactory.ts +++ b/src/core/metrics/tokenCounterFactory.ts @@ -1,4 +1,4 @@ -import type { TiktokenEncoding } from 'tiktoken'; +import type { TiktokenEncoding } from 'tiktoken/init'; import { logger } from '../../shared/logger.js'; import { TokenCounter } from './TokenCounter.js'; diff --git a/src/core/metrics/wasmModuleCache.ts b/src/core/metrics/wasmModuleCache.ts new file mode 100644 index 000000000..8d98aeee4 --- /dev/null +++ b/src/core/metrics/wasmModuleCache.ts @@ -0,0 +1,37 @@ +import fs from 'node:fs/promises'; +import { createRequire } from 'node:module'; +import { logger } from '../../shared/logger.js'; + +let compiledModule: WebAssembly.Module | null = null; + +/** + * Resolve the file path to the tiktoken WASM binary. + */ +const getTiktokenWasmPath = (): string => { + const require = createRequire(import.meta.url); + return require.resolve('tiktoken/tiktoken_bg.wasm'); +}; + +/** + * Compile the tiktoken WASM binary once and cache the resulting WebAssembly.Module. + * + * The compiled module can be transferred to worker threads via structured clone + * (WebAssembly.Module is transferable), avoiding the ~250ms per-worker compile cost. + */ +export const getCompiledTiktokenWasmModule = async (): Promise => { + if (compiledModule) { + return compiledModule; + } + + const startTime = process.hrtime.bigint(); + + const wasmPath = getTiktokenWasmPath(); + const wasmBinary = await fs.readFile(wasmPath); + compiledModule = await WebAssembly.compile(wasmBinary); + + const endTime = process.hrtime.bigint(); + const compileTime = Number(endTime - startTime) / 1e6; + logger.debug(`Tiktoken WASM compilation took ${compileTime.toFixed(2)}ms`); + + return compiledModule; +}; diff --git a/src/core/metrics/workers/calculateMetricsWorker.ts b/src/core/metrics/workers/calculateMetricsWorker.ts index 241af02e0..b9593819a 100644 --- a/src/core/metrics/workers/calculateMetricsWorker.ts +++ b/src/core/metrics/workers/calculateMetricsWorker.ts @@ -1,5 +1,7 @@ -import type { TiktokenEncoding } from 'tiktoken'; +import { workerData } from 'node:worker_threads'; +import type { TiktokenEncoding } from 'tiktoken/init'; import { logger, setLogLevelByWorkerData } from '../../../shared/logger.js'; +import { initTiktokenWasm } from '../TokenCounter.js'; import { freeTokenCounters, getTokenCounter } from '../tokenCounterFactory.js'; /** @@ -14,6 +16,16 @@ import { freeTokenCounters, getTokenCounter } from '../tokenCounterFactory.js'; // This must be called before any logging operations in the worker setLogLevelByWorkerData(); +// Extract the pre-compiled WASM module from workerData. +// Tinypool wraps workerData as [tinypoolPrivateData, userWorkerData], so we access index [1]. +const userWorkerData = Array.isArray(workerData) ? workerData[1] : workerData; +const wasmModule = userWorkerData?.tiktokenWasmModule; + +// Initialize tiktoken WASM with the pre-compiled module from the main thread. +// If a valid WebAssembly.Module is present, this avoids re-compiling the ~5.3MB +// WASM binary in each worker thread (~6ms instantiation vs ~250ms compile+instantiate). +const wasmInitPromise = initTiktokenWasm(wasmModule instanceof WebAssembly.Module ? wasmModule : undefined); + export interface TokenCountTask { content: string; encoding: TiktokenEncoding; @@ -21,6 +33,9 @@ export interface TokenCountTask { } export const countTokens = async (task: TokenCountTask): Promise => { + // Ensure WASM is initialized before first token count + await wasmInitPromise; + const processStartAt = process.hrtime.bigint(); try { diff --git a/src/shared/processConcurrency.ts b/src/shared/processConcurrency.ts index 0507131a2..d6c428132 100644 --- a/src/shared/processConcurrency.ts +++ b/src/shared/processConcurrency.ts @@ -12,6 +12,7 @@ export interface WorkerOptions { numOfTasks: number; workerType: WorkerType; runtime: WorkerRuntime; + extraWorkerData?: Record; } /** @@ -62,7 +63,7 @@ export const getWorkerThreadCount = (numOfTasks: number): { minThreads: number; }; export const createWorkerPool = (options: WorkerOptions): Tinypool => { - const { numOfTasks, workerType, runtime = 'child_process' } = options; + const { numOfTasks, workerType, runtime = 'child_process', extraWorkerData } = options; const { minThreads, maxThreads } = getWorkerThreadCount(numOfTasks); // Get worker path - uses unified worker in bundled env, individual files otherwise @@ -84,6 +85,7 @@ export const createWorkerPool = (options: WorkerOptions): Tinypool => { workerData: { workerType, logLevel: logger.getLogLevel(), + ...extraWorkerData, }, // Only add env for child_process workers ...(runtime === 'child_process' && { diff --git a/src/types/webassembly.d.ts b/src/types/webassembly.d.ts new file mode 100644 index 000000000..bc190e44d --- /dev/null +++ b/src/types/webassembly.d.ts @@ -0,0 +1,62 @@ +/** + * Minimal WebAssembly type declarations. + * + * The project targets es2022 which does not include WebAssembly types. + * These declarations cover only the subset used by the tiktoken WASM module sharing. + */ + +declare namespace WebAssembly { + class Module { + constructor(bytes: BufferSource); + } + + class Instance { + readonly exports: Record; + constructor(module: Module, importObject?: Imports); + } + + interface WebAssemblyInstantiatedSource { + instance: Instance; + module: Module; + } + + // biome-ignore lint/complexity/noBannedTypes: WebAssembly spec defines ImportValue as accepting any function + type ImportValue = Function | Global | Memory | Table | number; + type Imports = Record>; + + class Global { + constructor(descriptor: GlobalDescriptor, value?: number); + value: number; + } + + interface GlobalDescriptor { + mutable?: boolean; + value: string; + } + + class Memory { + constructor(descriptor: MemoryDescriptor); + readonly buffer: ArrayBuffer; + } + + interface MemoryDescriptor { + initial: number; + maximum?: number; + shared?: boolean; + } + + class Table { + constructor(descriptor: TableDescriptor); + readonly length: number; + } + + interface TableDescriptor { + element: string; + initial: number; + maximum?: number; + } + + function compile(bytes: BufferSource): Promise; + function instantiate(moduleObject: Module, importObject?: Imports): Promise; + function instantiate(bytes: BufferSource, importObject?: Imports): Promise; +} diff --git a/tests/core/metrics/TokenCounter.test.ts b/tests/core/metrics/TokenCounter.test.ts index dedc8dbcf..9cc907f08 100644 --- a/tests/core/metrics/TokenCounter.test.ts +++ b/tests/core/metrics/TokenCounter.test.ts @@ -1,10 +1,11 @@ -import { get_encoding, type Tiktoken } from 'tiktoken'; +import { get_encoding, type Tiktoken } from 'tiktoken/init'; import { afterEach, beforeEach, describe, expect, type Mock, test, vi } from 'vitest'; import { TokenCounter } from '../../../src/core/metrics/TokenCounter.js'; import { logger } from '../../../src/shared/logger.js'; -vi.mock('tiktoken', () => ({ +vi.mock('tiktoken/init', () => ({ get_encoding: vi.fn(), + init: vi.fn(), })); vi.mock('../../../src/shared/logger');