|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +# Kernel defined by Lodhi et al. (2002) |
| 4 | +def ssk(s, t, n, lbda, accum=False): |
| 5 | + dynamic = {} |
| 6 | + |
| 7 | + def k_prim(s, t, i): |
| 8 | + # print( "k_prim({},{},{})".format(s, t, i) ) |
| 9 | + if i == 0: |
| 10 | + # print( "k_prim({},{},{}) => 1".format(s, t, i) ) |
| 11 | + return 1. |
| 12 | + if min(len(s), len(t)) < i: |
| 13 | + # print( "k_prim({},{},{}) => 0".format(s, t, i) ) |
| 14 | + return 0. |
| 15 | + if (s,t,i) in dynamic: |
| 16 | + return dynamic[(s,t,i)] |
| 17 | + |
| 18 | + x = s[-1] |
| 19 | + s_ = s[:-1] |
| 20 | + indices = [i for i, e in enumerate(t) if e == x] |
| 21 | + toret = lbda * k_prim(s_, t, i) \ |
| 22 | + + sum( k_prim(s_, t[:j], i-1) * (lbda**(len(t)-j+1)) for j in indices ) |
| 23 | + # print( "k_prim({},{},{}) => {}".format(s, t, i, toret) ) |
| 24 | + dynamic[(s,t,i)] = toret |
| 25 | + return toret |
| 26 | + |
| 27 | + def k(s, t, n): |
| 28 | + # print( "k({},{},{})".format(s, t, n) ) |
| 29 | + if n <= 0: |
| 30 | + raise "Error, n must be bigger than zero" |
| 31 | + if min(len(s), len(t)) < n: |
| 32 | + # print( "k({},{},{}) => 0".format(s, t, n) ) |
| 33 | + return 0. |
| 34 | + x = s[-1] |
| 35 | + s_ = s[:-1] |
| 36 | + indices = [i for i, e in enumerate(t) if e == x] |
| 37 | + toret = k(s_, t, n) \ |
| 38 | + + lbda**2 * sum( k_prim(s_, t[:j], n-1) for j in indices ) |
| 39 | + # print( "k({},{},{}) => {}".format(s, t, n, toret) ) |
| 40 | + return toret |
| 41 | + |
| 42 | + if accum: |
| 43 | + toret = sum( k(s, t, i) for i in range(1, min(n,len(s),len(t))+1) ) |
| 44 | + else: |
| 45 | + toret = k(s, t, n) |
| 46 | + |
| 47 | + # print( len(dynamic) ) |
| 48 | + return toret |
| 49 | + |
| 50 | +def string_kernel(xs, ys, n, lbda): |
| 51 | + if len(xs.shape) != 2 or len(ys.shape) != 2 or xs.shape[1] != 1 or ys.shape[1] != 1: |
| 52 | + raise "The shape of the features is wrong, it must be (n,1)" |
| 53 | + |
| 54 | + lenxs, lenys = xs.shape[0], ys.shape[0] |
| 55 | + |
| 56 | + mat = np.zeros( (lenxs, lenys) ) |
| 57 | + for i in range(lenxs): |
| 58 | + for j in range(lenys): |
| 59 | + mat[i,j] = ssk(xs[i,0], ys[j,0], n, lbda, accum=True) |
| 60 | + |
| 61 | + mat_xs = np.zeros( (lenxs, 1) ) |
| 62 | + mat_ys = np.zeros( (lenys, 1) ) |
| 63 | + |
| 64 | + for i in range(lenxs): |
| 65 | + mat_xs[i] = ssk(xs[i,0], xs[i,0], n, lbda, accum=True) |
| 66 | + for j in range(lenys): |
| 67 | + mat_ys[j] = ssk(ys[j,0], ys[j,0], n, lbda, accum=True) |
| 68 | + |
| 69 | + return np.divide(mat, np.sqrt(mat_ys.T * mat_xs)) |
| 70 | + |
| 71 | +if __name__ == '__main__': |
| 72 | + print("Testing...") |
| 73 | + lbda = .6 |
| 74 | + assert abs( ssk("cat", "cart", 4, lbda, accum=True) - (3*lbda**2 + lbda**4 + lbda**5 + 2*lbda**7) ) < 1e-6 |
| 75 | + assert ssk("science is organized knowledge", "wisdom is organized life", 4, 1, accum=True) == 20538.0 |
| 76 | + |
| 77 | + xs = np.array( ["cat", "car", "cart", "camp", "shard"] ).reshape( (5,1) ) |
| 78 | + ys = np.array( ["a", "cd"] ).reshape( (2,1) ) |
| 79 | + assert string_kernel(xs, xs, 2, 1.)[0,0] == 1. |
| 80 | + |
| 81 | + print( string_kernel(xs, ys, 2, 1.) ) |
0 commit comments