Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Update lip reading example #13647

Merged
merged 48 commits into from
Feb 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
4ce3c9d
update lipnet
Dec 14, 2018
a2d237c
update utils
soeque1 Dec 14, 2018
c6007ea
Update example/gluon/lipnet/README.md
aaronmarkham Dec 27, 2018
6cd8667
Update example/gluon/lipnet/README.md
aaronmarkham Dec 27, 2018
a0071d5
Update example/gluon/lipnet/utils/multi.py
aaronmarkham Dec 27, 2018
5f78f05
Update example/gluon/lipnet/utils/preprocess_data.py
aaronmarkham Dec 27, 2018
089455d
Update example/gluon/lipnet/utils/multi.py
aaronmarkham Dec 27, 2018
ab79109
Update example/gluon/lipnet/utils/download_data.py
aaronmarkham Dec 27, 2018
9f10967
fix error for using gpu mode
seujung Dec 28, 2018
4aa4640
Add requirements
soeque1 Dec 28, 2018
c5503d9
Remove unnecessary requirements
soeque1 Dec 28, 2018
efe6295
Update .gitignore
soeque1 Dec 28, 2018
a958ad9
Remove inappropriate license file
soeque1 Dec 28, 2018
3e8a709
Changed relative path
soeque1 Dec 31, 2018
4e7ba27
Fix description
soeque1 Dec 31, 2018
b8fbb26
Fix description
soeque1 Dec 31, 2018
ac509a5
Fix description
soeque1 Dec 31, 2018
ddeb117
Fix description
soeque1 Dec 31, 2018
271f3ac
Change doc strings and add url reference
soeque1 Dec 31, 2018
2ba0b90
Fix align_path
soeque1 Dec 31, 2018
71d779d
Remove zip files
soeque1 Dec 31, 2018
a9da0e0
Fix bugs: source_path, n_process
soeque1 Dec 31, 2018
c003210
Fix target_path
soeque1 Dec 31, 2018
e2f1b42
Fix exception handler and resume the preprocess
soeque1 Jan 1, 2019
81b0185
Pass the output when it fails to detect the mouth
soeque1 Jan 3, 2019
54afdc5
Add exception during collecting images
soeque1 Jan 3, 2019
39d3378
Add the disk space and fix default align_path
soeque1 Jan 3, 2019
fcf5251
Change optimizer
soeque1 Jan 3, 2019
22afc90
Update readme for pip
soeque1 Jan 3, 2019
8e0d34b
Update README
soeque1 Jan 4, 2019
7a1bffc
Add checkpoint folder
soeque1 Jan 5, 2019
9bf3483
Apply to train using multiprocess
soeque1 Jan 8, 2019
37a0759
update network.py
seujung Jan 10, 2019
49c0861
Update readme
soeque1 Jan 10, 2019
f2b60f5
Add test code for beamsearch
soeque1 Jan 10, 2019
b3804e6
add space
Jan 23, 2019
7d6900d
delete line and fix code
Jan 23, 2019
0ad9d29
Add shebang in BeamSearch
soeque1 Jan 23, 2019
bf550fd
Fix trainer
soeque1 Jan 23, 2019
8a42b00
Fix trainer
soeque1 Jan 24, 2019
f487255
Hybridize lip model
soeque1 Jan 24, 2019
a18a96b
Fix the shape of model
soeque1 Jan 25, 2019
66c1b94
Apply to split train and validation
soeque1 Jan 25, 2019
05009c8
Add images
soeque1 Jan 25, 2019
b2f8d51
Update readme
soeque1 Jan 25, 2019
97dbcde
Fix typo and pylint
soeque1 Jan 25, 2019
ed3e4c1
Fix loss digits of save_file and typo
soeque1 Jan 27, 2019
de1eb6b
Add info of data split and batch size
soeque1 Jan 27, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions example/gluon/lipnet/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__/
utils/*.dat

170 changes: 170 additions & 0 deletions example/gluon/lipnet/BeamSearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python3

# Licensed to the Apache Software Foundation (ASF) under one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add shebang with python version? Preferably python3? thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Thanks!

# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Module : this module to decode using beam search
https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/utils/CTCDecoder/BeamSearch.py
"""

from __future__ import division
from __future__ import print_function
import numpy as np

class BeamEntry:
"""
information about one single beam at specific time-step
"""
def __init__(self):
self.prTotal = 0 # blank and non-blank
self.prNonBlank = 0 # non-blank
self.prBlank = 0 # blank
self.prText = 1 # LM score
self.lmApplied = False # flag if LM was already applied to this beam
self.labeling = () # beam-labeling

class BeamState:
"""
information about the beams at specific time-step
"""
def __init__(self):
self.entries = {}

def norm(self):
"""
length-normalise LM score
"""
for (k, _) in self.entries.items():
labelingLen = len(self.entries[k].labeling)
self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0))

def sort(self):
"""
return beam-labelings, sorted by probability
"""
beams = [v for (_, v) in self.entries.items()]
sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)
return [x.labeling for x in sortedBeams]

def applyLM(parentBeam, childBeam, classes, lm):
"""
calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars
"""
if lm and not childBeam.lmApplied:
c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char
c2 = classes[childBeam.labeling[-1]] # second char
lmFactor = 0.01 # influence of language model
bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other
childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence
childBeam.lmApplied = True # only apply LM once per beam entry

def addBeam(beamState, labeling):
"""
add beam if it does not yet exist
"""
if labeling not in beamState.entries:
beamState.entries[labeling] = BeamEntry()

def ctcBeamSearch(mat, classes, lm, k, beamWidth):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance you could add a quick unit test for this function? It looks complex, and could very easily contain a bug.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

"""
beam search as described by the paper of Hwang et al. and the paper of Graves et al.
"""

blankIdx = len(classes)
maxT, maxC = mat.shape

# initialise beam state
last = BeamState()
labeling = ()
last.entries[labeling] = BeamEntry()
last.entries[labeling].prBlank = 1
last.entries[labeling].prTotal = 1

# go over all time-steps
for t in range(maxT):
curr = BeamState()

# get beam-labelings of best beams
bestLabelings = last.sort()[0:beamWidth]

# go over best beams
for labeling in bestLabelings:

# probability of paths ending with a non-blank
prNonBlank = 0
# in case of non-empty beam
if labeling:
# probability of paths with repeated last char at the end
try:
prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]]
except FloatingPointError:
prNonBlank = 0

# probability of paths ending with a blank
prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx]

# add beam at current time-step if needed
addBeam(curr, labeling)

# fill in data
curr.entries[labeling].labeling = labeling
curr.entries[labeling].prNonBlank += prNonBlank
curr.entries[labeling].prBlank += prBlank
curr.entries[labeling].prTotal += prBlank + prNonBlank
curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from
curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling

# extend current beam-labeling
for c in range(maxC - 1):
# add new char to current beam-labeling
newLabeling = labeling + (c,)

# if new labeling contains duplicate char at the end, only consider paths ending with a blank
if labeling and labeling[-1] == c:
prNonBlank = mat[t, c] * last.entries[labeling].prBlank
else:
prNonBlank = mat[t, c] * last.entries[labeling].prTotal

# add beam at current time-step if needed
addBeam(curr, newLabeling)

# fill in data
curr.entries[newLabeling].labeling = newLabeling
curr.entries[newLabeling].prNonBlank += prNonBlank
curr.entries[newLabeling].prTotal += prNonBlank

# apply LM
applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm)

# set new beam state
last = curr

# normalise LM scores according to beam-labeling-length
last.norm()

# sort by probability
bestLabelings = last.sort()[:k] # get most probable labeling

output = []
for bestLabeling in bestLabelings:
# map labels to chars
res = ''
for l in bestLabeling:
res += classes[l]
output.append(res)
return output
194 changes: 194 additions & 0 deletions example/gluon/lipnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# LipNet: End-to-End Sentence-level Lipreading
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# LipNet: End-to-End Sentence-level Lipreading
<!---
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. See accompanying LICENSE file.
-->
# LipNet: End-to-End Sentence-level Lipreading

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

License isn't required on readme files. @szha if you feel strongly about adding it, I'm going to modify the readme in another PR later today and I can add it then.


---

Gluon implementation of [LipNet: End-to-End Sentence-level Lipreading](https://arxiv.org/abs/1611.01599)

![net_structure](asset/network_structure.png)

## Requirements
- Python 3.6.4
- MXnet 1.3.0
- The Required Disk Space: 35Gb
```
pip install -r requirements.txt
```

aaronmarkham marked this conversation as resolved.
Show resolved Hide resolved
---

## The Data
- The GRID audiovisual sentence corpus (http://spandh.dcs.shef.ac.uk/gridcorpus/)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to add the description from the website here:

GRID is a large multitalker audiovisual sentence corpus to support joint computational-behavioral studies in speech perception. In brief, the corpus consists of high-quality audio and video (facial) recordings of 1000 sentences spoken by each of 34 talkers (18 male, 16 female). Sentences are of the form "put red at G9 now". The corpus, together with transcriptions, is freely available for research use.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

- GRID is a large multitalker audiovisual sentence corpus to support joint computational-behavioral studies in speech perception. In brief, the corpus consists of high-quality audio and video (facial) recordings of 1000 sentences spoken by each of 34 talkers (18 male, 16 female). Sentences are of the form "put red at G9 now". The corpus, together with transcriptions, is freely available for research use.
- Video: (normal)(480 M each)
- Each movie has one sentence consist of 6 words.
- Align: word alignments(190 K each)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One sentence explaining 'word alignments' would be really useful for people new to the domain.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

- One align has 6 words. Each word has start time and end time. But this tutorial needs just sentence because of using ctc-loss.

---

## Prepare the Data
### (1) Download the data
- Outputs
- The Total Moives(mp4): 16GB
- The Total Aligns(text): 134MB
- Arguments
- src_path : Path for videos (default='./data/mp4s/')
- align_path : Path for aligns (default='./data/')
- n_process : num of process (default=1)

```
cd ./utils && python download_data.py --n_process=$(nproc)
```

### (2) Preprocess the Data: Extracting the mouth images from a video and save it.

* Using Face Landmark Detection(http://dlib.net/)

#### Preprocess (preprocess_data.py)
* If there is no landmark, it download automatically.
* Using Face Landmark Detection, It extract the mouth from a video.

- example:
- video: ./data/mp4s/s2/bbbf7p.mpg
- align(target): ./data/align/s2/bbbf7p.align
: 'sil bin blue by f seven please sil'


- Video to the images (75 Frames)

Frame 0 | Frame 1 | ... | Frame 74 |
:-------------------------:|:-------------------------:|:-------------------------:|:-------------------------:
![](asset/s2_bbbf7p_000.png) | ![](asset/s2_bbbf7p_001.png) | ... | ![](asset/s2_bbbf7p_074.png)

- Extract the mouth from images

Frame 0 | Frame 1 | ... | Frame 74 |
:-------------------------:|:-------------------------:|:-------------------------:|:-------------------------:
![](asset/mouth_000.png) | ![](asset/mouth_001.png) | ... | ![](asset/mouth_074.png)

* Save the result images into tgt_path.

----

### How to run

- Arguments
- src_path : Path for videos (default='./data/mp4s/')
- tgt_path : Path for preprocessed images (default='./data/datasets/')
- n_process : num of process (default=1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
You can run the preprocessing with just one processor, but this will take a long time (>48 hours). To use all of the available processors, use the following command:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update

- Outputs
- The Total Images(png): 19GB
- Elapsed time
- About 54 Hours using 1 process
- If you use the multi-processes, you can finish the number of processes faster.
- e.g) 9 hours using 6 processes

You can run the preprocessing with just one processor, but this will take a long time (>48 hours). To use all of the available processors, use the following command:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to add pre-processing time estimates (for specified hardware that you used) with multiple processors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

```
cd ./utils && python preprocess_data.py --n_process=$(nproc)
```

## Output: Data Structure

```
The training data folder should look like :
<train_data_root>
|--datasets
|--s1
|--bbir7s
|--mouth_000.png
|--mouth_001.png
...
|--bgaa8p
|--mouth_000.png
|--mouth_001.png
...
|--s2
...
|--align
|--bw1d8a.align
|--bggzzs.align
...

```

---

## Training

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to add training time estimates (for specified hardware that you used).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

- According to [LipNet: End-to-End Sentence-level Lipreading](https://arxiv.org/abs/1611.01599), four (S1, S2, S20, S22) of the 34 subjects are used for evaluation.
The other subjects are used for training.

- To use the multi-gpu, it is recommended to make the batch size $(num_gpus) times larger.

- e.g) 1-gpu and 128 batch_size > 2-gpus 256 batch_size


- arguments
- batch_size : Define batch size (default=64)
- epochs : Define total epochs (default=100)
- image_path : Path for lip image files (default='./data/datasets/')
- align_path : Path for align files (default='./data/align/')
- dr_rate : Dropout rate(default=0.5)
- num_gpus : Num of gpus (if num_gpus is 0, then use cpu) (default=1)
- num_workers : Num of workers when generating data (default=0)
- model_path : Path of pretrained model (defalut=None)

```
python main.py
```

---

## Test Environment
- 72 CPU cores
- 1 GPU (NVIDIA Tesla V100 SXM2 32 GB)
- 128 Batch Size

- It takes over 24 hours (60 epochs) to get some good results.

---

## Inference

- arguments
- batch_size : Define batch size (default=64)
- image_path : Path for lip image files (default='./data/datasets/')
- align_path : Path for align files (default='./data/align/')
- num_gpus : Num of gpus (if num_gpus is 0, then use cpu) (default=1)
- num_workers : Num of workers when generating data (default=0)
- data_type : 'train' or 'valid' (defalut='valid')
- model_path : Path of pretrained model (defalut=None)

```
python infer.py --model_path=$(model_path)
```


```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment about how to generate these, either notebook or main.py.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove iypnb and Add the infer.py file

[Target]
['lay green with a zero again',
'bin blue with r nine please',
'set blue with e five again',
'bin green by t seven soon',
'lay red at d five now',
'bin green in x eight now',
'bin blue with e one now',
'lay red at j nine now']
```

```
[Pred]
['lay green with s zero again',
'bin blue with r nine please',
'set blue with e five again',
'bin green by t seven soon',
'lay red at c five now',
'bin green in x eight now',
'bin blue with m one now',
'lay red at j nine now']
```


Binary file added example/gluon/lipnet/asset/mouth_000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example/gluon/lipnet/asset/mouth_001.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example/gluon/lipnet/asset/mouth_074.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example/gluon/lipnet/asset/s2_bbbf7p_000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example/gluon/lipnet/asset/s2_bbbf7p_001.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example/gluon/lipnet/asset/s2_bbbf7p_074.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 16 additions & 0 deletions example/gluon/lipnet/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading