Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions egs/cifar/v1/image/get_allowed_lengths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env python3

# Copyright 2017 Hossein Hadian
# Apache 2.0


""" This script finds a set of allowed lengths for a given OCR/HWR data dir.
The allowed lengths are spaced by a factor (like 10%) and are written
in an output file named "allowed_lengths.txt" in the output data dir. This
file is later used by make_features.py to pad each image sufficiently so that
they all have an allowed length. This is intended for end2end chain training.
"""

import argparse
import os
import sys
import copy
import math
import logging

sys.path.insert(0, 'steps')
import libs.common as common_lib

logger = logging.getLogger('libs')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - "
"%(funcName)s - %(levelname)s ] %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)

def get_args():
parser = argparse.ArgumentParser(description="""This script finds a set of
allowed lengths for a given OCR/HWR data dir.
Intended for chain training.""")
parser.add_argument('factor', type=float, default=12,
help='Spacing (in percentage) between allowed lengths.')
parser.add_argument('srcdir', type=str,
help='path to source data dir')
parser.add_argument('--coverage-factor', type=float, default=0.05,
help="""Percentage of durations not covered from each
side of duration histogram.""")
parser.add_argument('--frame-subsampling-factor', type=int, default=3,
help="""Chain frame subsampling factor.
See steps/nnet3/chain/train.py""")

args = parser.parse_args()
return args


def read_kaldi_mapfile(path):
""" Read any Kaldi mapping file - like text, .scp files, etc.
"""

m = {}
with open(path, 'r', encoding='latin-1') as f:
for line in f:
line = line.strip()
sp_pos = line.find(' ')
key = line[:sp_pos]
val = line[sp_pos+1:]
m[key] = val
return m

def find_duration_range(img2len, coverage_factor):
"""Given a list of utterances, find the start and end duration to cover

If we try to cover
all durations which occur in the training set, the number of
allowed lengths could become very large.

Returns
-------
start_dur: int
end_dur: int
"""
durs = []
for im, imlen in img2len.items():
durs.append(int(imlen))
durs.sort()
to_ignore_dur = 0
tot_dur = sum(durs)
for d in durs:
to_ignore_dur += d
if to_ignore_dur * 100.0 / tot_dur > coverage_factor:
start_dur = d
break
to_ignore_dur = 0
for d in reversed(durs):
to_ignore_dur += d
if to_ignore_dur * 100.0 / tot_dur > coverage_factor:
end_dur = d
break
if start_dur < 30:
start_dur = 30 # a hard limit to avoid too many allowed lengths --not critical
return start_dur, end_dur


def find_allowed_durations(start_len, end_len, args):
"""Given the start and end duration, find a set of
allowed durations spaced by args.factor%. Also write
out the list of allowed durations and the corresponding
allowed lengths (in frames) on disk.

Returns
-------
allowed_durations: list of allowed durations (in seconds)
"""

allowed_lengths = []
length = start_len
with open(os.path.join(args.srcdir, 'allowed_lengths.txt'), 'w', encoding='latin-1') as fp:
while length < end_len:
if length % args.frame_subsampling_factor != 0:
length = (args.frame_subsampling_factor *
(length // args.frame_subsampling_factor))
allowed_lengths.append(length)
fp.write("{}\n".format(int(length)))
length *= args.factor
return allowed_lengths



def main():
args = get_args()
args.factor = 1.0 + args.factor / 100.0

image2length = read_kaldi_mapfile(os.path.join(args.srcdir, 'image2num_frames'))

start_dur, end_dur = find_duration_range(image2length, args.coverage_factor)
logger.info("Lengths in the range [{},{}] will be covered. "
"Coverage rate: {}%".format(start_dur, end_dur,
100.0 - args.coverage_factor * 2))
logger.info("There will be {} unique allowed lengths "
"for the images.".format(int(math.log(end_dur / start_dur) /
math.log(args.factor))))

allowed_durations = find_allowed_durations(start_dur, end_dur, args)


if __name__ == '__main__':
main()
62 changes: 62 additions & 0 deletions egs/cifar/v1/image/get_image2num_frames.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3

# Copyright 2018 Hossein Hadian


""" This script computes the image lengths (with padding) in an image data dir.
The output is written to 'image2num_frames' in the given data dir. This
file is later used by image/get_allowed_lengths.py to find a set of allowed lengths
for the data dir. The output format is similar to utt2num_frames

"""

import argparse
import os
import sys
import numpy as np
from scipy import misc

parser = argparse.ArgumentParser(description="""Computes the image lengths (i.e. width) in an image data dir
and writes them (by default) to image2num_frames.""")
parser.add_argument('dir', type=str,
help='Source data directory (containing images.scp)')
parser.add_argument('--out-ark', type=str, default=None,
help='Where to write the output image-to-num_frames info. '
'Default: "dir"/image2num_frames')
parser.add_argument('--feat-dim', type=int, default=40,
help='Size to scale the height of all images')
parser.add_argument('--padding', type=int, default=5,
help='Number of white pixels to pad on the left'
'and right side of the image.')
args = parser.parse_args()


def get_scaled_image_length(im):
scale_size = args.feat_dim
sx = im.shape[1]
sy = im.shape[0]
scale = (1.0 * scale_size) / sy
nx = int(scale * sx)
return nx

### main ###
data_list_path = os.path.join(args.dir,'images.scp')

if not args.out_ark:
args.out_ark = os.path.join(args.dir,'image2num_frames')
if args.out_ark == '-':
out_fh = sys.stdout
else:
out_fh = open(args.out_ark, 'w', encoding='latin-1')

with open(data_list_path) as f:
for line in f:
line = line.strip()
line_vect = line.split(' ')
image_id = line_vect[0]
image_path = line_vect[1]
im = misc.imread(image_path)
im_len = get_scaled_image_length(im) + (args.padding * 2)
print('{} {}'.format(image_id, im_len), file=out_fh)

out_fh.close()
Loading