Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Unify the style of comments
Browse files Browse the repository at this point in the history
Unify the style of comments suggested by @sandeep-krishnamurthy
  • Loading branch information
cchung100m committed Feb 11, 2019
1 parent 69f3334 commit f9f3313
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 95 deletions.
16 changes: 5 additions & 11 deletions example/cnn_text_classification/text_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@

# -*- coding: utf-8 -*-

"""
Implementing CNN + Highway Network for Text Classification in MXNet
"""
"""Implementing CNN + Highway Network for Text Classification in MXNet"""

import os
import logging
Expand Down Expand Up @@ -59,8 +57,7 @@


def save_model():
"""
Save cnn model
"""Save cnn model
Returns
----------
Expand All @@ -72,8 +69,7 @@ def save_model():


def data_iter(batch_size, num_embed, pre_trained_word2vec=False):
"""
Construct data iter
"""Construct data iter
Parameters
----------
Expand Down Expand Up @@ -136,8 +132,7 @@ def data_iter(batch_size, num_embed, pre_trained_word2vec=False):
def sym_gen(batch_size, sentences_size, num_embed, vocabulary_size,
num_label=2, filter_list=None, num_filter=100,
dropout=0.0, pre_trained_word2vec=False):
"""
Generate network symbol
"""Generate network symbol
Parameters
----------
Expand Down Expand Up @@ -204,8 +199,7 @@ def sym_gen(batch_size, sentences_size, num_embed, vocabulary_size,


def train(symbol_data, train_iterator, valid_iterator, data_column_names, target_names):
"""
Train cnn model
"""Train cnn model
Parameters
----------
Expand Down
46 changes: 14 additions & 32 deletions example/ctc/captcha_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
""" Helper classes for multiprocess captcha image generation
This module also provides script for saving captcha images to file using CLI.
"""

Expand All @@ -29,8 +28,7 @@


class CaptchaGen(object):
"""
Generates a captcha image
"""Generates a captcha image
"""
def __init__(self, h, w, font_paths):
"""
Expand All @@ -48,8 +46,7 @@ def __init__(self, h, w, font_paths):
self.w = w

def image(self, captcha_str):
"""
Generate a greyscale captcha image representing number string
"""Generate a greyscale captcha image representing number string
Parameters
----------
Expand All @@ -71,8 +68,7 @@ def image(self, captcha_str):


class DigitCaptcha(object):
"""
Provides shape() and get() interface for digit-captcha image generation
"""Provides shape() and get() interface for digit-captcha image generation
"""
def __init__(self, font_paths, h, w, num_digit_min, num_digit_max):
"""
Expand All @@ -95,8 +91,7 @@ def __init__(self, font_paths, h, w, num_digit_min, num_digit_max):

@property
def shape(self):
"""
Returns shape of the image data generated
"""Returns shape of the image data generated
Returns
-------
Expand All @@ -105,8 +100,7 @@ def shape(self):
return self.captcha.h, self.captcha.w

def get(self):
"""
Get an image from the queue
"""Get an image from the queue
Returns
-------
Expand All @@ -117,9 +111,8 @@ def get(self):

@staticmethod
def get_rand(num_digit_min, num_digit_max):
"""
Generates a character string of digits. Number of digits are
between self.num_digit_min and self.num_digit_max
"""Generates a character string of digits. Number of digits are
between self.num_digit_min and self.num_digit_max
Returns
-------
str
Expand All @@ -131,8 +124,7 @@ def get_rand(num_digit_min, num_digit_max):
return buf

def _gen_sample(self):
"""
Generate a random captcha image sample
"""Generate a random captcha image sample
Returns
-------
(numpy.ndarray, str)
Expand All @@ -143,13 +135,10 @@ def _gen_sample(self):


class MPDigitCaptcha(DigitCaptcha):
"""
Handles multi-process captcha image generation
"""Handles multi-process captcha image generation
"""
def __init__(self, font_paths, h, w, num_digit_min, num_digit_max, num_processes, max_queue_size):
"""
Parameters
"""Parameters
----------
font_paths: list of str
List of path to ttf font files
Expand All @@ -170,14 +159,11 @@ def __init__(self, font_paths, h, w, num_digit_min, num_digit_max, num_processes
self.mp_data = MPData(num_processes, max_queue_size, self._gen_sample)

def start(self):
"""
Starts the processes
"""
"""Starts the processes"""
self.mp_data.start()

def get(self):
"""
Get an image from the queue
"""Get an image from the queue
Returns
-------
Expand All @@ -187,19 +173,15 @@ def get(self):
return self.mp_data.get()

def reset(self):
"""
Resets the generator by stopping all processes
"""
"""Resets the generator by stopping all processes"""
self.mp_data.reset()


if __name__ == '__main__':
import argparse

def main():
"""
Program entry point
"""
"""Program entry point"""
parser = argparse.ArgumentParser()
parser.add_argument("font_path", help="Path to ttf font file")
parser.add_argument("output", help="Output filename including extension (e.g. 'sample.jpg')")
Expand Down
7 changes: 2 additions & 5 deletions example/ctc/ctc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@


class CtcMetrics(object):
"""
Module for calculating the prediction accuracy during training. Two accuracy measures are implemented:
"""Module for calculating the prediction accuracy during training. Two accuracy measures are implemented:
A simple accuracy measure that calculates number of correct predictions divided by total number of predictions
and a second accuracy measure based on sum of Longest Common Sequence(LCS) ratio of all predictions divided by total
number of predictions
Expand All @@ -33,9 +32,7 @@ def __init__(self, seq_len):

@staticmethod
def ctc_label(p):
"""
Iterates through p, identifying non-zero and non-repeating values, and returns them in a list
Parameters
"""Iterates through p, identifying non-zero and non-repeating values, and returns them in a list Parameters
----------
p: list of int
Expand Down
4 changes: 1 addition & 3 deletions example/ctc/hyperparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@


class Hyperparams(object):
"""
Hyperparameters for LSTM network
"""
"""Hyperparameters for LSTM network"""
def __init__(self):
# Training hyper parameters
self._train_epoch_size = 30000
Expand Down
7 changes: 2 additions & 5 deletions example/ctc/lstm_ocr_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def lstm_init_states(batch_size):


def load_module(prefix, epoch, data_names, data_shapes):
"""
Loads the model from checkpoint specified by prefix and epoch, binds it
"""Loads the model from checkpoint specified by prefix and epoch, binds it
to an executor, and sets its parameters and returns a mx.mod.Module
"""
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
Expand All @@ -64,9 +63,7 @@ def load_module(prefix, epoch, data_names, data_shapes):


def main():
"""
Program entry point
"""
"""Program entry point"""
parser = argparse.ArgumentParser()
parser.add_argument("path", help="Path to the CAPTCHA image file")
parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='ocr')
Expand Down
8 changes: 2 additions & 6 deletions example/ctc/lstm_ocr_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def get_fonts(path):


