-
Notifications
You must be signed in to change notification settings - Fork 1
/
logsumexp1.c
65 lines (57 loc) · 1.53 KB
/
logsumexp1.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include "mex.h"
#include "matrix.h"
#include <math.h>
/* Copyright (c) 2015, Lloyd T. Elliott. */
int validate(int nlhs,
mxArray *plhs[],
int nrhs,
const mxArray *prhs[]) {
if (nrhs != 1) {
mexErrMsgTxt("Error: logsumexp1 requires 1 argument");
return 1;
} else if (nlhs == 0) {
return 1;
} else if (nlhs > 1) {
mexErrMsgTxt("Error: logsumexp1 requires at most 1 output argument");
return 1;
} else if (!mxIsDouble(prhs[0]) ||
!mxIsNumeric(prhs[0]) ||
mxIsSparse(prhs[0]) ||
mxIsCell(prhs[0]) ||
mxIsStruct(prhs[0])) {
mexErrMsgTxt("Error: logsumexp1 argument must be double array or scalar");
return 1;
} else if (mxIsEmpty(prhs[0])) {
plhs[0] = mxCreateDoubleScalar(NAN);
return 1;
}
return 0;
}
void logsumexp(int nlhs,
mxArray *plhs[],
int nrhs,
const mxArray *prhs[]) {
size_t n = mxGetNumberOfElements(prhs[0]);
double *x = mxGetPr(prhs[0]);
double max = x[0];
double y = 0.0;
size_t i;
for (i = 1; i < n; i++) {
if (x[i] > max) {
max = x[i];
}
}
for (i = 0; i < n; i++) {
y += exp(x[i] - max);
}
plhs[0] = mxCreateDoubleScalar(log(y) + max);
}
void mexFunction(int nlhs,
mxArray *plhs[],
int nrhs,
const mxArray *prhs[]) {
/*if (validate(nlhs, plhs, nrhs, prhs) != 0) {
return;
}*/
logsumexp(nlhs, plhs, nrhs, prhs);
}