|
3 | 3 | # Kernel defined by Lodhi et al. (2002)
|
4 | 4 | def ssk(s, t, n, lbda, accum=False):
|
5 | 5 | lens, lent = len(s), len(t)
|
6 |
| - #dynamic = (-1)*np.ones( (n+1, lens, lent) ) |
| 6 | + #k_prim = (-1)*np.ones( (n+1, lens, lent) ) |
7 | 7 | k_prim = np.zeros( (n, lens, lent) )
|
8 | 8 | indices = { x : [i for i, e in enumerate(t) if e == x] for x in set(s) }
|
9 | 9 |
|
10 | 10 | k_prim[0,:,:] = 1
|
11 | 11 |
|
12 | 12 | for i in range(1,n):
|
13 | 13 | for sj in range(i,lens):
|
| 14 | + toret = 0. |
14 | 15 | for tk in range(i,lent):
|
15 |
| - x = s[sj-1] |
16 |
| - toret = lbda * k_prim[i, sj-1, tk] |
17 |
| - for k_ in indices[x]: |
18 |
| - if k_ >= tk: |
19 |
| - break |
20 |
| - toret += k_prim[i-1, sj-1, k_] * (lbda**(tk-k_+1)) |
21 |
| - k_prim[i,sj,tk] = toret |
| 16 | + if s[sj-1]==t[tk-1]: # trick taken from shogun implemantion of SSK |
| 17 | + toret = lbda * (toret + lbda*k_prim[i-1,sj-1,tk-1]) |
| 18 | + else: |
| 19 | + toret *= lbda |
| 20 | + k_prim[i,sj,tk] = toret + lbda * k_prim[i, sj-1, tk] |
22 | 21 |
|
23 |
| - def k(sj, tk, n): |
24 |
| - # print( "k({},{},{})".format(s, t, n) ) |
25 |
| - if n <= 0: |
26 |
| - raise "Error, n must be bigger than zero" |
27 |
| - if min(sj, tk) < n: |
28 |
| - # print( "k({},{},{}) => 0".format(s, t, n) ) |
29 |
| - return 0. |
30 |
| - x = s[sj-1] |
31 |
| - toret = k(sj-1, tk, n) |
32 |
| - for k_ in indices[x]: |
33 |
| - if k_ >= tk: |
34 |
| - break |
35 |
| - toret += lbda**2 * k_prim[n-1, sj-1, k_] |
36 |
| - # print( "k({},{},{}) => {}".format(s, t, n, toret) ) |
37 |
| - return toret |
38 | 22 |
|
39 |
| - if accum: |
40 |
| - toret = sum( k(lens, lent, i) for i in range(1, min(n,lens,lent)+1) ) |
41 |
| - else: |
42 |
| - toret = k(lens, lent, n) |
| 23 | + start = 0 if accum else n-1 |
| 24 | + k = 0. |
| 25 | + for i in range(n): |
| 26 | + for sj in range(i,lens): |
| 27 | + for tk in range(i,lent): |
| 28 | + if s[sj]==t[tk]: |
| 29 | + k += lbda*lbda*k_prim[i,sj,tk] |
43 | 30 |
|
44 | 31 | # print( [len(list(i for (sj,tk,i) in k_prim if i==m-1)) for m in range(n)] )
|
45 |
| - return toret |
| 32 | + return k |
46 | 33 |
|
47 | 34 | def string_kernel(xs, ys, n, lbda):
|
48 | 35 | if len(xs.shape) != 2 or len(ys.shape) != 2 or xs.shape[1] != 1 or ys.shape[1] != 1:
|
|
0 commit comments