Skip to content

Commit

Permalink
simplify psd cone projection
Browse files Browse the repository at this point in the history
  • Loading branch information
bodono committed Oct 3, 2021
1 parent 16f9eac commit e07d5ab
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 61 deletions.
2 changes: 1 addition & 1 deletion include/cones.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct SCS_CONE_WORK {
#ifdef USE_LAPACK
/* workspace for eigenvector decompositions: */
scs_float *Xs, *Z, *e, *work;
blas_int *iwork, lwork, liwork;
blas_int lwork;
#endif
};

Expand Down
113 changes: 53 additions & 60 deletions src/cones.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,15 @@
#define MAX_BOX_VAL (1e15)

#ifdef USE_LAPACK
void BLAS(syevr)(const char *jobz, const char *range, const char *uplo,
blas_int *n, scs_float *a, blas_int *lda, scs_float *vl,
scs_float *vu, blas_int *il, blas_int *iu, scs_float *abstol,
blas_int *m, scs_float *w, scs_float *z, blas_int *ldz,
blas_int *isuppz, scs_float *work, blas_int *lwork,
blas_int *iwork, blas_int *liwork, blas_int *info);
void BLAS(syr)(const char *uplo, const blas_int *n, const scs_float *alpha,
const scs_float *x, const blas_int *incx, scs_float *a,
const blas_int *lda);
void BLAS(syev)(const char *jobz, const char *uplo, blas_int *n, scs_float *a,
blas_int *lda, scs_float *w, scs_float *work, blas_int *lwork,
blas_int *info);
blas_int BLAS(syrk)(const char *uplo, const char *trans, const blas_int *n,
const blas_int *k, const scs_float *alpha,
const scs_float *a, const blas_int *lda,
const scs_float *beta, scs_float *c, const blas_int *ldc);
void BLAS(scal)(const blas_int *n, const scs_float *sa, scs_float *sx,
const blas_int *incx);
scs_float BLAS(nrm2)(const blas_int *n, scs_float *x, const blas_int *incx);
#endif

