@@ -4,15 +4,19 @@ use super::helpers::{
4
4
remove_job_handle, schedule_job_retry, set_job_handle, startup_hook,
5
5
} ;
6
6
use super :: types:: {
7
- ClientJobsMap , EmbeddingJob , EmbeddingProcessorArgs , JobBatchingHashMap , JobEvent ,
8
- JobEventHandlersMap , JobInsertNotification , JobRunArgs , JobUpdateNotification ,
7
+ ClientJobsMap , EmbeddingProcessorArgs , JobBatchingHashMap , JobEvent , JobEventHandlersMap ,
8
+ JobInsertNotification , JobRunArgs , JobUpdateNotification ,
9
9
} ;
10
10
use crate :: daemon:: helpers:: anyhow_wrap_connection;
11
- use crate :: embeddings:: cli:: { EmbeddingArgs , EmbeddingJobType } ;
11
+ use crate :: embeddings:: cli:: { EmbeddingArgs , EmbeddingJobType , Runtime } ;
12
+ use crate :: embeddings:: core:: utils:: get_clean_model_name;
12
13
use crate :: embeddings:: get_default_batch_size;
13
14
use crate :: logger:: Logger ;
14
- use crate :: utils:: { get_common_embedding_ignore_filters, get_full_table_name, quote_ident} ;
15
+ use crate :: utils:: {
16
+ get_common_embedding_ignore_filters, get_full_table_name, quote_ident, quote_literal,
17
+ } ;
15
18
use crate :: { embeddings, types:: * } ;
19
+ use itertools:: Itertools ;
16
20
use std:: collections:: HashMap ;
17
21
use std:: ops:: Deref ;
18
22
use std:: path:: Path ;
@@ -22,8 +26,7 @@ use std::time::SystemTime;
22
26
use tokio:: fs;
23
27
use tokio:: sync:: mpsc:: { Receiver , Sender , UnboundedReceiver , UnboundedSender } ;
24
28
use tokio:: sync:: { mpsc, Mutex , RwLock } ;
25
- use tokio_postgres:: types:: ToSql ;
26
- use tokio_postgres:: { Client , NoTls } ;
29
+ use tokio_postgres:: { types:: ToSql , Client , NoTls , Row } ;
27
30
use tokio_util:: sync:: CancellationToken ;
28
31
29
32
pub const JOB_TABLE_DEFINITION : & ' static str = r#"
@@ -71,6 +74,104 @@ const EMB_USAGE_TABLE_NAME: &'static str = "embedding_usage_info";
71
74
const EMB_FAILURE_TABLE_NAME : & ' static str = "embedding_failure_info" ;
72
75
const EMB_LOCK_TABLE_NAME : & ' static str = "_lantern_emb_job_locks" ;
73
76
77
+ #[ derive( Debug , Clone ) ]
78
+ pub struct EmbeddingJob {
79
+ pub id : i32 ,
80
+ pub is_init : bool ,
81
+ pub db_uri : String ,
82
+ pub schema : String ,
83
+ pub table : String ,
84
+ pub column : String ,
85
+ pub pk : String ,
86
+ pub filter : Option < String > ,
87
+ pub label : Option < String > ,
88
+ pub job_type : EmbeddingJobType ,
89
+ pub column_type : String ,
90
+ pub out_column : String ,
91
+ pub model : String ,
92
+ pub runtime_params : String ,
93
+ pub runtime : Runtime ,
94
+ pub batch_size : Option < usize > ,
95
+ pub row_ids : Option < Vec < String > > ,
96
+ }
97
+
98
+ impl EmbeddingJob {
99
+ pub fn new ( row : Row , data_path : & str , db_uri : & str ) -> Result < EmbeddingJob , anyhow:: Error > {
100
+ let runtime = Runtime :: try_from ( row. get :: < & str , Option < & str > > ( "runtime" ) . unwrap_or ( "ort" ) ) ?;
101
+ let runtime_params = if runtime == Runtime :: Ort {
102
+ format ! ( r#"{{ "data_path": "{data_path}" }}"# )
103
+ } else {
104
+ row. get :: < & str , Option < String > > ( "runtime_params" )
105
+ . unwrap_or ( "{}" . to_owned ( ) )
106
+ } ;
107
+
108
+ let batch_size = if let Some ( batch_size) = row. get :: < & str , Option < i32 > > ( "batch_size" ) {
109
+ Some ( batch_size as usize )
110
+ } else {
111
+ None
112
+ } ;
113
+
114
+ Ok ( Self {
115
+ id : row. get :: < & str , i32 > ( "id" ) ,
116
+ pk : row. get :: < & str , String > ( "pk" ) ,
117
+ label : row. get :: < & str , Option < String > > ( "label" ) ,
118
+ db_uri : db_uri. to_owned ( ) ,
119
+ schema : row. get :: < & str , String > ( "schema" ) ,
120
+ table : row. get :: < & str , String > ( "table" ) ,
121
+ column : row. get :: < & str , String > ( "column" ) ,
122
+ out_column : row. get :: < & str , String > ( "dst_column" ) ,
123
+ model : get_clean_model_name ( row. get :: < & str , & str > ( "model" ) , runtime) ,
124
+ runtime,
125
+ runtime_params,
126
+ filter : None ,
127
+ row_ids : None ,
128
+ is_init : true ,
129
+ batch_size,
130
+ job_type : EmbeddingJobType :: try_from (
131
+ row. get :: < & str , Option < & str > > ( "job_type" )
132
+ . unwrap_or ( "embedding" ) ,
133
+ ) ?,
134
+ column_type : row
135
+ . get :: < & str , Option < String > > ( "column_type" )
136
+ . unwrap_or ( "REAL[]" . to_owned ( ) ) ,
137
+ } )
138
+ }
139
+
140
+ pub fn set_filter ( & mut self , filter : & str ) {
141
+ self . filter = Some ( filter. to_owned ( ) ) ;
142
+ }
143
+
144
+ pub fn set_is_init ( & mut self , is_init : bool ) {
145
+ self . is_init = is_init;
146
+ }
147
+
148
+ pub fn set_row_ids ( & mut self , row_ids : Vec < String > ) {
149
+ self . row_ids = Some ( row_ids) ;
150
+ }
151
+
152
+ #[ allow( dead_code) ]
153
+ pub fn set_ctid_filter ( & mut self , row_ids : & Vec < String > ) {
154
+ let row_ctids_str = row_ids
155
+ . iter ( )
156
+ . map ( |r| {
157
+ format ! (
158
+ "currtid2('{table_name}','{r}'::tid)" ,
159
+ table_name = & self . table
160
+ )
161
+ } )
162
+ . join ( "," ) ;
163
+ self . set_filter ( & format ! ( "ctid IN ({row_ctids_str})" ) ) ;
164
+ }
165
+
166
+ pub fn set_id_filter ( & mut self , row_ids : & Vec < String > ) {
167
+ let row_ctids_str = row_ids. iter ( ) . map ( |s| quote_literal ( s) ) . join ( "," ) ;
168
+ self . set_filter ( & format ! (
169
+ "id IN ({row_ctids_str}) AND {common_filter}" ,
170
+ common_filter = get_common_embedding_ignore_filters( & self . column)
171
+ ) ) ;
172
+ }
173
+ }
174
+
74
175
async fn lock_row (
75
176
client : Arc < Client > ,
76
177
lock_table_name : & str ,
0 commit comments