This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* update lipnet * update utils * Update example/gluon/lipnet/README.md Co-Authored-By: seujung <[email protected]> * Update example/gluon/lipnet/README.md Co-Authored-By: seujung <[email protected]> * Update example/gluon/lipnet/utils/multi.py Co-Authored-By: seujung <[email protected]> * Update example/gluon/lipnet/utils/preprocess_data.py Co-Authored-By: seujung <[email protected]> * Update example/gluon/lipnet/utils/multi.py Co-Authored-By: seujung <[email protected]> * Update example/gluon/lipnet/utils/download_data.py Co-Authored-By: seujung <[email protected]> * fix error for using gpu mode * Add requirements * Remove unnecessary requirements * Update .gitignore * Remove inappropriate license file * Changed relative path * Fix description * Fix description * Fix description * Fix description * Change doc strings and add url reference * Fix align_path * Remove zip files * Fix bugs: source_path, n_process * Fix target_path * Fix exception handler and resume the preprocess * Pass the output when it fails to detect the mouth * Add exception during collecting images * Add the disk space and fix default align_path * Change optimizer * Update readme for pip * Update README * Add checkpoint folder * Apply to train using multiprocess * update network.py * delete batchnorm comment *fix dropout * fix loading ndarray as F * add space * Update readme * Add the info of GRID Data * Add the info of word alignments * Add total download size * Add time for preprocessing * Add test code for beamsearch * add space * delete line and fix code * Add shebang in BeamSearch * Fix trainer * Add space line * Fix appeding losses * Fix trainer * Delete debug line in data_loader * Move transpose of input into data_loader * Delete trailing-whitespace * Hybridize lip model * Hybridize model * Refactor the len of input sequence * Fix the shape of model * Apply to split train and validation * Split data into train and valid * Update Readme * Add infer.py * Remove ipynb * Apply to continual learning * Add images * Update readme * Fix typo and pylint * Fix loss digits of save_file and typo * Add info of data split and batch size
- Loading branch information
1 parent
199bc7e
commit 7ff6ad1
Showing
27 changed files
with
2,141 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
__pycache__/ | ||
utils/*.dat | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
""" | ||
Module : this module to decode using beam search | ||
https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/utils/CTCDecoder/BeamSearch.py | ||
""" | ||
|
||
from __future__ import division | ||
from __future__ import print_function | ||
import numpy as np | ||
|
||
class BeamEntry: | ||
""" | ||
information about one single beam at specific time-step | ||
""" | ||
def __init__(self): | ||
self.prTotal = 0 # blank and non-blank | ||
self.prNonBlank = 0 # non-blank | ||
self.prBlank = 0 # blank | ||
self.prText = 1 # LM score | ||
self.lmApplied = False # flag if LM was already applied to this beam | ||
self.labeling = () # beam-labeling | ||
|
||
class BeamState: | ||
""" | ||
information about the beams at specific time-step | ||
""" | ||
def __init__(self): | ||
self.entries = {} | ||
|
||
def norm(self): | ||
""" | ||
length-normalise LM score | ||
""" | ||
for (k, _) in self.entries.items(): | ||
labelingLen = len(self.entries[k].labeling) | ||
self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0)) | ||
|
||
def sort(self): | ||
""" | ||
return beam-labelings, sorted by probability | ||
""" | ||
beams = [v for (_, v) in self.entries.items()] | ||
sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText) | ||
return [x.labeling for x in sortedBeams] | ||
|
||
def applyLM(parentBeam, childBeam, classes, lm): | ||
""" | ||
calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars | ||
""" | ||
if lm and not childBeam.lmApplied: | ||
c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char | ||
c2 = classes[childBeam.labeling[-1]] # second char | ||
lmFactor = 0.01 # influence of language model | ||
bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other | ||
childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence | ||
childBeam.lmApplied = True # only apply LM once per beam entry | ||
|
||
def addBeam(beamState, labeling): | ||
""" | ||
add beam if it does not yet exist | ||
""" | ||
if labeling not in beamState.entries: | ||
beamState.entries[labeling] = BeamEntry() | ||
|
||
def ctcBeamSearch(mat, classes, lm, k, beamWidth): | ||
""" | ||
beam search as described by the paper of Hwang et al. and the paper of Graves et al. | ||
""" | ||
|
||
blankIdx = len(classes) | ||
maxT, maxC = mat.shape | ||
|
||
# initialise beam state | ||
last = BeamState() | ||
labeling = () | ||
last.entries[labeling] = BeamEntry() | ||
last.entries[labeling].prBlank = 1 | ||
last.entries[labeling].prTotal = 1 | ||
|
||
# go over all time-steps | ||
for t in range(maxT): | ||
curr = BeamState() | ||
|
||
# get beam-labelings of best beams | ||
bestLabelings = last.sort()[0:beamWidth] | ||
|
||
# go over best beams | ||
for labeling in bestLabelings: | ||
|
||
# probability of paths ending with a non-blank | ||
prNonBlank = 0 | ||
# in case of non-empty beam | ||
if labeling: | ||
# probability of paths with repeated last char at the end | ||
try: | ||
prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]] | ||
except FloatingPointError: | ||
prNonBlank = 0 | ||
|
||
# probability of paths ending with a blank | ||
prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx] | ||
|
||
# add beam at current time-step if needed | ||
addBeam(curr, labeling) | ||
|
||
# fill in data | ||
curr.entries[labeling].labeling = labeling | ||
curr.entries[labeling].prNonBlank += prNonBlank | ||
curr.entries[labeling].prBlank += prBlank | ||
curr.entries[labeling].prTotal += prBlank + prNonBlank | ||
curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from | ||
curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling | ||
|
||
# extend current beam-labeling | ||
for c in range(maxC - 1): | ||
# add new char to current beam-labeling | ||
newLabeling = labeling + (c,) | ||
|
||
# if new labeling contains duplicate char at the end, only consider paths ending with a blank | ||
if labeling and labeling[-1] == c: | ||
prNonBlank = mat[t, c] * last.entries[labeling].prBlank | ||
else: | ||
prNonBlank = mat[t, c] * last.entries[labeling].prTotal | ||
|
||
# add beam at current time-step if needed | ||
addBeam(curr, newLabeling) | ||
|
||
# fill in data | ||
curr.entries[newLabeling].labeling = newLabeling | ||
curr.entries[newLabeling].prNonBlank += prNonBlank | ||
curr.entries[newLabeling].prTotal += prNonBlank | ||
|
||
# apply LM | ||
applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm) | ||
|
||
# set new beam state | ||
last = curr | ||
|
||
# normalise LM scores according to beam-labeling-length | ||
last.norm() | ||
|
||
# sort by probability | ||
bestLabelings = last.sort()[:k] # get most probable labeling | ||
|
||
output = [] | ||
for bestLabeling in bestLabelings: | ||
# map labels to chars | ||
res = '' | ||
for l in bestLabeling: | ||
res += classes[l] | ||
output.append(res) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# LipNet: End-to-End Sentence-level Lipreading | ||
|
||
--- | ||
|
||
Gluon implementation of [LipNet: End-to-End Sentence-level Lipreading](https://arxiv.org/abs/1611.01599) | ||
|
||
![net_structure](asset/network_structure.png) | ||
|
||
## Requirements | ||
- Python 3.6.4 | ||
- MXnet 1.3.0 | ||
- The Required Disk Space: 35Gb | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
--- | ||
|
||
## The Data | ||
- The GRID audiovisual sentence corpus (http://spandh.dcs.shef.ac.uk/gridcorpus/) | ||
- GRID is a large multitalker audiovisual sentence corpus to support joint computational-behavioral studies in speech perception. In brief, the corpus consists of high-quality audio and video (facial) recordings of 1000 sentences spoken by each of 34 talkers (18 male, 16 female). Sentences are of the form "put red at G9 now". The corpus, together with transcriptions, is freely available for research use. | ||
- Video: (normal)(480 M each) | ||
- Each movie has one sentence consist of 6 words. | ||
- Align: word alignments(190 K each) | ||
- One align has 6 words. Each word has start time and end time. But this tutorial needs just sentence because of using ctc-loss. | ||
|
||
--- | ||
|
||
## Prepare the Data | ||
### (1) Download the data | ||
- Outputs | ||
- The Total Moives(mp4): 16GB | ||
- The Total Aligns(text): 134MB | ||
- Arguments | ||
- src_path : Path for videos (default='./data/mp4s/') | ||
- align_path : Path for aligns (default='./data/') | ||
- n_process : num of process (default=1) | ||
|
||
``` | ||
cd ./utils && python download_data.py --n_process=$(nproc) | ||
``` | ||
|
||
### (2) Preprocess the Data: Extracting the mouth images from a video and save it. | ||
|
||
* Using Face Landmark Detection(http://dlib.net/) | ||
|
||
#### Preprocess (preprocess_data.py) | ||
* If there is no landmark, it download automatically. | ||
* Using Face Landmark Detection, It extract the mouth from a video. | ||
|
||
- example: | ||
- video: ./data/mp4s/s2/bbbf7p.mpg | ||
- align(target): ./data/align/s2/bbbf7p.align | ||
: 'sil bin blue by f seven please sil' | ||
|
||
|
||
- Video to the images (75 Frames) | ||
|
||
Frame 0 | Frame 1 | ... | Frame 74 | | ||
:-------------------------:|:-------------------------:|:-------------------------:|:-------------------------: | ||
![](asset/s2_bbbf7p_000.png) | ![](asset/s2_bbbf7p_001.png) | ... | ![](asset/s2_bbbf7p_074.png) | ||
|
||
- Extract the mouth from images | ||
|
||
Frame 0 | Frame 1 | ... | Frame 74 | | ||
:-------------------------:|:-------------------------:|:-------------------------:|:-------------------------: | ||
![](asset/mouth_000.png) | ![](asset/mouth_001.png) | ... | ![](asset/mouth_074.png) | ||
|
||
* Save the result images into tgt_path. | ||
|
||
---- | ||
|
||
### How to run | ||
|
||
- Arguments | ||
- src_path : Path for videos (default='./data/mp4s/') | ||
- tgt_path : Path for preprocessed images (default='./data/datasets/') | ||
- n_process : num of process (default=1) | ||
|
||
- Outputs | ||
- The Total Images(png): 19GB | ||
- Elapsed time | ||
- About 54 Hours using 1 process | ||
- If you use the multi-processes, you can finish the number of processes faster. | ||
- e.g) 9 hours using 6 processes | ||
|
||
You can run the preprocessing with just one processor, but this will take a long time (>48 hours). To use all of the available processors, use the following command: | ||
|
||
``` | ||
cd ./utils && python preprocess_data.py --n_process=$(nproc) | ||
``` | ||
|
||
## Output: Data Structure | ||
|
||
``` | ||
The training data folder should look like : | ||
<train_data_root> | ||
|--datasets | ||
|--s1 | ||
|--bbir7s | ||
|--mouth_000.png | ||
|--mouth_001.png | ||
... | ||
|--bgaa8p | ||
|--mouth_000.png | ||
|--mouth_001.png | ||
... | ||
|--s2 | ||
... | ||
|--align | ||
|--bw1d8a.align | ||
|--bggzzs.align | ||
... | ||
``` | ||
|
||
--- | ||
|
||
## Training | ||
|
||
- According to [LipNet: End-to-End Sentence-level Lipreading](https://arxiv.org/abs/1611.01599), four (S1, S2, S20, S22) of the 34 subjects are used for evaluation. | ||
The other subjects are used for training. | ||
|
||
- To use the multi-gpu, it is recommended to make the batch size $(num_gpus) times larger. | ||
|
||
- e.g) 1-gpu and 128 batch_size > 2-gpus 256 batch_size | ||
|
||
|
||
- arguments | ||
- batch_size : Define batch size (default=64) | ||
- epochs : Define total epochs (default=100) | ||
- image_path : Path for lip image files (default='./data/datasets/') | ||
- align_path : Path for align files (default='./data/align/') | ||
- dr_rate : Dropout rate(default=0.5) | ||
- num_gpus : Num of gpus (if num_gpus is 0, then use cpu) (default=1) | ||
- num_workers : Num of workers when generating data (default=0) | ||
- model_path : Path of pretrained model (defalut=None) | ||
|
||
``` | ||
python main.py | ||
``` | ||
|
||
--- | ||
|
||
## Test Environment | ||
- 72 CPU cores | ||
- 1 GPU (NVIDIA Tesla V100 SXM2 32 GB) | ||
- 128 Batch Size | ||
|
||
- It takes over 24 hours (60 epochs) to get some good results. | ||
|
||
--- | ||
|
||
## Inference | ||
|
||
- arguments | ||
- batch_size : Define batch size (default=64) | ||
- image_path : Path for lip image files (default='./data/datasets/') | ||
- align_path : Path for align files (default='./data/align/') | ||
- num_gpus : Num of gpus (if num_gpus is 0, then use cpu) (default=1) | ||
- num_workers : Num of workers when generating data (default=0) | ||
- data_type : 'train' or 'valid' (defalut='valid') | ||
- model_path : Path of pretrained model (defalut=None) | ||
|
||
``` | ||
python infer.py --model_path=$(model_path) | ||
``` | ||
|
||
|
||
``` | ||
[Target] | ||
['lay green with a zero again', | ||
'bin blue with r nine please', | ||
'set blue with e five again', | ||
'bin green by t seven soon', | ||
'lay red at d five now', | ||
'bin green in x eight now', | ||
'bin blue with e one now', | ||
'lay red at j nine now'] | ||
``` | ||
|
||
``` | ||
[Pred] | ||
['lay green with s zero again', | ||
'bin blue with r nine please', | ||
'set blue with e five again', | ||
'bin green by t seven soon', | ||
'lay red at c five now', | ||
'bin green in x eight now', | ||
'bin blue with m one now', | ||
'lay red at j nine now'] | ||
``` | ||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. |
Oops, something went wrong.