Skip to content

Commit 0d89160

Browse files
committed
feat(learning-stan): add LDA model
1 parent d5a0eed commit 0d89160

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

learning-pystan/lda.stan

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Latent Dirichlet Allocation (LDA)
2+
3+
data {
4+
int<lower=2> n_topics;
5+
int<lower=2> n_vocab;
6+
int<lower=1> n_words;
7+
int<lower=1> n_docs;
8+
int<lower=1, upper=n_vocab> words[n_words]; // i番目の単語のID
9+
int<lower=1, upper=n_docs> doc_of[n_words]; // i番目の単語が属するドキュメントのID
10+
vector<lower=0>[n_topics] alpha; // i番目のトピックの事前分布
11+
vector<lower=0>[n_vocab] beta; // IDがiである単語の事前分布
12+
}
13+
14+
parameters {
15+
simplex[n_topics] theta[n_docs]; // i番目のドキュメントのトピックの分布
16+
simplex[n_vocab] phi[n_topics]; // i番目のトピックの単語の分布
17+
}
18+
19+
model {
20+
// 事前分布からパラメータをサンプリングする
21+
for (i in 1:n_docs) {
22+
theta[i] ~ dirichlet(alpha);
23+
}
24+
for (i in 1:n_topics) {
25+
phi[i] ~ dirichlet(beta);
26+
}
27+
28+
// 対数尤度を計算する
29+
for (w in 1:n_words) {
30+
real gamma[n_topics];
31+
for (t in 1:n_topics) {
32+
// log(そのドキュメントのトピックが t である確率) + log(トピック t の下で単語 w が出現する確率)
33+
gamma[t] = log(theta[doc_of[w], t]) + log(phi[t, words[w]]);
34+
}
35+
target += log_sum_exp(gamma);
36+
}
37+
}

learning-pystan/stan_lda.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pystan
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
5+
6+
class Vocabulary:
7+
def __init__(self):
8+
self._word2id = {}
9+
self._id2word = {}
10+
11+
def intern(self, word):
12+
if word in self._word2id:
13+
return self._word2id[word]
14+
new_id = len(self._word2id) + 1
15+
self._word2id[word] = new_id
16+
self._id2word[new_id] = word
17+
return new_id
18+
19+
def word(self, wid):
20+
return self._id2word[wid]
21+
22+
@property
23+
def size(self):
24+
return len(self._word2id)
25+
26+
27+
def read_corpus(filename, max_lines):
28+
vocab = Vocabulary()
29+
vocab.intern("<unk>")
30+
31+
word_ids = []
32+
doc_ids = []
33+
34+
with open(filename) as f:
35+
for i, line in enumerate(f):
36+
if i >= max_lines:
37+
break
38+
line = line.strip()
39+
words = line.split(" ")
40+
for word in words:
41+
wid = vocab.intern(word)
42+
word_ids.append(wid)
43+
doc_ids.append(i + 1)
44+
45+
return (word_ids, doc_ids, vocab)
46+
47+
48+
def run_stan(word_ids, doc_ids, vocab, n_topics=10):
49+
# https://stats.stackexchange.com/questions/59684/what-are-typical-values-to-use-for-alpha-and-beta-in-latent-dirichlet-allocation
50+
alpha = np.full(n_topics, 50 / n_topics)
51+
beta = np.full(vocab.size, 0.1)
52+
53+
data = {
54+
"n_topics": n_topics,
55+
"n_vocab": vocab.size,
56+
"n_words": len(word_ids),
57+
"n_docs": max(doc_ids),
58+
"words": word_ids,
59+
"doc_of": doc_ids,
60+
"alpha": alpha,
61+
"beta": beta,
62+
}
63+
64+
with open("lda.stan", encoding="utf-8") as f:
65+
model_code = f.read()
66+
# pystan は非ASCII文字があると例外が飛んでしまうので、非ASCII文字を消す
67+
model_code = model_code.encode("ascii", errors="ignore").decode("ascii")
68+
69+
fit = pystan.stan(model_code=model_code, data=data, iter=300)
70+
return fit

0 commit comments

Comments
 (0)