Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

支持断点继续训练,支持TensorFlow2.0,增加predict功能。 #21

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

LZY2006
Copy link

@LZY2006 LZY2006 commented Jan 10, 2021

支持断点继续训练,若未达到目标次数会按照最后一次保存的模型继续训练;若已经到达目标次数,会直接停止。

支持TF2.0,并未改动大多数代码,只是启用了TF2.0已经弃用的1.0功能,并且关闭了2.0的功能。

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

增加了predict功能,给出start_string。

python predict.py \
  --converter_path model/torch_gen/converter.pkl \
  --checkpoint_path  model/torch_gen \
  --max_length 1500 \
  --start_string "    raise "

会输出如下结果:

    raise  ->  utized_inpu   probability: 0.6539345979690552
    raise  -> es()\r\n       probability: 0.1654084473848343
    raise  ->  pistent_and   probability: 0.07784435153007507
    raise  ->  al_module_t   probability: 0.0615621916949749
    raise  ->  Porgex(self   probability: 0.04125040024518967

另外加入了预处理好的pytorch的代码,在data/torch_code.txt中,去除了#注释,把所有字符串都替换成了"msg" 'msg' """msg""" '''msg'''的形式。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant