diff --git a/packages/components/nodes/memory/AgentMemory/SQLiteAgentMemory/sqliteSaver.ts b/packages/components/nodes/memory/AgentMemory/SQLiteAgentMemory/sqliteSaver.ts index dc746936dac..e38b5bbb3a3 100644 --- a/packages/components/nodes/memory/AgentMemory/SQLiteAgentMemory/sqliteSaver.ts +++ b/packages/components/nodes/memory/AgentMemory/SQLiteAgentMemory/sqliteSaver.ts @@ -1,42 +1,39 @@ import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph' import { RunnableConfig } from '@langchain/core/runnables' import { BaseMessage } from '@langchain/core/messages' -import { DataSource, QueryRunner } from 'typeorm' +import { DataSource } from 'typeorm' import { CheckpointTuple, SaverOptions, SerializerProtocol } from '../interface' import { IMessage, MemoryMethods } from '../../../../src/Interface' import { mapChatMessageToBaseMessage } from '../../../../src/utils' export class SqliteSaver extends BaseCheckpointSaver implements MemoryMethods { protected isSetup: boolean - - datasource: DataSource - - queryRunner: QueryRunner - config: SaverOptions - threadId: string - tableName = 'checkpoints' constructor(config: SaverOptions, serde?: SerializerProtocol) { super(serde) this.config = config - const { datasourceOptions, threadId } = config + const { threadId } = config this.threadId = threadId - this.datasource = new DataSource(datasourceOptions) } - private async setup(): Promise { + private async getDataSource(): Promise { + const { datasourceOptions } = this.config + const dataSource = new DataSource(datasourceOptions) + await dataSource.initialize() + return dataSource + } + + private async setup(dataSource: DataSource): Promise { if (this.isSetup) { return } try { - const appDataSource = await this.datasource.initialize() - - this.queryRunner = appDataSource.createQueryRunner() - await this.queryRunner.manager.query(` + const queryRunner = dataSource.createQueryRunner() + await queryRunner.manager.query(` CREATE TABLE IF NOT EXISTS ${this.tableName} ( thread_id TEXT NOT NULL, checkpoint_id TEXT NOT NULL, @@ -44,6 +41,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( checkpoint BLOB, metadata BLOB, PRIMARY KEY (thread_id, checkpoint_id));`) + await queryRunner.release() } catch (error) { console.error(`Error creating ${this.tableName} table`, error) throw new Error(`Error creating ${this.tableName} table`) @@ -53,16 +51,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( } async getTuple(config: RunnableConfig): Promise { - await this.setup() + const dataSource = await this.getDataSource() + await this.setup(dataSource) + const thread_id = config.configurable?.thread_id || this.threadId const checkpoint_id = config.configurable?.checkpoint_id if (checkpoint_id) { try { + const queryRunner = dataSource.createQueryRunner() const keys = [thread_id, checkpoint_id] const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?` - const rows = await this.queryRunner.manager.query(sql, [...keys]) + const rows = await queryRunner.manager.query(sql, [...keys]) + await queryRunner.release() if (rows && rows.length > 0) { return { @@ -82,39 +84,53 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( } catch (error) { console.error(`Error retrieving ${this.tableName}`, error) throw new Error(`Error retrieving ${this.tableName}`) + } finally { + await dataSource.destroy() } } else { - const keys = [thread_id] - const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1` + try { + const queryRunner = dataSource.createQueryRunner() + const keys = [thread_id] + const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1` - const rows = await this.queryRunner.manager.query(sql, [...keys]) + const rows = await queryRunner.manager.query(sql, [...keys]) + await queryRunner.release() - if (rows && rows.length > 0) { - return { - config: { - configurable: { - thread_id: rows[0].thread_id, - checkpoint_id: rows[0].checkpoint_id - } - }, - checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint, - metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata, - parentConfig: rows[0].parent_id - ? { - configurable: { - thread_id: rows[0].thread_id, - checkpoint_id: rows[0].parent_id + if (rows && rows.length > 0) { + return { + config: { + configurable: { + thread_id: rows[0].thread_id, + checkpoint_id: rows[0].checkpoint_id + } + }, + checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint, + metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata, + parentConfig: rows[0].parent_id + ? { + configurable: { + thread_id: rows[0].thread_id, + checkpoint_id: rows[0].parent_id + } } - } - : undefined + : undefined + } } + } catch (error) { + console.error(`Error retrieving ${this.tableName}`, error) + throw new Error(`Error retrieving ${this.tableName}`) + } finally { + await dataSource.destroy() } } return undefined } async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator { - await this.setup() + const dataSource = await this.getDataSource() + await this.setup(dataSource) + + const queryRunner = dataSource.createQueryRunner() const thread_id = config.configurable?.thread_id || this.threadId let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${ before ? 'AND checkpoint_id < ?' : '' @@ -125,7 +141,8 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean) try { - const rows = await this.queryRunner.manager.query(sql, [...args]) + const rows = await queryRunner.manager.query(sql, [...args]) + await queryRunner.release() if (rows && rows.length > 0) { for (const row of rows) { @@ -152,13 +169,18 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( } catch (error) { console.error(`Error listing ${this.tableName}`, error) throw new Error(`Error listing ${this.tableName}`) + } finally { + await dataSource.destroy() } } async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise { - await this.setup() + const dataSource = await this.getDataSource() + await this.setup(dataSource) + if (!config.configurable?.checkpoint_id) return {} try { + const queryRunner = dataSource.createQueryRunner() const row = [ config.configurable?.thread_id || this.threadId, checkpoint.id, @@ -169,10 +191,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( const query = `INSERT OR REPLACE INTO ${this.tableName} (thread_id, checkpoint_id, parent_id, checkpoint, metadata) VALUES (?, ?, ?, ?, ?)` - await this.queryRunner.manager.query(query, row) + await queryRunner.manager.query(query, row) + await queryRunner.release() } catch (error) { console.error('Error saving checkpoint', error) throw new Error('Error saving checkpoint') + } finally { + await dataSource.destroy() } return { @@ -187,13 +212,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( if (!threadId) { return } - await this.setup() + + const dataSource = await this.getDataSource() + await this.setup(dataSource) + const query = `DELETE FROM "${this.tableName}" WHERE thread_id = ?;` try { - await this.queryRunner.manager.query(query, [threadId]) + const queryRunner = dataSource.createQueryRunner() + await queryRunner.manager.query(query, [threadId]) + await queryRunner.release() } catch (error) { console.error(`Error deleting thread_id ${threadId}`, error) + } finally { + await dataSource.destroy() } }