|
| 1 | +__author__ = 'jcapde87' |
| 2 | + |
| 3 | +import sys, pickle |
| 4 | +import numpy as np |
| 5 | +from os.path import expanduser |
| 6 | + |
| 7 | +PROJECT_PATH = "/Documents/Tweet-models/Biterm/" |
| 8 | + |
| 9 | +def gibbs_sampler_LDA(It, V, B, num_topics, b, alpha=1., beta=0.1): |
| 10 | + print "Biterm model ------ " |
| 11 | + print "Corpus length: " + str(len(b)) |
| 12 | + print "Number of topics: " + str(num_topics) |
| 13 | + print "alpha: " + str(alpha) + " beta: " + str(beta) |
| 14 | + |
| 15 | + Z = np.zeros(B) |
| 16 | + Nwz = np.zeros((V, num_topics)) |
| 17 | + Nz = np.zeros(num_topics) |
| 18 | + |
| 19 | + theta = np.random.dirichlet([alpha]*num_topics, 1) |
| 20 | + for ibi, bi in enumerate(b): |
| 21 | + topics = np.random.choice(num_topics, 1, p=theta[0,:])[0] |
| 22 | + Nwz[bi[0], topics] += 1 |
| 23 | + Nwz[bi[1], topics] += 1 |
| 24 | + Nz[topics] += 1 |
| 25 | + Z[ibi] = topics |
| 26 | + |
| 27 | + for it in xrange(It): |
| 28 | + print "Iteration: " + str(it) |
| 29 | + Nzold = np.copy(Nz) |
| 30 | + for ibi, bi in enumerate(b): |
| 31 | + Nwz[bi[0], Z[ibi]] -= 1 |
| 32 | + Nwz[bi[1], Z[ibi]] -= 1 |
| 33 | + Nz[Z[ibi]] -= 1 |
| 34 | + pz = (Nz + alpha)*(Nwz[bi[0],:]+beta)*(Nwz[bi[1],:]+beta)/(Nwz.sum(axis=0)+beta*V)**2 |
| 35 | + pz = pz/pz.sum() |
| 36 | + Z[ibi] = np.random.choice(num_topics, 1, p=pz) |
| 37 | + Nwz[bi[0], Z[ibi]] += 1 |
| 38 | + Nwz[bi[1], Z[ibi]] += 1 |
| 39 | + Nz[Z[ibi]] += 1 |
| 40 | + print "Variation between iterations: " + str(np.sqrt(np.sum((Nz-Nzold)**2))) |
| 41 | + return Nz, Nwz, Z |
| 42 | + |
| 43 | +def pbd(doc,names): |
| 44 | + ret = [] |
| 45 | + retnames = [] |
| 46 | + for term1 in set(doc): |
| 47 | + cnts = 0. |
| 48 | + for term2 in doc: |
| 49 | + if term1 == term2: |
| 50 | + cnts +=1. |
| 51 | + ret.append(cnts/len(doc)) |
| 52 | + retnames.append(term1) |
| 53 | + if names: |
| 54 | + return retnames |
| 55 | + else: |
| 56 | + return ret |
| 57 | + |
| 58 | +if __name__ == "__main__": |
| 59 | + home = expanduser("~") |
| 60 | + project = home + PROJECT_PATH |
| 61 | + |
| 62 | + print "Set project directory: " + project |
| 63 | + |
| 64 | + INFILE = "14S2015_cl" |
| 65 | + |
| 66 | + file = open(project + "data/"+INFILE+".pkl", 'rb') |
| 67 | + tweets, hashtags = pickle.load(file) |
| 68 | + |
| 69 | + tweets = [tweet for tweet in tweets if len(tweet)>3] |
| 70 | + tweets = tweets |
| 71 | + N = len(tweets) |
| 72 | + dictionary = np.array(list(set([word for tweet in tweets for word in tweet]))) |
| 73 | + V = len(dictionary) |
| 74 | + alpha = 1. |
| 75 | + beta = 0.1 |
| 76 | + |
| 77 | + btmp = [[(np.where(dictionary==word1)[0][0], np.where(dictionary==word2)[0][0]) for iword1, word1 in enumerate(tweet) for iword2, word2 in enumerate(tweet) if iword1 < iword2] for tweet in tweets] |
| 78 | + |
| 79 | + aux = [] |
| 80 | + for bi in btmp: |
| 81 | + aux.extend(bi) |
| 82 | + b = aux |
| 83 | + B = len(b) |
| 84 | + bset = set(b) |
| 85 | + num_topics = 10 |
| 86 | + pbd_cts = [pbd(doc, False) for doc in btmp] |
| 87 | + pbd_names = [pbd(doc, True) for doc in btmp] |
| 88 | + |
| 89 | + Nz, Nwz, Z = gibbs_sampler_LDA(It=20, V=V, B=B, num_topics=num_topics, b=b, alpha=alpha, beta=beta) |
| 90 | + |
| 91 | + topics = [[dictionary[ident] for ident in np.argsort(-Nwz[:,k])[0:10]] for k in xrange(num_topics)] |
| 92 | + print "TOP 10 words per topic" |
| 93 | + for topic in topics: |
| 94 | + print topic |
| 95 | + print " ---- " |
| 96 | + |
| 97 | + thetaz = (Nz + alpha)/(B + num_topics*alpha) |
| 98 | + phiwz = (Nwz + beta)/np.tile((Nwz.sum(axis=0)+V*beta),(V,1)) |
| 99 | + |
| 100 | + pzb = [[list(thetaz*phiwz[term[0],:]*phiwz[term[1],:]/(thetaz*phiwz[term[0],:]*phiwz[term[1],:]).sum()) for term in set(doc)] for doc in btmp] |
| 101 | + |
| 102 | + pdz = [] |
| 103 | + for idoc, doc in enumerate(pzb): |
| 104 | + aux = 0 |
| 105 | + for iterm, term in enumerate(doc): |
| 106 | + aux += np.array(term) * pbd_cts[idoc][iterm] |
| 107 | + pdz.append(aux) |
| 108 | + |
| 109 | + pdz = np.array(pdz) |
| 110 | + |
| 111 | + topics = [[tweets[ident] for ident in np.argsort(-pdz[:,k])[0:5]] for k in xrange(num_topics)] |
| 112 | + print "TOP 5 tweets per topic" |
| 113 | + for topic in topics: |
| 114 | + for tweet in topic: |
| 115 | + print tweet |
| 116 | + print " ---- " |
0 commit comments