1
1
import { BaseCheckpointSaver , Checkpoint , CheckpointMetadata } from '@langchain/langgraph'
2
2
import { RunnableConfig } from '@langchain/core/runnables'
3
3
import { BaseMessage } from '@langchain/core/messages'
4
- import { DataSource , QueryRunner } from 'typeorm'
4
+ import { DataSource } from 'typeorm'
5
5
import { CheckpointTuple , SaverOptions , SerializerProtocol } from './interface'
6
6
import { IMessage , MemoryMethods } from '../../../src/Interface'
7
7
import { mapChatMessageToBaseMessage } from '../../../src/utils'
8
8
9
9
export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods {
10
10
protected isSetup : boolean
11
-
12
- datasource : DataSource
13
-
14
- queryRunner : QueryRunner
15
-
16
11
config : SaverOptions
17
-
18
12
threadId : string
19
-
20
13
tableName = 'checkpoints'
21
14
22
15
constructor ( config : SaverOptions , serde ?: SerializerProtocol < Checkpoint > ) {
23
16
super ( serde )
24
17
this . config = config
25
- const { datasourceOptions , threadId } = config
18
+ const { threadId } = config
26
19
this . threadId = threadId
27
- this . datasource = new DataSource ( datasourceOptions )
28
20
}
29
21
30
- private async setup ( ) : Promise < void > {
31
- if ( this . isSetup ) {
32
- return
33
- }
22
+ private async getDataSource ( ) : Promise < DataSource > {
23
+ const { datasourceOptions } = this . config
24
+ const dataSource = new DataSource ( datasourceOptions )
25
+ await dataSource . initialize ( )
26
+ return dataSource
27
+ }
28
+
29
+ private async setup ( dataSource : DataSource ) : Promise < void > {
30
+ if ( this . isSetup ) return
34
31
35
32
try {
36
- const appDataSource = await this . datasource . initialize ( )
37
-
38
- this . queryRunner = appDataSource . createQueryRunner ( )
39
- await this . queryRunner . manager . query ( `
40
- CREATE TABLE IF NOT EXISTS ${ this . tableName } (
41
- thread_id VARCHAR(255) NOT NULL,
42
- checkpoint_id VARCHAR(255) NOT NULL,
43
- parent_id VARCHAR(255),
44
- checkpoint BLOB,
45
- metadata BLOB,
46
- PRIMARY KEY (thread_id, checkpoint_id)
47
- );` )
33
+ const queryRunner = dataSource . createQueryRunner ( )
34
+ await queryRunner . manager . query ( `
35
+ CREATE TABLE IF NOT EXISTS ${ this . tableName } (
36
+ thread_id VARCHAR(255) NOT NULL,
37
+ checkpoint_id VARCHAR(255) NOT NULL,
38
+ parent_id VARCHAR(255),
39
+ checkpoint BLOB,
40
+ metadata BLOB,
41
+ PRIMARY KEY (thread_id, checkpoint_id)
42
+ );` )
43
+ await queryRunner . release ( )
48
44
} catch ( error ) {
49
45
console . error ( `Error creating ${ this . tableName } table` , error )
50
46
throw new Error ( `Error creating ${ this . tableName } table` )
@@ -54,79 +50,67 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
54
50
}
55
51
56
52
async getTuple ( config : RunnableConfig ) : Promise < CheckpointTuple | undefined > {
57
- await this . setup ( )
53
+ const dataSource = await this . getDataSource ( )
54
+ await this . setup ( dataSource )
55
+
58
56
const thread_id = config . configurable ?. thread_id || this . threadId
59
57
const checkpoint_id = config . configurable ?. checkpoint_id
60
58
61
- if ( checkpoint_id ) {
62
- try {
63
- const keys = [ thread_id , checkpoint_id ]
64
- const sql = `SELECT checkpoint, parent_id, metadata FROM ${ this . tableName } WHERE thread_id = ? AND checkpoint_id = ?`
65
-
66
- const rows = await this . queryRunner . manager . query ( sql , keys )
67
-
68
- if ( rows && rows . length > 0 ) {
69
- return {
70
- config,
71
- checkpoint : ( await this . serde . parse ( rows [ 0 ] . checkpoint . toString ( ) ) ) as Checkpoint ,
72
- metadata : ( await this . serde . parse ( rows [ 0 ] . metadata . toString ( ) ) ) as CheckpointMetadata ,
73
- parentConfig : rows [ 0 ] . parent_id
74
- ? {
75
- configurable : {
76
- thread_id,
77
- checkpoint_id : rows [ 0 ] . parent_id
78
- }
79
- }
80
- : undefined
81
- }
82
- }
83
- } catch ( error ) {
84
- console . error ( `Error retrieving ${ this . tableName } ` , error )
85
- throw new Error ( `Error retrieving ${ this . tableName } ` )
86
- }
87
- } else {
88
- const keys = [ thread_id ]
89
- const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${ this . tableName } WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
59
+ try {
60
+ const queryRunner = dataSource . createQueryRunner ( )
61
+ const sql = checkpoint_id
62
+ ? `SELECT checkpoint, parent_id, metadata FROM ${ this . tableName } WHERE thread_id = ? AND checkpoint_id = ?`
63
+ : `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${ this . tableName } WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
90
64
91
- const rows = await this . queryRunner . manager . query ( sql , keys )
65
+ const rows = await queryRunner . manager . query ( sql , checkpoint_id ? [ thread_id , checkpoint_id ] : [ thread_id ] )
66
+ await queryRunner . release ( )
92
67
93
68
if ( rows && rows . length > 0 ) {
69
+ const row = rows [ 0 ]
94
70
return {
95
71
config : {
96
72
configurable : {
97
- thread_id : rows [ 0 ] . thread_id ,
98
- checkpoint_id : rows [ 0 ] . checkpoint_id
73
+ thread_id : row . thread_id || thread_id ,
74
+ checkpoint_id : row . checkpoint_id || checkpoint_id
99
75
}
100
76
} ,
101
- checkpoint : ( await this . serde . parse ( rows [ 0 ] . checkpoint . toString ( ) ) ) as Checkpoint ,
102
- metadata : ( await this . serde . parse ( rows [ 0 ] . metadata . toString ( ) ) ) as CheckpointMetadata ,
103
- parentConfig : rows [ 0 ] . parent_id
77
+ checkpoint : ( await this . serde . parse ( row . checkpoint . toString ( ) ) ) as Checkpoint ,
78
+ metadata : ( await this . serde . parse ( row . metadata . toString ( ) ) ) as CheckpointMetadata ,
79
+ parentConfig : row . parent_id
104
80
? {
105
81
configurable : {
106
- thread_id : rows [ 0 ] . thread_id ,
107
- checkpoint_id : rows [ 0 ] . parent_id
82
+ thread_id,
83
+ checkpoint_id : row . parent_id
108
84
}
109
85
}
110
86
: undefined
111
87
}
112
88
}
89
+ } catch ( error ) {
90
+ console . error ( `Error retrieving ${ this . tableName } ` , error )
91
+ throw new Error ( `Error retrieving ${ this . tableName } ` )
92
+ } finally {
93
+ await dataSource . destroy ( )
113
94
}
114
95
return undefined
115
96
}
116
97
117
- async * list ( config : RunnableConfig , limit ?: number , before ?: RunnableConfig ) : AsyncGenerator < CheckpointTuple > {
118
- await this . setup ( )
119
- const thread_id = config . configurable ?. thread_id || this . threadId
120
- let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${ this . tableName } WHERE thread_id = ? ${
121
- before ? 'AND checkpoint_id < ?' : ''
122
- } ORDER BY checkpoint_id DESC`
123
- if ( limit ) {
124
- sql += ` LIMIT ${ limit } `
125
- }
126
- const args = [ thread_id , before ?. configurable ?. checkpoint_id ] . filter ( Boolean )
127
-
98
+ async * list ( config : RunnableConfig , limit ?: number , before ?: RunnableConfig ) : AsyncGenerator < CheckpointTuple , void , unknown > {
99
+ const dataSource = await this . getDataSource ( )
100
+ await this . setup ( dataSource )
101
+ const queryRunner = dataSource . createQueryRunner ( )
128
102
try {
129
- const rows = await this . queryRunner . manager . query ( sql , args )
103
+ const threadId = config . configurable ?. thread_id || this . threadId
104
+ let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${ this . tableName } WHERE thread_id = ? ${
105
+ before ? 'AND checkpoint_id < ?' : ''
106
+ } ORDER BY checkpoint_id DESC`
107
+ if ( limit ) {
108
+ sql += ` LIMIT ${ limit } `
109
+ }
110
+ const args = [ threadId , before ?. configurable ?. checkpoint_id ] . filter ( Boolean )
111
+
112
+ const rows = await queryRunner . manager . query ( sql , args )
113
+ await queryRunner . release ( )
130
114
131
115
if ( rows && rows . length > 0 ) {
132
116
for ( const row of rows ) {
@@ -151,15 +135,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
151
135
}
152
136
}
153
137
} catch ( error ) {
154
- console . error ( `Error listing ${ this . tableName } ` , error )
155
- throw new Error ( `Error listing ${ this . tableName } ` )
138
+ console . error ( `Error listing checkpoints` , error )
139
+ throw new Error ( `Error listing checkpoints` )
140
+ } finally {
141
+ await dataSource . destroy ( )
156
142
}
157
143
}
158
144
159
145
async put ( config : RunnableConfig , checkpoint : Checkpoint , metadata : CheckpointMetadata ) : Promise < RunnableConfig > {
160
- await this . setup ( )
146
+ const dataSource = await this . getDataSource ( )
147
+ await this . setup ( dataSource )
148
+
161
149
if ( ! config . configurable ?. checkpoint_id ) return { }
162
150
try {
151
+ const queryRunner = dataSource . createQueryRunner ( )
163
152
const row = [
164
153
config . configurable ?. thread_id || this . threadId ,
165
154
checkpoint . id ,
@@ -172,10 +161,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
172
161
VALUES (?, ?, ?, ?, ?)
173
162
ON DUPLICATE KEY UPDATE checkpoint = VALUES(checkpoint), metadata = VALUES(metadata)`
174
163
175
- await this . queryRunner . manager . query ( query , row )
164
+ await queryRunner . manager . query ( query , row )
165
+ await queryRunner . release ( )
176
166
} catch ( error ) {
177
167
console . error ( 'Error saving checkpoint' , error )
178
168
throw new Error ( 'Error saving checkpoint' )
169
+ } finally {
170
+ await dataSource . destroy ( )
179
171
}
180
172
181
173
return {
@@ -187,16 +179,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
187
179
}
188
180
189
181
async delete ( threadId : string ) : Promise < void > {
190
- if ( ! threadId ) {
191
- return
192
- }
193
- await this . setup ( )
194
- const query = `DELETE FROM ${ this . tableName } WHERE thread_id = ?;`
182
+ if ( ! threadId ) return
183
+
184
+ const dataSource = await this . getDataSource ( )
185
+ await this . setup ( dataSource )
195
186
196
187
try {
197
- await this . queryRunner . manager . query ( query , [ threadId ] )
188
+ const queryRunner = dataSource . createQueryRunner ( )
189
+ const query = `DELETE FROM ${ this . tableName } WHERE thread_id = ?;`
190
+ await queryRunner . manager . query ( query , [ threadId ] )
191
+ await queryRunner . release ( )
198
192
} catch ( error ) {
199
193
console . error ( `Error deleting thread_id ${ threadId } ` , error )
194
+ } finally {
195
+ await dataSource . destroy ( )
200
196
}
201
197
}
202
198
@@ -232,6 +228,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
232
228
type : m . role
233
229
} )
234
230
}
231
+
235
232
return returnIMessages
236
233
}
237
234
@@ -240,6 +237,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
240
237
}
241
238
242
239
async clearChatMessages ( overrideSessionId = '' ) : Promise < void > {
240
+ if ( ! overrideSessionId ) return
243
241
await this . delete ( overrideSessionId )
244
242
}
245
243
}
0 commit comments