Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/config/configSchema.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { TiktokenEncoding } from 'tiktoken';
import type { TiktokenEncoding } from 'tiktoken/init';
import { z } from 'zod';

// Output style enum
Expand Down
22 changes: 21 additions & 1 deletion src/core/metrics/TokenCounter.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
import { get_encoding, type Tiktoken, type TiktokenEncoding } from 'tiktoken';
import { get_encoding, init, type Tiktoken, type TiktokenEncoding } from 'tiktoken/init';
import { logger } from '../../shared/logger.js';

/**
* Initialize tiktoken with a pre-compiled WebAssembly module.
*
* When called with a module, tiktoken skips the expensive WASM compilation step
* and only performs instantiation (~6ms vs ~250ms).
* When called without a module, falls back to reading and compiling from disk.
*/
export const initTiktokenWasm = async (wasmModule?: WebAssembly.Module): Promise<void> => {
if (wasmModule) {
await init((imports) => WebAssembly.instantiate(wasmModule, imports));
} else {
const fs = await import('node:fs/promises');
const { createRequire } = await import('node:module');
const require = createRequire(import.meta.url);
const wasmPath = require.resolve('tiktoken/tiktoken_bg.wasm');
const wasmBinary = await fs.readFile(wasmPath);
await init((imports) => WebAssembly.instantiate(wasmBinary, imports));
}
};

