Skip to content

Commit

Permalink
Update nl2sql_baseline.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bojone authored Sep 10, 2019
1 parent 992ded0 commit 68b358f
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions nl2sql_baseline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#! -*- coding: utf-8 -*-
# 追一科技2019年NL2SQL挑战赛的一个Baseline(个人作品,非官方发布,基于Bert)
# 比赛地址:https://tianchi.aliyun.com/competition/entrance/231716/introduction
# 目前全匹配率大概是50%左右
# 目前全匹配率大概是58%左右

import json
import uniout
Expand Down Expand Up @@ -225,9 +225,11 @@ def seq_gather(x):

x = Lambda(lambda x: K.expand_dims(x, 2))(x)
x4h = Lambda(lambda x: K.expand_dims(x, 1))(x4h)
pcsel_1 = Dense(1)(x)
pcsel_2 = Dense(1)(x4h)
pcsel_1 = Dense(256)(x)
pcsel_2 = Dense(256)(x4h)
pcsel = Lambda(lambda x: x[0] + x[1])([pcsel_1, pcsel_2])
pcsel = Activation('tanh')(pcsel)
pcsel = Dense(1)(pcsel)
pcsel = Lambda(lambda x: x[0][..., 0] - (1 - x[1]) * 1e10)([pcsel, hm])
pcsel = Activation('softmax')(pcsel)

Expand Down

0 comments on commit 68b358f

Please sign in to comment.