/* set the vector of rho y terms, based on scale and cones */
Expand Down Expand Up @@ -218,9 +215,6 @@ void SCS(finish_cone)(ScsConeWork *c) {
if (c->work) {
scs_free(c->work);
}
if (c->iwork) {
scs_free(c->iwork);
}
#endif
if (c->s) {
scs_free(c->s);
Expand Down Expand Up @@ -407,9 +401,7 @@ static scs_int set_up_sd_cone_work_space(ScsConeWork *c, const ScsCone *k) {
#ifdef USE_LAPACK
scs_int i;
blas_int n_max = 0;
scs_float eig_tol = 1e-8;
blas_int neg_one = -1;
blas_int m = 0;
blas_int info = 0;
scs_float wkopt = 0.0;
#if VERBOSITY > 0
Expand All @@ -426,22 +418,20 @@ static scs_int set_up_sd_cone_work_space(ScsConeWork *c, const ScsCone *k) {
c->Xs = (scs_float *)scs_calloc(n_max * n_max, sizeof(scs_float));
c->Z = (scs_float *)scs_calloc(n_max * n_max, sizeof(scs_float));
c->e = (scs_float *)scs_calloc(n_max, sizeof(scs_float));
c->liwork = 0;

BLAS(syevr)
("Vectors", "All", "Lower", &n_max, c->Xs, &n_max, SCS_NULL, SCS_NULL,
SCS_NULL, SCS_NULL, &eig_tol, &m, c->e, c->Z, &n_max, SCS_NULL, &wkopt,
&neg_one, &(c->liwork), &neg_one, &info);
/* workspace query */
BLAS(syev)
("Vectors", "Lower", &n_max, c->Xs, &n_max, SCS_NULL, &wkopt, &neg_one,
&info);

if (info != 0) {
scs_printf("FATAL: syevr failure, info = %li\n", (long)info);
scs_printf("FATAL: syev failure, info = %li\n", (long)info);
return -1;
}
c->lwork = (blas_int)(wkopt + 0.01); /* 0.01 for int casting safety */
c->lwork = (blas_int)(wkopt + 1); /* +1 for int casting safety */
c->work = (scs_float *)scs_calloc(c->lwork, sizeof(scs_float));
c->iwork = (blas_int *)scs_calloc(c->liwork, sizeof(blas_int));

if (!c->Xs || !c->Z || !c->e || !c->work || !c->iwork) {
if (!c->Xs || !c->Z || !c->e || !c->work) {
return -1;
}
return 0;
Expand Down Expand Up @@ -473,12 +463,6 @@ static scs_int project_2x2_sdc(scs_float *X) {
l1 = 0.5 * (a + d + rad);
l2 = 0.5 * (a + d - rad);

#if VERBOSITY > 0
scs_printf("2x2 SD: a = %4f, b = %4f, (X[1] = %4f, X[2] = %4f), d = %4f, "
"rad = %4f, l1 = %4f, l2 = %4f\n",
a, b, X[1], X[2], d, rad, l1, l2);
#endif

if (l2 >= 0) { /* both eigs positive already */
return 0;
}
Expand All @@ -504,28 +488,22 @@ static scs_int proj_semi_definite_cone(scs_float *X, const scs_int n,
ScsConeWork *c) {
/* project onto the positive semi-definite cone */
#ifdef USE_LAPACK
scs_int i;
blas_int one = 1;
blas_int m = 0;
scs_int i, first_idx;
blas_int nb = (blas_int)n;
blas_int ncols_z;
blas_int nb_plus_one = (blas_int)(n + 1);
blas_int cone_sz = (blas_int)(get_sd_cone_size(n));

blas_int one_int = 1;
scs_float zero = 0., one = 1.;
scs_float sqrt2 = SQRTF(2.0);
scs_float sqrt2Inv = 1.0 / sqrt2;
scs_float sqrt2_inv = 1.0 / sqrt2;
scs_float *Xs = c->Xs;
scs_float *Z = c->Z;
scs_float *e = c->e;
scs_float *work = c->work;
blas_int *iwork = c->iwork;
blas_int lwork = c->lwork;
blas_int liwork = c->liwork;

scs_float eig_tol = CONE_TOL; /* iter < 0 ? CONE_TOL : MAX(CONE_TOL, 1 /
POWF(iter + 1, CONE_RATE)); */
scs_float zero = 0.0;
blas_int info = 0;
scs_float vupper = 0.0;
scs_float sq_eig_pos;

#endif
if (n == 0) {
return 0;
Expand All @@ -540,8 +518,7 @@ static scs_int proj_semi_definite_cone(scs_float *X, const scs_int n,
}
#ifdef USE_LAPACK

memset(Xs, 0, n * n * sizeof(scs_float));
/* expand lower triangular matrix to full matrix */
/* copy lower triangular matrix into full matrix */
for (i = 0; i < n; ++i) {
memcpy(&(Xs[i * (n + 1)]), &(X[i * n - ((i - 1) * i) / 2]),
(n - i) * sizeof(scs_float));
Expand All @@ -553,29 +530,45 @@ static scs_int proj_semi_definite_cone(scs_float *X, const scs_int n,
/* scale diags by sqrt(2) */
BLAS(scal)(&nb, &sqrt2, Xs, &nb_plus_one); /* not n_squared */

/* max-eig upper bounded by frobenius norm */
/* mult by factor to make sure is upper bound */
vupper = 1.1 * sqrt2 * BLAS(nrm2)(&cone_sz, X, &one);
vupper = MAX(vupper, 0.01);
/* Solve eigenproblem, reuse workspaces */
BLAS(syevr)
("Vectors", "VInterval", "Lower", &nb, Xs, &nb, &zero, &vupper, SCS_NULL,
SCS_NULL, &eig_tol, &m, e, Z, &nb, SCS_NULL, work, &lwork, iwork, &liwork,
&info);
BLAS(syev)("Vectors", "Lower", &nb, Xs, &nb, e, work, &lwork, &info);
if (info != 0) {
scs_printf("WARN: LAPACK syevr error, info = %i\n", info);
scs_printf("WARN: LAPACK syev error, info = %i\n", info);
}
if (info < 0) {
return -1;
}

memset(Xs, 0, n * n * sizeof(scs_float));
for (i = 0; i < m; ++i) {
scs_float a = e[i];
BLAS(syr)("Lower", &nb, &a, &(Z[i * n]), &one, Xs, &nb);
first_idx = -1;
/* e is eigvals in ascending order, find first entry > 0 */
for (i = 0; i < n; ++i) {
if (e[i] > 0) {
first_idx = i;
break;
}
}

if (first_idx == -1) {
/* there are no positive eigenvalues, set X to 0 and return */
memset(X, 0, sizeof(scs_float) * n * (n + 1) / 2);
return 0;
}
/* scale diags by 1/sqrt(2) */
BLAS(scal)(&nb, &sqrt2Inv, Xs, &nb_plus_one); /* not n_squared */

/* Z is matrix of eigenvectors with positive eigenvalues */
memcpy(Z, &Xs[first_idx * n], sizeof(scs_float) * n * (n - first_idx));

/* scale Z by sqrt(eig) */
for (i = first_idx; i < n; ++i) {
sq_eig_pos = SQRTF(e[i]);
BLAS(scal)(&nb, &sq_eig_pos, &Z[(i - first_idx) * n], &one_int);
}

/* Xs = Z Z' = V E V' */
ncols_z = (blas_int)(n - first_idx);
BLAS(syrk)("Lower", "NoTrans", &nb, &ncols_z, &one, Z, &nb, &zero, Xs, &nb);

/* undo rescaling: scale diags by 1/sqrt(2) */
BLAS(scal)(&nb, &sqrt2_inv, Xs, &nb_plus_one); /* not n_squared */
/* extract just lower triangular matrix */
for (i = 0; i < n; ++i) {
memcpy(&(X[i * n - ((i - 1) * i) / 2]), &(Xs[i * (n + 1)]),
Expand Down

0 comments on commit e07d5ab

Please sign in to comment.