export class TokenCounter {
private encoding: Tiktoken;

Expand Down
15 changes: 15 additions & 0 deletions src/core/metrics/calculateMetrics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { calculateGitDiffMetrics } from './calculateGitDiffMetrics.js';
import { calculateGitLogMetrics } from './calculateGitLogMetrics.js';
import { calculateOutputMetrics } from './calculateOutputMetrics.js';
import { calculateSelectiveFileMetrics } from './calculateSelectiveFileMetrics.js';
import { getCompiledTiktokenWasmModule } from './wasmModuleCache.js';
import type { TokenCountTask } from './workers/calculateMetricsWorker.js';

export interface CalculateMetricsResult {
Expand Down Expand Up @@ -38,13 +39,27 @@ export const calculateMetrics = async (
): Promise<CalculateMetricsResult> => {
progressCallback('Calculating metrics...');

// Pre-compile tiktoken WASM module once in the main thread and pass it to workers.
// This avoids each worker independently compiling the ~5.3MB WASM binary (~250ms each).
// Only compile when we will create our own task runner; skip if deps.taskRunner is provided.
let tiktokenWasmModule: WebAssembly.Module | undefined;
if (!deps.taskRunner) {
try {
tiktokenWasmModule = await getCompiledTiktokenWasmModule();
} catch {
// Fall back to per-worker compilation if main thread precompile fails
tiktokenWasmModule = undefined;
}
}

// Initialize a single task runner for all metrics calculations
const taskRunner =
deps.taskRunner ??
initTaskRunner<TokenCountTask, number>({
numOfTasks: processedFiles.length,
workerType: 'calculateMetrics',
runtime: 'worker_threads',
extraWorkerData: tiktokenWasmModule ? { tiktokenWasmModule } : undefined,
});

try {
Expand Down
2 changes: 1 addition & 1 deletion src/core/metrics/calculateOutputMetrics.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { TiktokenEncoding } from 'tiktoken';
import type { TiktokenEncoding } from 'tiktoken/init';
import { logger } from '../../shared/logger.js';
import type { TaskRunner } from '../../shared/processConcurrency.js';
import type { TokenCountTask } from './workers/calculateMetricsWorker.js';
Expand Down
2 changes: 1 addition & 1 deletion src/core/metrics/calculateSelectiveFileMetrics.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pc from 'picocolors';
import type { TiktokenEncoding } from 'tiktoken';
import type { TiktokenEncoding } from 'tiktoken/init';
import { logger } from '../../shared/logger.js';
import type { TaskRunner } from '../../shared/processConcurrency.js';
import type { RepomixProgressCallback } from '../../shared/types.js';
Expand Down
2 changes: 1 addition & 1 deletion src/core/metrics/tokenCounterFactory.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { TiktokenEncoding } from 'tiktoken';
import type { TiktokenEncoding } from 'tiktoken/init';
import { logger } from '../../shared/logger.js';
import { TokenCounter } from './TokenCounter.js';

Expand Down
37 changes: 37 additions & 0 deletions src/core/metrics/wasmModuleCache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import fs from 'node:fs/promises';
import { createRequire } from 'node:module';
import { logger } from '../../shared/logger.js';

let compiledModule: WebAssembly.Module | null = null;

/**
* Resolve the file path to the tiktoken WASM binary.
*/
const getTiktokenWasmPath = (): string => {
const require = createRequire(import.meta.url);
return require.resolve('tiktoken/tiktoken_bg.wasm');
};

/**
* Compile the tiktoken WASM binary once and cache the resulting WebAssembly.Module.
*
* The compiled module can be transferred to worker threads via structured clone
* (WebAssembly.Module is transferable), avoiding the ~250ms per-worker compile cost.
*/
export const getCompiledTiktokenWasmModule = async (): Promise<WebAssembly.Module> => {
if (compiledModule) {
return compiledModule;
}

const startTime = process.hrtime.bigint();

const wasmPath = getTiktokenWasmPath();
const wasmBinary = await fs.readFile(wasmPath);
compiledModule = await WebAssembly.compile(wasmBinary);

const endTime = process.hrtime.bigint();
const compileTime = Number(endTime - startTime) / 1e6;
logger.debug(`Tiktoken WASM compilation took ${compileTime.toFixed(2)}ms`);

return compiledModule;
};
17 changes: 16 additions & 1 deletion src/core/metrics/workers/calculateMetricsWorker.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import type { TiktokenEncoding } from 'tiktoken';
import { workerData } from 'node:worker_threads';
import type { TiktokenEncoding } from 'tiktoken/init';
import { logger, setLogLevelByWorkerData } from '../../../shared/logger.js';
import { initTiktokenWasm } from '../TokenCounter.js';
import { freeTokenCounters, getTokenCounter } from '../tokenCounterFactory.js';

/**
Expand All @@ -14,13 +16,26 @@ import { freeTokenCounters, getTokenCounter } from '../tokenCounterFactory.js';
// This must be called before any logging operations in the worker
setLogLevelByWorkerData();

// Extract the pre-compiled WASM module from workerData.
// Tinypool wraps workerData as [tinypoolPrivateData, userWorkerData], so we access index [1].
const userWorkerData = Array.isArray(workerData) ? workerData[1] : workerData;
const wasmModule = userWorkerData?.tiktokenWasmModule;

// Initialize tiktoken WASM with the pre-compiled module from the main thread.
// If a valid WebAssembly.Module is present, this avoids re-compiling the ~5.3MB
// WASM binary in each worker thread (~6ms instantiation vs ~250ms compile+instantiate).
const wasmInitPromise = initTiktokenWasm(wasmModule instanceof WebAssembly.Module ? wasmModule : undefined);

export interface TokenCountTask {
content: string;
encoding: TiktokenEncoding;
path?: string;
}

export const countTokens = async (task: TokenCountTask): Promise<number> => {
// Ensure WASM is initialized before first token count
await wasmInitPromise;

const processStartAt = process.hrtime.bigint();

try {
Expand Down
4 changes: 3 additions & 1 deletion src/shared/processConcurrency.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export interface WorkerOptions {
numOfTasks: number;
workerType: WorkerType;
runtime: WorkerRuntime;
extraWorkerData?: Record<string, unknown>;
}

/**
Expand Down Expand Up @@ -62,7 +63,7 @@ export const getWorkerThreadCount = (numOfTasks: number): { minThreads: number;
};

export const createWorkerPool = (options: WorkerOptions): Tinypool => {
const { numOfTasks, workerType, runtime = 'child_process' } = options;
const { numOfTasks, workerType, runtime = 'child_process', extraWorkerData } = options;
const { minThreads, maxThreads } = getWorkerThreadCount(numOfTasks);

// Get worker path - uses unified worker in bundled env, individual files otherwise
Expand All @@ -84,6 +85,7 @@ export const createWorkerPool = (options: WorkerOptions): Tinypool => {
workerData: {
workerType,
logLevel: logger.getLogLevel(),
...extraWorkerData,
},
// Only add env for child_process workers
...(runtime === 'child_process' && {
Expand Down
62 changes: 62 additions & 0 deletions src/types/webassembly.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/**
* Minimal WebAssembly type declarations.
*
* The project targets es2022 which does not include WebAssembly types.
* These declarations cover only the subset used by the tiktoken WASM module sharing.
*/

declare namespace WebAssembly {
class Module {
constructor(bytes: BufferSource);
}

class Instance {
readonly exports: Record<string, unknown>;
constructor(module: Module, importObject?: Imports);
}

interface WebAssemblyInstantiatedSource {
instance: Instance;
module: Module;
}

// biome-ignore lint/complexity/noBannedTypes: WebAssembly spec defines ImportValue as accepting any function
type ImportValue = Function | Global | Memory | Table | number;
type Imports = Record<string, Record<string, ImportValue>>;

class Global {
constructor(descriptor: GlobalDescriptor, value?: number);
value: number;
}

interface GlobalDescriptor {
mutable?: boolean;
value: string;
}

class Memory {
constructor(descriptor: MemoryDescriptor);
readonly buffer: ArrayBuffer;
}

interface MemoryDescriptor {
initial: number;
maximum?: number;
shared?: boolean;
}

class Table {
constructor(descriptor: TableDescriptor);
readonly length: number;
}

interface TableDescriptor {
element: string;
initial: number;
maximum?: number;
}

function compile(bytes: BufferSource): Promise<Module>;
function instantiate(moduleObject: Module, importObject?: Imports): Promise<Instance>;
function instantiate(bytes: BufferSource, importObject?: Imports): Promise<WebAssemblyInstantiatedSource>;
}
5 changes: 3 additions & 2 deletions tests/core/metrics/TokenCounter.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { get_encoding, type Tiktoken } from 'tiktoken';
import { get_encoding, type Tiktoken } from 'tiktoken/init';
import { afterEach, beforeEach, describe, expect, type Mock, test, vi } from 'vitest';
import { TokenCounter } from '../../../src/core/metrics/TokenCounter.js';
import { logger } from '../../../src/shared/logger.js';

vi.mock('tiktoken', () => ({
vi.mock('tiktoken/init', () => ({
get_encoding: vi.fn(),
init: vi.fn(),
}));

vi.mock('../../../src/shared/logger');
Expand Down
Loading