Skip to content

Commit f416c13

Browse files
authored
initial work on biases for model with implicit feedback (#54)
* initial work on biases for model with implicit feedback * more work on implicit with biases - now rhs takes into account unobserved interactions. * use precomputed XtX in transform * fix bug in transform when modelling with biases * patches wrathematics/float#36 * add tests for models with biases * fixed bug in transform when using dynamic lambda * added extra transform() call in fit_transform() to ensure results from transform() and fit_transform() are identical
1 parent 9c8b2c5 commit f416c13

File tree

5 files changed

+225
-61
lines changed

5 files changed

+225
-61
lines changed

Diff for: R/RcppExports.R

+4-4
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ dense_csc_prod <- function(x_r, y_csc_r, num_threads = 1L) {
7373
.Call(`_rsparse_dense_csc_prod`, x_r, y_csc_r, num_threads)
7474
}
7575

76-
als_implicit_double <- function(m_csc_r, X, Y, XtX, lambda, n_threads, solver, cg_steps, is_x_bias_last_row) {
77-
.Call(`_rsparse_als_implicit_double`, m_csc_r, X, Y, XtX, lambda, n_threads, solver, cg_steps, is_x_bias_last_row)
76+
als_implicit_double <- function(m_csc_r, X, Y, XtX, lambda, n_threads, solver, cg_steps, with_biases, is_x_bias_last_row) {
77+
.Call(`_rsparse_als_implicit_double`, m_csc_r, X, Y, XtX, lambda, n_threads, solver, cg_steps, with_biases, is_x_bias_last_row)
7878
}
7979

80-
als_implicit_float <- function(m_csc_r, X_, Y_, XtX_, lambda, n_threads, solver, cg_steps, is_x_bias_last_row) {
81-
.Call(`_rsparse_als_implicit_float`, m_csc_r, X_, Y_, XtX_, lambda, n_threads, solver, cg_steps, is_x_bias_last_row)
80+
als_implicit_float <- function(m_csc_r, X_, Y_, XtX_, lambda, n_threads, solver, cg_steps, with_biases, is_x_bias_last_row) {
81+
.Call(`_rsparse_als_implicit_float`, m_csc_r, X_, Y_, XtX_, lambda, n_threads, solver, cg_steps, with_biases, is_x_bias_last_row)
8282
}
8383

8484
als_explicit_double <- function(m_csc_r, X, Y, cnt_X, lambda, n_threads, solver, cg_steps, dynamic_lambda, with_biases, is_x_bias_last_row) {

Diff for: R/model_WRMF.R

+52-26
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,22 @@ WRMF = R6::R6Class(
8585
...) {
8686
stopifnot(is.null(init) || is.matrix(init))
8787
solver = match.arg(solver)
88-
private$non_negative = ifelse(solver == "nnls", TRUE, FALSE)
8988
feedback = match.arg(feedback)
9089

9190
if (feedback == 'implicit') {
9291
# FIXME
93-
# now only support bias for explicit feedback
94-
with_user_item_bias = FALSE
92+
93+
if (solver == "conjugate_gradient" && with_user_item_bias == TRUE) {
94+
msg = paste("'conjugate_gradient' is not supported for a model",
95+
"`with_user_item_bias == TRUE`. Setting to 'cholesky'."
96+
)
97+
warning(msg)
98+
solver = "cholesky"
99+
}
95100
with_global_bias = FALSE
96101
}
102+
private$non_negative = ifelse(solver == "nnls", TRUE, FALSE)
103+
97104
if (private$non_negative && with_global_bias == TRUE) {
98105
logger$warn("setting `with_global_bias=FALSE` for 'nnls' solver")
99106
with_global_bias = FALSE
@@ -257,9 +264,10 @@ WRMF = R6::R6Class(
257264
item_bias = float(n_item)
258265
}
259266

260-
self$global_bias = private$init_user_item_bias(c_ui, c_iu, user_bias, item_bias)
267+
global_bias = private$init_user_item_bias(c_ui, c_iu, user_bias, item_bias)
261268
self$components[1L, ] = item_bias
262269
private$U[private$rank, ] = user_bias
270+
if(private$with_global_bias) self$global_bias = global_bias
263271
} else if (private$feedback == "explicit" && private$with_global_bias) {
264272
self$global_bias = mean(c_ui@x)
265273
c_ui@x = c_ui@x - self$global_bias
@@ -282,15 +290,23 @@ WRMF = R6::R6Class(
282290
cnt_u = float::fl(cnt_u)
283291
cnt_i = float::fl(cnt_i)
284292
}
293+
private$cnt_u = cnt_u
285294

286295
# iterate
287296
for (i in seq_len(n_iter)) {
297+
288298
# solve for items
289-
loss = private$solver(c_ui, private$U, self$components, TRUE, cnt_X=cnt_i)
299+
loss = private$solver(c_ui, private$U, self$components,
300+
is_bias_last_row = TRUE,
301+
cnt_X = cnt_i)
302+
logger$info("iter %d (items) loss = %.4f", i, loss)
303+
290304
# solve for users
291-
loss = private$solver(c_iu, self$components, private$U, FALSE, cnt_X=cnt_u)
305+
loss = private$solver(c_iu, self$components, private$U,
306+
is_bias_last_row = FALSE,
307+
cnt_X = cnt_u)
308+
logger$info("iter %d (users) loss = %.4f", i, loss)
292309

293-
logger$info("iter %d loss = %.4f", i, loss)
294310
if (loss_prev_iter / loss - 1 < convergence_tol) {
295311
logger$info("Converged after %d iterations", i)
296312
break
@@ -299,31 +315,28 @@ WRMF = R6::R6Class(
299315
loss_prev_iter = loss
300316
}
301317

302-
rank_ = ifelse(private$with_user_item_bias, private$rank - 1L, private$rank)
303-
ridge = fl(diag(x = private$lambda, nrow = rank_, ncol = rank_))
304-
305-
X = if (private$with_user_item_bias) tcrossprod(self$components[-1L, ]) else self$components
306-
private$XtX = tcrossprod(X) + ridge
307-
308318
if (private$precision == "double")
309319
data.table::setattr(self$components, "dimnames", list(NULL, colnames(x)))
310320
else
311321
data.table::setattr(self$components@Data, "dimnames", list(NULL, colnames(x)))
312322

313-
res = t(private$U)
314-
private$U = NULL
315323

316-
if (private$precision == "double")
317-
setattr(res, "dimnames", list(rownames(x), NULL))
318-
else
319-
setattr(res@Data, "dimnames", list(rownames(x), NULL))
320-
res
324+
rank_ = ifelse(private$with_user_item_bias, private$rank - 1L, private$rank)
325+
ridge = fl(diag(x = private$lambda, nrow = rank_, ncol = rank_))
326+
XX = if (private$with_user_item_bias) self$components[-1L, , drop = FALSE] else self$components
327+
private$XtX = tcrossprod(XX) + ridge
328+
329+
# call extra transform to ensure results from transform() and fit_transform()
330+
# are the same (due to avoid_cg, etc)
331+
# this adds some extra computation, but not a big deal though
332+
self$transform(x)
321333
},
322334
# project new users into latent user space - just make ALS step given fixed items matrix
323335
#' @description create user embeddings for new input
324336
#' @param x user-item iteraction matrix
325337
#' @param ... not used at the moment
326338
transform = function(x, ...) {
339+
327340
stopifnot(ncol(x) == ncol(self$components))
328341
if (private$feedback == "implicit" ) {
329342
logger$trace("WRMF$transform(): calling `RhpcBLASctl::blas_set_num_threads(1)` (to avoid thread contention)")
@@ -346,7 +359,19 @@ WRMF = R6::R6Class(
346359
res = float(0, nrow = private$rank, ncol = nrow(x))
347360
}
348361

349-
loss = private$solver(t(x), self$components, res, FALSE, private$XtX, avoid_cg=TRUE)
362+
if (private$with_user_item_bias) {
363+
res[1, ] = if(private$precision == "double") 1.0 else float::fl(1.0)
364+
}
365+
366+
loss = private$solver(
367+
t(x),
368+
self$components,
369+
res,
370+
is_bias_last_row = FALSE,
371+
XtX = private$XtX,
372+
cnt_X = private$cnt_u,
373+
avoid_cg = TRUE
374+
)
350375

351376
res = t(res)
352377

@@ -367,6 +392,7 @@ WRMF = R6::R6Class(
367392
dynamic_lambda = FALSE,
368393
rank = NULL,
369394
non_negative = NULL,
395+
cnt_u = NULL,
370396
# user factor matrix = rank * n_users
371397
U = NULL,
372398
# item factor matrix = rank * n_items
@@ -404,15 +430,15 @@ als_implicit = function(
404430
rank = ifelse(with_user_item_bias, nrow(X) - 1L, nrow(X))
405431
ridge = fl(diag(x = lambda, nrow = rank, ncol = rank))
406432
if (with_user_item_bias) {
407-
index_row_to_discard = ifelse(is_bias_last_row, rank, 1L)
408-
XtX = tcrossprod(X[-index_row_to_discard, ])
433+
index_row_to_discard = ifelse(is_bias_last_row, nrow(X), 1L)
434+
XX = X[-index_row_to_discard, , drop = FALSE]
409435
} else {
410-
XtX = tcrossprod(X)
436+
XX = X
411437
}
412-
XtX = XtX + ridge
438+
XtX = tcrossprod(XX) + ridge
413439
}
414440
# Y is modified in-place
415-
loss = solver(x, X, Y, XtX, lambda, n_threads, solver_code, cg_steps, is_bias_last_row)
441+
loss = solver(x, X, Y, XtX, lambda, n_threads, solver_code, cg_steps, with_user_item_bias, is_bias_last_row)
416442
}
417443

418444
als_explicit = function(

Diff for: src/RcppExports.cpp

+10-8
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@ BEGIN_RCPP
246246
END_RCPP
247247
}
248248
// als_implicit_double
249-
double als_implicit_double(const Rcpp::S4& m_csc_r, arma::mat& X, arma::mat& Y, const arma::mat& XtX, double lambda, unsigned n_threads, unsigned solver, unsigned cg_steps, bool is_x_bias_last_row);
250-
RcppExport SEXP _rsparse_als_implicit_double(SEXP m_csc_rSEXP, SEXP XSEXP, SEXP YSEXP, SEXP XtXSEXP, SEXP lambdaSEXP, SEXP n_threadsSEXP, SEXP solverSEXP, SEXP cg_stepsSEXP, SEXP is_x_bias_last_rowSEXP) {
249+
double als_implicit_double(const Rcpp::S4& m_csc_r, arma::mat& X, arma::mat& Y, const arma::mat& XtX, double lambda, unsigned n_threads, unsigned solver, unsigned cg_steps, const bool with_biases, bool is_x_bias_last_row);
250+
RcppExport SEXP _rsparse_als_implicit_double(SEXP m_csc_rSEXP, SEXP XSEXP, SEXP YSEXP, SEXP XtXSEXP, SEXP lambdaSEXP, SEXP n_threadsSEXP, SEXP solverSEXP, SEXP cg_stepsSEXP, SEXP with_biasesSEXP, SEXP is_x_bias_last_rowSEXP) {
251251
BEGIN_RCPP
252252
Rcpp::RObject rcpp_result_gen;
253253
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -259,14 +259,15 @@ BEGIN_RCPP
259259
Rcpp::traits::input_parameter< unsigned >::type n_threads(n_threadsSEXP);
260260
Rcpp::traits::input_parameter< unsigned >::type solver(solverSEXP);
261261
Rcpp::traits::input_parameter< unsigned >::type cg_steps(cg_stepsSEXP);
262+
Rcpp::traits::input_parameter< const bool >::type with_biases(with_biasesSEXP);
262263
Rcpp::traits::input_parameter< bool >::type is_x_bias_last_row(is_x_bias_last_rowSEXP);
263-
rcpp_result_gen = Rcpp::wrap(als_implicit_double(m_csc_r, X, Y, XtX, lambda, n_threads, solver, cg_steps, is_x_bias_last_row));
264+
rcpp_result_gen = Rcpp::wrap(als_implicit_double(m_csc_r, X, Y, XtX, lambda, n_threads, solver, cg_steps, with_biases, is_x_bias_last_row));
264265
return rcpp_result_gen;
265266
END_RCPP
266267
}
267268
// als_implicit_float
268-
double als_implicit_float(const Rcpp::S4& m_csc_r, Rcpp::S4& X_, Rcpp::S4& Y_, Rcpp::S4& XtX_, double lambda, unsigned n_threads, unsigned solver, unsigned cg_steps, bool is_x_bias_last_row);
269-
RcppExport SEXP _rsparse_als_implicit_float(SEXP m_csc_rSEXP, SEXP X_SEXP, SEXP Y_SEXP, SEXP XtX_SEXP, SEXP lambdaSEXP, SEXP n_threadsSEXP, SEXP solverSEXP, SEXP cg_stepsSEXP, SEXP is_x_bias_last_rowSEXP) {
269+
double als_implicit_float(const Rcpp::S4& m_csc_r, Rcpp::S4& X_, Rcpp::S4& Y_, Rcpp::S4& XtX_, double lambda, unsigned n_threads, unsigned solver, unsigned cg_steps, const bool with_biases, bool is_x_bias_last_row);
270+
RcppExport SEXP _rsparse_als_implicit_float(SEXP m_csc_rSEXP, SEXP X_SEXP, SEXP Y_SEXP, SEXP XtX_SEXP, SEXP lambdaSEXP, SEXP n_threadsSEXP, SEXP solverSEXP, SEXP cg_stepsSEXP, SEXP with_biasesSEXP, SEXP is_x_bias_last_rowSEXP) {
270271
BEGIN_RCPP
271272
Rcpp::RObject rcpp_result_gen;
272273
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -278,8 +279,9 @@ BEGIN_RCPP
278279
Rcpp::traits::input_parameter< unsigned >::type n_threads(n_threadsSEXP);
279280
Rcpp::traits::input_parameter< unsigned >::type solver(solverSEXP);
280281
Rcpp::traits::input_parameter< unsigned >::type cg_steps(cg_stepsSEXP);
282+
Rcpp::traits::input_parameter< const bool >::type with_biases(with_biasesSEXP);
281283
Rcpp::traits::input_parameter< bool >::type is_x_bias_last_row(is_x_bias_last_rowSEXP);
282-
rcpp_result_gen = Rcpp::wrap(als_implicit_float(m_csc_r, X_, Y_, XtX_, lambda, n_threads, solver, cg_steps, is_x_bias_last_row));
284+
rcpp_result_gen = Rcpp::wrap(als_implicit_float(m_csc_r, X_, Y_, XtX_, lambda, n_threads, solver, cg_steps, with_biases, is_x_bias_last_row));
283285
return rcpp_result_gen;
284286
END_RCPP
285287
}
@@ -538,8 +540,8 @@ static const R_CallMethodDef CallEntries[] = {
538540
{"_rsparse_cpp_glove_partial_fit", (DL_FUNC) &_rsparse_cpp_glove_partial_fit, 6},
539541
{"_rsparse_csr_dense_tcrossprod", (DL_FUNC) &_rsparse_csr_dense_tcrossprod, 3},
540542
{"_rsparse_dense_csc_prod", (DL_FUNC) &_rsparse_dense_csc_prod, 3},
541-
{"_rsparse_als_implicit_double", (DL_FUNC) &_rsparse_als_implicit_double, 9},
542-
{"_rsparse_als_implicit_float", (DL_FUNC) &_rsparse_als_implicit_float, 9},
543+
{"_rsparse_als_implicit_double", (DL_FUNC) &_rsparse_als_implicit_double, 10},
544+
{"_rsparse_als_implicit_float", (DL_FUNC) &_rsparse_als_implicit_float, 10},
543545
{"_rsparse_als_explicit_double", (DL_FUNC) &_rsparse_als_explicit_double, 11},
544546
{"_rsparse_als_explicit_float", (DL_FUNC) &_rsparse_als_explicit_float, 11},
545547
{"_rsparse_initialize_biases_double", (DL_FUNC) &_rsparse_initialize_biases_double, 8},

0 commit comments

Comments
 (0)