def parse_args():
"""
Parse command line arguments
"""
"""Parse command line arguments"""
parser = argparse.ArgumentParser()
parser.add_argument("font_path", help="Path to ttf font file or directory containing ttf files")
parser.add_argument("--loss", help="'ctc' or 'warpctc' loss [Default 'ctc']", default='ctc')
Expand All @@ -58,9 +56,7 @@ def parse_args():


def main():
"""
Program entry point
"""
"""Program entry point"""
args = parse_args()
if not any(args.loss == s for s in ['ctc', 'warpctc']):
raise ValueError("Invalid loss '{}' (must be 'ctc' or 'warpctc')".format(args.loss))
Expand Down
26 changes: 8 additions & 18 deletions example/ctc/multiproc_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,17 @@
from Queue import Full as QFullExcept
from Queue import Empty as QEmptyExcept


class MPData(object):
"""
Handles multi-process data generation.
"""Handles multi-process data generation.
Operation:
- call start() to start the data generation
- call get() (blocking) to read one sample
- call reset() to stop data generation
"""
def __init__(self, num_processes, max_queue_size, fn):
"""
Parameters
"""Parameters
----------
num_processes: int
Number of processes to spawn
Expand All @@ -55,15 +53,12 @@ def __init__(self, num_processes, max_queue_size, fn):
self.fn = fn

def start(self):
"""
Starts the processes
"""
"""Starts the processes"""
self._init_proc()

@staticmethod
def _proc_loop(proc_id, alive, queue, fn):
"""
Thread loop for generating data
"""Thread loop for generating data
Parameters
----------
Expand Down Expand Up @@ -94,9 +89,7 @@ def _proc_loop(proc_id, alive, queue, fn):
queue.close()

def _init_proc(self):
"""
Start processes if not already started
"""
"""Start processes if not already started"""
if not self.proc:
self.proc = [
mp.Process(target=self._proc_loop, args=(i, self.alive, self.queue, self.fn))
Expand All @@ -107,8 +100,7 @@ def _init_proc(self):
p.start()

def get(self):
"""
Get a datum from the queue
"""Get a datum from the queue
Returns
-------
Expand All @@ -119,9 +111,7 @@ def get(self):
return self.queue.get()

def reset(self):
"""
Resets the generator by stopping all processes
"""
"""Resets the generator by stopping all processes"""
self.alive.value = False
qsize = 0
try:
Expand Down
10 changes: 3 additions & 7 deletions example/ctc/ocr_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@


class SimpleBatch(object):
"""
Batch class for getting label data
"""Batch class for getting label data
Operation:
- call get_label() to start label data generation
"""
Expand Down Expand Up @@ -72,12 +71,9 @@ def get_label(buf):


class OCRIter(mx.io.DataIter):
"""
Iterator class for generating captcha image data
"""
"""Iterator class for generating captcha image data"""
def __init__(self, count, batch_size, lstm_init_states, captcha, name):
"""
Parameters
"""Parameters
----------
count: int
Number of batches to produce for one epoch
Expand Down
7 changes: 2 additions & 5 deletions example/ctc/ocr_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@


class lstm_ocr_model(object):
"""
LSTM network for predicting the Optical Character Recognition
"""
"""LSTM network for predicting the Optical Character Recognition"""
# Keep Zero index for blank. (CTC request it)
CONST_CHAR = '0123456789'

Expand Down Expand Up @@ -63,8 +61,7 @@ def __init_ocr(self):
all_shapes_dict)

def forward_ocr(self, img_):
"""
Forward the image through the LSTM network model
"""Forward the image through the LSTM network model
Parameters
----------
Expand Down
1 change: 0 additions & 1 deletion example/deep-embedded-clustering/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from solver import Solver, Monitor



def cluster_acc(Y_pred, Y):
from sklearn.utils.linear_assignment_ import linear_assignment
assert Y_pred.size == Y.size
Expand Down
3 changes: 1 addition & 2 deletions example/distributed_training/cifar10_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ def forward_backward(network, data, label):

# Train a batch using multiple GPUs
def train_batch(batch_list, context, network, gluon_trainer):
"""
Training with multiple GPUs
""" Training with multiple GPUs
Parameters
----------
Expand Down

0 comments on commit f9f3313

Please sign in to comment.