@@ -12,11 +12,11 @@ import numpy as np
12
12
cimport numpy as np
13
13
14
14
# declare the interface to the C code
15
- cdef extern void c_vbgmm_fit (double * adX, int nN, int nD, int nK, int seed, int * anAssign, int nThreads)
15
+ cdef extern void c_vbgmm_fit (double * adX, int nN, int nD, int nK, int seed, int * anAssign, int nThreads, int nIter )
16
16
@ cython.boundscheck (False )
17
17
@ cython.wraparound (False )
18
18
19
- def fit (np.ndarray[double , ndim = 2 , mode = " c" ] xarray not None , nClusters , seed , threads ):
19
+ def fit (np.ndarray[double , ndim = 2 , mode = " c" ] xarray not None , nClusters , seed , threads , piter ):
20
20
"""
21
21
fit (xarray, nClusters, seed, threads)
22
22
@@ -26,18 +26,20 @@ def fit(np.ndarray[double, ndim=2, mode="c"] xarray not None, nClusters, seed, t
26
26
param: nClusters -- an int, number of start clusters
27
27
param: seed -- an int, the random seed
28
28
param: threads -- int, the number of threads to use
29
-
29
+ param: piter -- int, the number of VB iterations to use
30
30
"""
31
- cdef int nN, nD, nK, nThreads
31
+ cdef int nN, nD, nK, nThreads, nIter
32
32
33
33
nN, nD = xarray.shape[0 ], xarray.shape[1 ]
34
34
35
35
nK = nClusters
36
36
37
+ nIter = piter
38
+
37
39
nThreads = threads
38
40
39
41
cdef np.ndarray[int , ndim= 1 ,mode= " c" ] assign = np.zeros((nN), dtype = np.intc)
40
42
41
- c_vbgmm_fit (& xarray[0 ,0 ], nN, nD, nK, seed, & assign[0 ], nThreads)
43
+ c_vbgmm_fit (& xarray[0 ,0 ], nN, nD, nK, seed, & assign[0 ], nThreads, nIter )
42
44
43
45
return assign
0 commit comments