diff --git a/README.md b/README.md index 33b0f592e..f05de6cdc 100644 --- a/README.md +++ b/README.md @@ -27,13 +27,13 @@ pip install git+https://github.com/wenet-e2e/wenet.git ``` -Command-line usage(use `-h` for parameters): +**Command-line usage** (use `-h` for parameters): ``` sh wenet --language chinese audio.wav ``` -Python programming usage: +**Python programming usage**: ``` python import wenet @@ -43,6 +43,8 @@ result = model.transcribe('audio.wav') print(result['text']) ``` +Please refer [python usage](docs/python_package.md) for more command line and python programming usage. + ### Install for training & deployment - Clone the repo diff --git a/docs/index.rst b/docs/index.rst index 007d85ad9..2d712cf73 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,6 +13,7 @@ wenet is an tansformer-based end-to-end ASR toolkit. :maxdepth: 2 :caption: Contents: + ./python_package.md ./train.rst ./production.rst ./reference.rst diff --git a/docs/python_package.md b/docs/python_package.md new file mode 100644 index 000000000..e01e4e822 --- /dev/null +++ b/docs/python_package.md @@ -0,0 +1,32 @@ +# Python Package + + +## Install + +``` sh +pip install git+https://github.com/wenet-e2e/wenet.git +``` + +## Command line Usage + +``` sh +wenet --language chinese audio.wav +``` + +You can specify the following parameters. + +* `-l` or `--language`: chinese/english are supported now. +* `-m` or `--model_dir`: your own model dir +* `-t` or `--show_tokens_info`: show the token level information such as timestamp, confidence, etc. + + +## Python Programming Usage + +``` python +import wenet + +model = wenet.load_model('chinese') +# or model = wenet.load_model(model_dir='xxx') +result = model.transcribe('audio.wav') +print(result['text']) +``` diff --git a/wenet/cli/model.py b/wenet/cli/model.py index 53ab3f069..57bb88a20 100644 --- a/wenet/cli/model.py +++ b/wenet/cli/model.py @@ -26,8 +26,7 @@ class Model: - def __init__(self, language: str): - model_dir = Hub.get_model_by_lang(language) + def __init__(self, model_dir: str): model_path = os.path.join(model_dir, 'final.zip') units_path = os.path.join(model_dir, 'units.txt') self.model = torch.jit.load(model_path) @@ -74,5 +73,7 @@ def transcribe(self, audio_file: str, tokens_info: bool = False): return result -def load_model(language: str) -> Model: - return Model(language) +def load_model(language: str = None, model_dir: str = None) -> Model: + if model_dir is None: + model_dir = Hub.get_model_by_lang(language) + return Model(model_dir) diff --git a/wenet/cli/transcribe.py b/wenet/cli/transcribe.py index 3d517c401..b2efef21a 100644 --- a/wenet/cli/transcribe.py +++ b/wenet/cli/transcribe.py @@ -14,19 +14,24 @@ import argparse -from wenet.cli.model import Model +from wenet.cli.model import load_model def get_args(): parser = argparse.ArgumentParser(description='') parser.add_argument('audio_file', help='audio file to transcribe') - parser.add_argument('--language', + parser.add_argument('-l', + '--language', choices=[ 'chinese', 'english', ], default='chinese', help='language type') + parser.add_argument('-m', + '--model_dir', + default=None, + help='specify your own model dir') parser.add_argument('-t', '--show_tokens_info', action='store_true', @@ -38,7 +43,7 @@ def get_args(): def main(): args = get_args() - model = Model(args.language) + model = load_model(args.language, args.model_dir) result = model.transcribe(args.audio_file, args.show_tokens_info) print(result)