Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/SQLite agent memory node #3650

Merged
merged 4 commits into from
Dec 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,49 +1,47 @@
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
import { RunnableConfig } from '@langchain/core/runnables'
import { BaseMessage } from '@langchain/core/messages'
import { DataSource, QueryRunner } from 'typeorm'
import { DataSource } from 'typeorm'
import { CheckpointTuple, SaverOptions, SerializerProtocol } from '../interface'
import { IMessage, MemoryMethods } from '../../../../src/Interface'
import { mapChatMessageToBaseMessage } from '../../../../src/utils'

export class SqliteSaver extends BaseCheckpointSaver implements MemoryMethods {
protected isSetup: boolean

datasource: DataSource

queryRunner: QueryRunner

config: SaverOptions

threadId: string

tableName = 'checkpoints'

constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
super(serde)
this.config = config
const { datasourceOptions, threadId } = config
const { threadId } = config
this.threadId = threadId
this.datasource = new DataSource(datasourceOptions)
}

private async setup(): Promise<void> {
private async getDataSource(): Promise<DataSource> {
const { datasourceOptions } = this.config
const dataSource = new DataSource(datasourceOptions)
await dataSource.initialize()
return dataSource
}

private async setup(dataSource: DataSource): Promise<void> {
if (this.isSetup) {
return
}

try {
const appDataSource = await this.datasource.initialize()

this.queryRunner = appDataSource.createQueryRunner()
await this.queryRunner.manager.query(`
const queryRunner = dataSource.createQueryRunner()
await queryRunner.manager.query(`
CREATE TABLE IF NOT EXISTS ${this.tableName} (
thread_id TEXT NOT NULL,
checkpoint_id TEXT NOT NULL,
parent_id TEXT,
checkpoint BLOB,
metadata BLOB,
PRIMARY KEY (thread_id, checkpoint_id));`)
await queryRunner.release()
} catch (error) {
console.error(`Error creating ${this.tableName} table`, error)
throw new Error(`Error creating ${this.tableName} table`)
Expand All @@ -53,16 +51,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
}

async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
await this.setup()
const dataSource = await this.getDataSource()
await this.setup(dataSource)

const thread_id = config.configurable?.thread_id || this.threadId
const checkpoint_id = config.configurable?.checkpoint_id

if (checkpoint_id) {
try {
const queryRunner = dataSource.createQueryRunner()
const keys = [thread_id, checkpoint_id]
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`

const rows = await this.queryRunner.manager.query(sql, [...keys])
const rows = await queryRunner.manager.query(sql, [...keys])
await queryRunner.release()

if (rows && rows.length > 0) {
return {
Expand All @@ -82,39 +84,53 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
} catch (error) {
console.error(`Error retrieving ${this.tableName}`, error)
throw new Error(`Error retrieving ${this.tableName}`)
} finally {
await dataSource.destroy()
}
} else {
const keys = [thread_id]
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
try {
const queryRunner = dataSource.createQueryRunner()
const keys = [thread_id]
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`

const rows = await this.queryRunner.manager.query(sql, [...keys])
const rows = await queryRunner.manager.query(sql, [...keys])
await queryRunner.release()

if (rows && rows.length > 0) {
return {
config: {
configurable: {
thread_id: rows[0].thread_id,
checkpoint_id: rows[0].checkpoint_id
}
},
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
parentConfig: rows[0].parent_id
? {
configurable: {
thread_id: rows[0].thread_id,
checkpoint_id: rows[0].parent_id
if (rows && rows.length > 0) {
return {
config: {
configurable: {
thread_id: rows[0].thread_id,
checkpoint_id: rows[0].checkpoint_id
}
},
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
parentConfig: rows[0].parent_id
? {
configurable: {
thread_id: rows[0].thread_id,
checkpoint_id: rows[0].parent_id
}
}
}
: undefined
: undefined
}
}
} catch (error) {
console.error(`Error retrieving ${this.tableName}`, error)
throw new Error(`Error retrieving ${this.tableName}`)
} finally {
await dataSource.destroy()
}
}
return undefined
}

async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> {
await this.setup()
const dataSource = await this.getDataSource()
await this.setup(dataSource)

const queryRunner = dataSource.createQueryRunner()
const thread_id = config.configurable?.thread_id || this.threadId
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
before ? 'AND checkpoint_id < ?' : ''
Expand All @@ -125,7 +141,8 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean)

try {
const rows = await this.queryRunner.manager.query(sql, [...args])
const rows = await queryRunner.manager.query(sql, [...args])
await queryRunner.release()

if (rows && rows.length > 0) {
for (const row of rows) {
Expand All @@ -152,13 +169,18 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
} catch (error) {
console.error(`Error listing ${this.tableName}`, error)
throw new Error(`Error listing ${this.tableName}`)
} finally {
await dataSource.destroy()
}
}

async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
await this.setup()
const dataSource = await this.getDataSource()
await this.setup(dataSource)

if (!config.configurable?.checkpoint_id) return {}
try {
const queryRunner = dataSource.createQueryRunner()
const row = [
config.configurable?.thread_id || this.threadId,
checkpoint.id,
Expand All @@ -169,10 +191,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (

const query = `INSERT OR REPLACE INTO ${this.tableName} (thread_id, checkpoint_id, parent_id, checkpoint, metadata) VALUES (?, ?, ?, ?, ?)`

await this.queryRunner.manager.query(query, row)
await queryRunner.manager.query(query, row)
await queryRunner.release()
} catch (error) {
console.error('Error saving checkpoint', error)
throw new Error('Error saving checkpoint')
} finally {
await dataSource.destroy()
}

return {
Expand All @@ -187,13 +212,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
if (!threadId) {
return
}
await this.setup()

const dataSource = await this.getDataSource()
await this.setup(dataSource)

const query = `DELETE FROM "${this.tableName}" WHERE thread_id = ?;`

try {
await this.queryRunner.manager.query(query, [threadId])
const queryRunner = dataSource.createQueryRunner()
await queryRunner.manager.query(query, [threadId])
await queryRunner.release()
} catch (error) {
console.error(`Error deleting thread_id ${threadId}`, error)
} finally {
await dataSource.destroy()
}
}

Expand Down
Loading