-
Notifications
You must be signed in to change notification settings - Fork 194
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
196 additions
and
10 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from rare.decoders import AttentionDecoder | ||
from rare.protos import decoder_pb2 | ||
from rare.builders import rnn_cell_builder | ||
|
||
def build(decoder_config): | ||
if not isinstance(decoder_config, decoder_pb2.Decoder): | ||
raise ValueError('decoder_config not of type ' | ||
'decoder_pb2.Decoder') | ||
decoder_oneof = decoder_config.WhichOneof('decoder_oneof') | ||
|
||
if decoder_oneof == 'attention_decoder': | ||
attention_decoder_config = decoder_config.attention_decoder | ||
|
||
rnn_cell = rnn_cell_builder.build(attention_decoder_config.rnn_cell) | ||
num_attention_units = attention_decoder_config.num_attention_units | ||
attention_conv_kernel_size = attention_decoder_config.attention_conv_kernel_size | ||
|
||
attention_decoder_object = AttentionDecoder( | ||
rnn_cell, | ||
num_attention_units, | ||
attention_conv_kernel_size | ||
) | ||
return attention_decoder_object | ||
|
||
else: | ||
raise ValueError('Unknown decoder_oneof: {}'.format(decoder_oneof)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import tensorflow as tf | ||
|
||
from google.protobuf import text_format | ||
from rare.builders import decoder_builder | ||
from rare.protos import decoder_pb2 | ||
|
||
|
||
class DecoderBuilderTest(tf.test.TestCase): | ||
|
||
def test_build_decoder(self): | ||
decoder_text_proto = """ | ||
attention_decoder { | ||
rnn_cell { | ||
gru_cell { | ||
num_units: 256 | ||
} | ||
} | ||
num_attention_units: 256 | ||
attention_conv_kernel_size: 5 | ||
} | ||
""" | ||
decoder_proto = decoder_pb2.Decoder() | ||
text_format.Merge(decoder_proto, decoder_proto) | ||
decoder_object = decoder_builder.build(decoder_proto) | ||
|
||
self.assertEqual(decoder_object.num_attention_units, 256) | ||
self.assertEqual(decoder_object.attention_conv_kernel_size, 5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import tensorflow as tf | ||
from rare.protos import rnn_cell_pb2 | ||
|
||
|
||
def build(rnn_cell_config): | ||
if not isinstance(rnn_cell_config, rnn_cell_pb2.RnnCell): | ||
raise ValueError('rnn_cell_config not of type ' | ||
'rnn_cell_pb2.RnnCell') | ||
rnn_cell_oneof = rnn_cell_config.WhichOneof('rnn_cell_oneof') | ||
|
||
if rnn_cell_oneof == 'lstm_cell': | ||
lstm_cell_config = rnn_cell_config.lstm_cell | ||
|
||
lstm_cell_object = tf.contrib.rnn.LSTMCell( | ||
lstm_cell_config.num_units, | ||
use_peepholes=lstm_cell_config.use_peepholes, | ||
forget_bias=lstm_cell_config.forget_bias | ||
) | ||
return lstm_cell_object | ||
|
||
elif rnn_cell_oneof == 'gru_cell': | ||
gru_cell_config = rnn_cell_config.gru_cell | ||
|
||
gru_cell_object = tf.contrib.rnn.GRUCell( | ||
gru_cell_config.num_units, | ||
) | ||
return gru_cell_object | ||
|
||
else: | ||
raise ValueError('Unknown rnn_cell_oneof: {}'.format(rnn_cell_oneof)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import tensorflow as tf | ||
|
||
from google.protobuf import text_format | ||
from rare.builders import rnn_cell_builder | ||
from rare.protos import rnn_cell_pb2 | ||
|
||
|
||
class RnnCellBuilderTest(tf.test.TestCase): | ||
|
||
def test_build_lstm_cell(self): | ||
rnn_cell_text_proto = """ | ||
lstm_cell { | ||
num_units: 1024 | ||
use_peepholes: true | ||
forget_bias: 1.5 | ||
} | ||
""" | ||
rnn_cell_proto = rnn_cell_pb2.RnnCell() | ||
text_format.Merge(rnn_cell_text_proto, rnn_cell_proto) | ||
rnn_cell_object = rnn_cell_builder.build(rnn_cell_proto) | ||
|
||
self.assertEqual(rnn_cell_object.state_size, 1024) | ||
|
||
def test_build_gru_cell(self): | ||
rnn_cell_text_proto = """ | ||
gru_cell { | ||
num_units: 1024 | ||
} | ||
""" | ||
rnn_cell_proto = rnn_cell_pb2.RnnCell() | ||
text_format.Merge(rnn_cell_text_proto, rnn_cell_proto) | ||
rnn_cell_object = rnn_cell_builder.build(rnn_cell_proto) | ||
|
||
self.assertEqual(rnn_cell_object.state_size, 1024) | ||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
syntax = "proto2"; | ||
package rare.protos; | ||
|
||
import "rare/protos/rnn_cell.proto"; | ||
|
||
message DecoderConfig { | ||
oneof decoder_oneof { | ||
AttentionDecoder attention_decoder = 1; | ||
} | ||
} | ||
|
||
message AttentionDecoder { | ||
optional RnnCell rnn_cell = 1; | ||
optional uint32 num_attention_units = 2 [default=128]; | ||
optional uint32 attention_conv_kernel_size = 3 [default=5]; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
syntax = "proto2"; | ||
package rare.protos; | ||
|
||
message RnnCell { | ||
oneof rnn_cell_oneof { | ||
LstmCell lstm_cell = 1; | ||
GruCell gru_cell = 2; | ||
} | ||
} | ||
|
||
message LstmCell { | ||
optional uint32 num_units = 1 [default=128]; | ||
optional boolean use_peepholes = 2 [default=false]; | ||
optional float32 forget_bias = 3 [default=1.0]; | ||
} | ||
|
||
message GruCell { | ||
optional uint32 num_units = 1 [default=128]; | ||
} |