Skip to content

Commit 28a3bd1

Browse files
committed
Making some changes to improve speed (some ideas taken from shogun's implementation)
1 parent 56816fe commit 28a3bd1

File tree

1 file changed

+15
-28
lines changed

1 file changed

+15
-28
lines changed

string_kernel.py

+15-28
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,33 @@
33
# Kernel defined by Lodhi et al. (2002)
44
def ssk(s, t, n, lbda, accum=False):
55
lens, lent = len(s), len(t)
6-
#dynamic = (-1)*np.ones( (n+1, lens, lent) )
6+
#k_prim = (-1)*np.ones( (n+1, lens, lent) )
77
k_prim = np.zeros( (n, lens, lent) )
88
indices = { x : [i for i, e in enumerate(t) if e == x] for x in set(s) }
99

1010
k_prim[0,:,:] = 1
1111

1212
for i in range(1,n):
1313
for sj in range(i,lens):
14+
toret = 0.
1415
for tk in range(i,lent):
15-
x = s[sj-1]
16-
toret = lbda * k_prim[i, sj-1, tk]
17-
for k_ in indices[x]:
18-
if k_ >= tk:
19-
break
20-
toret += k_prim[i-1, sj-1, k_] * (lbda**(tk-k_+1))
21-
k_prim[i,sj,tk] = toret
16+
if s[sj-1]==t[tk-1]: # trick taken from shogun implemantion of SSK
17+
toret = lbda * (toret + lbda*k_prim[i-1,sj-1,tk-1])
18+
else:
19+
toret *= lbda
20+
k_prim[i,sj,tk] = toret + lbda * k_prim[i, sj-1, tk]
2221

23-
def k(sj, tk, n):
24-
# print( "k({},{},{})".format(s, t, n) )
25-
if n <= 0:
26-
raise "Error, n must be bigger than zero"
27-
if min(sj, tk) < n:
28-
# print( "k({},{},{}) => 0".format(s, t, n) )
29-
return 0.
30-
x = s[sj-1]
31-
toret = k(sj-1, tk, n)
32-
for k_ in indices[x]:
33-
if k_ >= tk:
34-
break
35-
toret += lbda**2 * k_prim[n-1, sj-1, k_]
36-
# print( "k({},{},{}) => {}".format(s, t, n, toret) )
37-
return toret
3822

39-
if accum:
40-
toret = sum( k(lens, lent, i) for i in range(1, min(n,lens,lent)+1) )
41-
else:
42-
toret = k(lens, lent, n)
23+
start = 0 if accum else n-1
24+
k = 0.
25+
for i in range(n):
26+
for sj in range(i,lens):
27+
for tk in range(i,lent):
28+
if s[sj]==t[tk]:
29+
k += lbda*lbda*k_prim[i,sj,tk]
4330

4431
# print( [len(list(i for (sj,tk,i) in k_prim if i==m-1)) for m in range(n)] )
45-
return toret
32+
return k
4633

4734
def string_kernel(xs, ys, n, lbda):
4835
if len(xs.shape) != 2 or len(ys.shape) != 2 or xs.shape[1] != 1 or ys.shape[1] != 1:

0 commit comments

Comments
 (0)