@@ -85,15 +85,22 @@ WRMF = R6::R6Class(
85
85
... ) {
86
86
stopifnot(is.null(init ) || is.matrix(init ))
87
87
solver = match.arg(solver )
88
- private $ non_negative = ifelse(solver == " nnls" , TRUE , FALSE )
89
88
feedback = match.arg(feedback )
90
89
91
90
if (feedback == ' implicit' ) {
92
91
# FIXME
93
- # now only support bias for explicit feedback
94
- with_user_item_bias = FALSE
92
+
93
+ if (solver == " conjugate_gradient" && with_user_item_bias == TRUE ) {
94
+ msg = paste(" 'conjugate_gradient' is not supported for a model" ,
95
+ " `with_user_item_bias == TRUE`. Setting to 'cholesky'."
96
+ )
97
+ warning(msg )
98
+ solver = " cholesky"
99
+ }
95
100
with_global_bias = FALSE
96
101
}
102
+ private $ non_negative = ifelse(solver == " nnls" , TRUE , FALSE )
103
+
97
104
if (private $ non_negative && with_global_bias == TRUE ) {
98
105
logger $ warn(" setting `with_global_bias=FALSE` for 'nnls' solver" )
99
106
with_global_bias = FALSE
@@ -257,9 +264,10 @@ WRMF = R6::R6Class(
257
264
item_bias = float(n_item )
258
265
}
259
266
260
- self $ global_bias = private $ init_user_item_bias(c_ui , c_iu , user_bias , item_bias )
267
+ global_bias = private $ init_user_item_bias(c_ui , c_iu , user_bias , item_bias )
261
268
self $ components [1L , ] = item_bias
262
269
private $ U [private $ rank , ] = user_bias
270
+ if (private $ with_global_bias ) self $ global_bias = global_bias
263
271
} else if (private $ feedback == " explicit" && private $ with_global_bias ) {
264
272
self $ global_bias = mean(c_ui @ x )
265
273
c_ui @ x = c_ui @ x - self $ global_bias
@@ -282,15 +290,23 @@ WRMF = R6::R6Class(
282
290
cnt_u = float :: fl(cnt_u )
283
291
cnt_i = float :: fl(cnt_i )
284
292
}
293
+ private $ cnt_u = cnt_u
285
294
286
295
# iterate
287
296
for (i in seq_len(n_iter )) {
297
+
288
298
# solve for items
289
- loss = private $ solver(c_ui , private $ U , self $ components , TRUE , cnt_X = cnt_i )
299
+ loss = private $ solver(c_ui , private $ U , self $ components ,
300
+ is_bias_last_row = TRUE ,
301
+ cnt_X = cnt_i )
302
+ logger $ info(" iter %d (items) loss = %.4f" , i , loss )
303
+
290
304
# solve for users
291
- loss = private $ solver(c_iu , self $ components , private $ U , FALSE , cnt_X = cnt_u )
305
+ loss = private $ solver(c_iu , self $ components , private $ U ,
306
+ is_bias_last_row = FALSE ,
307
+ cnt_X = cnt_u )
308
+ logger $ info(" iter %d (users) loss = %.4f" , i , loss )
292
309
293
- logger $ info(" iter %d loss = %.4f" , i , loss )
294
310
if (loss_prev_iter / loss - 1 < convergence_tol ) {
295
311
logger $ info(" Converged after %d iterations" , i )
296
312
break
@@ -299,31 +315,28 @@ WRMF = R6::R6Class(
299
315
loss_prev_iter = loss
300
316
}
301
317
302
- rank_ = ifelse(private $ with_user_item_bias , private $ rank - 1L , private $ rank )
303
- ridge = fl(diag(x = private $ lambda , nrow = rank_ , ncol = rank_ ))
304
-
305
- X = if (private $ with_user_item_bias ) tcrossprod(self $ components [- 1L , ]) else self $ components
306
- private $ XtX = tcrossprod(X ) + ridge
307
-
308
318
if (private $ precision == " double" )
309
319
data.table :: setattr(self $ components , " dimnames" , list (NULL , colnames(x )))
310
320
else
311
321
data.table :: setattr(self $ components @ Data , " dimnames" , list (NULL , colnames(x )))
312
322
313
- res = t(private $ U )
314
- private $ U = NULL
315
323
316
- if (private $ precision == " double" )
317
- setattr(res , " dimnames" , list (rownames(x ), NULL ))
318
- else
319
- setattr(res @ Data , " dimnames" , list (rownames(x ), NULL ))
320
- res
324
+ rank_ = ifelse(private $ with_user_item_bias , private $ rank - 1L , private $ rank )
325
+ ridge = fl(diag(x = private $ lambda , nrow = rank_ , ncol = rank_ ))
326
+ XX = if (private $ with_user_item_bias ) self $ components [- 1L , , drop = FALSE ] else self $ components
327
+ private $ XtX = tcrossprod(XX ) + ridge
328
+
329
+ # call extra transform to ensure results from transform() and fit_transform()
330
+ # are the same (due to avoid_cg, etc)
331
+ # this adds some extra computation, but not a big deal though
332
+ self $ transform(x )
321
333
},
322
334
# project new users into latent user space - just make ALS step given fixed items matrix
323
335
# ' @description create user embeddings for new input
324
336
# ' @param x user-item iteraction matrix
325
337
# ' @param ... not used at the moment
326
338
transform = function (x , ... ) {
339
+
327
340
stopifnot(ncol(x ) == ncol(self $ components ))
328
341
if (private $ feedback == " implicit" ) {
329
342
logger $ trace(" WRMF$transform(): calling `RhpcBLASctl::blas_set_num_threads(1)` (to avoid thread contention)" )
@@ -346,7 +359,19 @@ WRMF = R6::R6Class(
346
359
res = float(0 , nrow = private $ rank , ncol = nrow(x ))
347
360
}
348
361
349
- loss = private $ solver(t(x ), self $ components , res , FALSE , private $ XtX , avoid_cg = TRUE )
362
+ if (private $ with_user_item_bias ) {
363
+ res [1 , ] = if (private $ precision == " double" ) 1.0 else float :: fl(1.0 )
364
+ }
365
+
366
+ loss = private $ solver(
367
+ t(x ),
368
+ self $ components ,
369
+ res ,
370
+ is_bias_last_row = FALSE ,
371
+ XtX = private $ XtX ,
372
+ cnt_X = private $ cnt_u ,
373
+ avoid_cg = TRUE
374
+ )
350
375
351
376
res = t(res )
352
377
@@ -367,6 +392,7 @@ WRMF = R6::R6Class(
367
392
dynamic_lambda = FALSE ,
368
393
rank = NULL ,
369
394
non_negative = NULL ,
395
+ cnt_u = NULL ,
370
396
# user factor matrix = rank * n_users
371
397
U = NULL ,
372
398
# item factor matrix = rank * n_items
@@ -404,15 +430,15 @@ als_implicit = function(
404
430
rank = ifelse(with_user_item_bias , nrow(X ) - 1L , nrow(X ))
405
431
ridge = fl(diag(x = lambda , nrow = rank , ncol = rank ))
406
432
if (with_user_item_bias ) {
407
- index_row_to_discard = ifelse(is_bias_last_row , rank , 1L )
408
- XtX = tcrossprod( X [- index_row_to_discard , ])
433
+ index_row_to_discard = ifelse(is_bias_last_row , nrow( X ) , 1L )
434
+ XX = X [- index_row_to_discard , , drop = FALSE ]
409
435
} else {
410
- XtX = tcrossprod( X )
436
+ XX = X
411
437
}
412
- XtX = XtX + ridge
438
+ XtX = tcrossprod( XX ) + ridge
413
439
}
414
440
# Y is modified in-place
415
- loss = solver(x , X , Y , XtX , lambda , n_threads , solver_code , cg_steps , is_bias_last_row )
441
+ loss = solver(x , X , Y , XtX , lambda , n_threads , solver_code , cg_steps , with_user_item_bias , is_bias_last_row )
416
442
}
417
443
418
444
als_explicit = function (
0 commit comments