Skip to content

Commit 09d20fa

Browse files
fix: change data source lifecycle on agent memory mysql saver (#3578)
* fix: change data source lifecycle on agent memory mysql saver * Update mysqlSaver.ts * Update pgSaver.ts * linting fix --------- Co-authored-by: Henry Heng <[email protected]>
1 parent 371da23 commit 09d20fa

File tree

2 files changed

+159
-129
lines changed

2 files changed

+159
-129
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,46 @@
11
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
22
import { RunnableConfig } from '@langchain/core/runnables'
33
import { BaseMessage } from '@langchain/core/messages'
4-
import { DataSource, QueryRunner } from 'typeorm'
4+
import { DataSource } from 'typeorm'
55
import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface'
66
import { IMessage, MemoryMethods } from '../../../src/Interface'
77
import { mapChatMessageToBaseMessage } from '../../../src/utils'
88

99
export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods {
1010
protected isSetup: boolean
11-
12-
datasource: DataSource
13-
14-
queryRunner: QueryRunner
15-
1611
config: SaverOptions
17-
1812
threadId: string
19-
2013
tableName = 'checkpoints'
2114

2215
constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
2316
super(serde)
2417
this.config = config
25-
const { datasourceOptions, threadId } = config
18+
const { threadId } = config
2619
this.threadId = threadId
27-
this.datasource = new DataSource(datasourceOptions)
2820
}
2921

30-
private async setup(): Promise<void> {
31-
if (this.isSetup) {
32-
return
33-
}
22+
private async getDataSource(): Promise<DataSource> {
23+
const { datasourceOptions } = this.config
24+
const dataSource = new DataSource(datasourceOptions)
25+
await dataSource.initialize()
26+
return dataSource
27+
}
28+
29+
private async setup(dataSource: DataSource): Promise<void> {
30+
if (this.isSetup) return
3431

3532
try {
36-
const appDataSource = await this.datasource.initialize()
37-
38-
this.queryRunner = appDataSource.createQueryRunner()
39-
await this.queryRunner.manager.query(`
40-
CREATE TABLE IF NOT EXISTS ${this.tableName} (
41-
thread_id VARCHAR(255) NOT NULL,
42-
checkpoint_id VARCHAR(255) NOT NULL,
43-
parent_id VARCHAR(255),
44-
checkpoint BLOB,
45-
metadata BLOB,
46-
PRIMARY KEY (thread_id, checkpoint_id)
47-
);`)
33+
const queryRunner = dataSource.createQueryRunner()
34+
await queryRunner.manager.query(`
35+
CREATE TABLE IF NOT EXISTS ${this.tableName} (
36+
thread_id VARCHAR(255) NOT NULL,
37+
checkpoint_id VARCHAR(255) NOT NULL,
38+
parent_id VARCHAR(255),
39+
checkpoint BLOB,
40+
metadata BLOB,
41+
PRIMARY KEY (thread_id, checkpoint_id)
42+
);`)
43+
await queryRunner.release()
4844
} catch (error) {
4945
console.error(`Error creating ${this.tableName} table`, error)
5046
throw new Error(`Error creating ${this.tableName} table`)
@@ -54,79 +50,67 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
5450
}
5551

5652
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
57-
await this.setup()
53+
const dataSource = await this.getDataSource()
54+
await this.setup(dataSource)
55+
5856
const thread_id = config.configurable?.thread_id || this.threadId
5957
const checkpoint_id = config.configurable?.checkpoint_id
6058

61-
if (checkpoint_id) {
62-
try {
63-
const keys = [thread_id, checkpoint_id]
64-
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`
65-
66-
const rows = await this.queryRunner.manager.query(sql, keys)
67-
68-
if (rows && rows.length > 0) {
69-
return {
70-
config,
71-
checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint,
72-
metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata,
73-
parentConfig: rows[0].parent_id
74-
? {
75-
configurable: {
76-
thread_id,
77-
checkpoint_id: rows[0].parent_id
78-
}
79-
}
80-
: undefined
81-
}
82-
}
83-
} catch (error) {
84-
console.error(`Error retrieving ${this.tableName}`, error)
85-
throw new Error(`Error retrieving ${this.tableName}`)
86-
}
87-
} else {
88-
const keys = [thread_id]
89-
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
59+
try {
60+
const queryRunner = dataSource.createQueryRunner()
61+
const sql = checkpoint_id
62+
? `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`
63+
: `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
9064

91-
const rows = await this.queryRunner.manager.query(sql, keys)
65+
const rows = await queryRunner.manager.query(sql, checkpoint_id ? [thread_id, checkpoint_id] : [thread_id])
66+
await queryRunner.release()
9267

9368
if (rows && rows.length > 0) {
69+
const row = rows[0]
9470
return {
9571
config: {
9672
configurable: {
97-
thread_id: rows[0].thread_id,
98-
checkpoint_id: rows[0].checkpoint_id
73+
thread_id: row.thread_id || thread_id,
74+
checkpoint_id: row.checkpoint_id || checkpoint_id
9975
}
10076
},
101-
checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint,
102-
metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata,
103-
parentConfig: rows[0].parent_id
77+
checkpoint: (await this.serde.parse(row.checkpoint.toString())) as Checkpoint,
78+
metadata: (await this.serde.parse(row.metadata.toString())) as CheckpointMetadata,
79+
parentConfig: row.parent_id
10480
? {
10581
configurable: {
106-
thread_id: rows[0].thread_id,
107-
checkpoint_id: rows[0].parent_id
82+
thread_id,
83+
checkpoint_id: row.parent_id
10884
}
10985
}
11086
: undefined
11187
}
11288
}
89+
} catch (error) {
90+
console.error(`Error retrieving ${this.tableName}`, error)
91+
throw new Error(`Error retrieving ${this.tableName}`)
92+
} finally {
93+
await dataSource.destroy()
11394
}
11495
return undefined
11596
}
11697

117-
async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> {
118-
await this.setup()
119-
const thread_id = config.configurable?.thread_id || this.threadId
120-
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
121-
before ? 'AND checkpoint_id < ?' : ''
122-
} ORDER BY checkpoint_id DESC`
123-
if (limit) {
124-
sql += ` LIMIT ${limit}`
125-
}
126-
const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean)
127-
98+
async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple, void, unknown> {
99+
const dataSource = await this.getDataSource()
100+
await this.setup(dataSource)
101+
const queryRunner = dataSource.createQueryRunner()
128102
try {
129-
const rows = await this.queryRunner.manager.query(sql, args)
103+
const threadId = config.configurable?.thread_id || this.threadId
104+
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
105+
before ? 'AND checkpoint_id < ?' : ''
106+
} ORDER BY checkpoint_id DESC`
107+
if (limit) {
108+
sql += ` LIMIT ${limit}`
109+
}
110+
const args = [threadId, before?.configurable?.checkpoint_id].filter(Boolean)
111+
112+
const rows = await queryRunner.manager.query(sql, args)
113+
await queryRunner.release()
130114

131115
if (rows && rows.length > 0) {
132116
for (const row of rows) {
@@ -151,15 +135,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
151135
}
152136
}
153137
} catch (error) {
154-
console.error(`Error listing ${this.tableName}`, error)
155-
throw new Error(`Error listing ${this.tableName}`)
138+
console.error(`Error listing checkpoints`, error)
139+
throw new Error(`Error listing checkpoints`)
140+
} finally {
141+
await dataSource.destroy()
156142
}
157143
}
158144

159145
async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
160-
await this.setup()
146+
const dataSource = await this.getDataSource()
147+
await this.setup(dataSource)
148+
161149
if (!config.configurable?.checkpoint_id) return {}
162150
try {
151+
const queryRunner = dataSource.createQueryRunner()
163152
const row = [
164153
config.configurable?.thread_id || this.threadId,
165154
checkpoint.id,
@@ -172,10 +161,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
172161
VALUES (?, ?, ?, ?, ?)
173162
ON DUPLICATE KEY UPDATE checkpoint = VALUES(checkpoint), metadata = VALUES(metadata)`
174163

175-
await this.queryRunner.manager.query(query, row)
164+
await queryRunner.manager.query(query, row)
165+
await queryRunner.release()
176166
} catch (error) {
177167
console.error('Error saving checkpoint', error)
178168
throw new Error('Error saving checkpoint')
169+
} finally {
170+
await dataSource.destroy()
179171
}
180172

181173
return {
@@ -187,16 +179,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
187179
}
188180

189181
async delete(threadId: string): Promise<void> {
190-
if (!threadId) {
191-
return
192-
}
193-
await this.setup()
194-
const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;`
182+
if (!threadId) return
183+
184+
const dataSource = await this.getDataSource()
185+
await this.setup(dataSource)
195186

196187
try {
197-
await this.queryRunner.manager.query(query, [threadId])
188+
const queryRunner = dataSource.createQueryRunner()
189+
const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;`
190+
await queryRunner.manager.query(query, [threadId])
191+
await queryRunner.release()
198192
} catch (error) {
199193
console.error(`Error deleting thread_id ${threadId}`, error)
194+
} finally {
195+
await dataSource.destroy()
200196
}
201197
}
202198

@@ -232,6 +228,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
232228
type: m.role
233229
})
234230
}
231+
235232
return returnIMessages
236233
}
237234

@@ -240,6 +237,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
240237
}
241238

242239
async clearChatMessages(overrideSessionId = ''): Promise<void> {
240+
if (!overrideSessionId) return
243241
await this.delete(overrideSessionId)
244242
}
245243
}

0 commit comments

Comments
 (0)