Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add vocabulary and embedding #10074

Merged
merged 20 commits into from
Mar 15, 2018
Merged
312 changes: 278 additions & 34 deletions LICENSE

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,46 @@
MXNet Change Log
================
## 1.1.0
### Usability Improvements
- Improved the usability of examples and tutorials
### Bug-fixes
- Fixed I/O multiprocessing for too many open file handles (#8904), race condition (#8995), deadlock (#9126).
- Fixed image IO integration with OpenCV 3.3 (#8757).
- Fixed Gluon block printing (#8956).
- Fixed float16 argmax when there is negative input. (#9149)
- Fixed random number generator to ensure sufficient randomness. (#9119, #9256, #9300)
- Fixed custom op multi-GPU scaling (#9283)
- Fixed gradient of gather_nd when duplicate entries exist in index. (#9200)
- Fixed overriden contexts in Module `group2ctx` option when using multiple contexts (#8867)
- Fixed `swap_axes` operator with "add_to" gradient req (#9541)
### New Features
- Added experimental API in `contrib.text` for building vocabulary, and loading pre-trained word embeddings, with built-in support for 307 GloVe and FastText pre-trained embeddings. (#8763)
- Added experimental structural blocks in `gluon.contrib`: `Concurrent`, `HybridConcurrent`, `Identity`. (#9427)
- Added `sparse.dot(dense, csr)` operator (#8938)
- Added `Khatri-Rao` operator (#7781)
- Added `FTML` and `Signum` optimizer (#9220, #9262)
- Added `ENABLE_CUDA_RTC` build option (#9428)
### API Changes
- Added zero gradients to rounding operators including `rint`, `ceil`, `floor`, `trunc`, and `fix` (#9040)
- Added `use_global_stats` in `nn.BatchNorm` (#9420)
- Added `axis` argument to `SequenceLast`, `SequenceMask` and `SequenceReverse` operators (#9306)
- Added `lazy_update` option for standard `SGD` & `Adam` optimizer with `row_sparse` gradients (#9468, #9189)
- Added `select` option in `Block.collect_params` to support regex (#9348)
- Added support for (one-to-one and sequence-to-one) inference on explicit unrolled RNN models in R (#9022)
### Deprecations
- The Scala API name space is still called `ml.dmlc`. The name space is likely be changed in a future release to `org.apache` and might brake existing applications and scripts (#9579, #9324)
### Performance Improvements
- Improved GPU inference speed by 20% when batch size is 1 (#9055)
- Improved `SequenceLast` operator speed (#9306)
- Added multithreading for the class of broadcast_reduce operators on CPU (#9444)
- Improved batching for GEMM/TRSM operators with large matrices on GPU (#8846)
### Known Issues
- "Predict with pre-trained models" tutorial is broken
- "example/numpy-ops/ndarray_softmax.py" is broken

For more information and examples, see [full release notes](https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+%28incubating%29+1.1.0+Release+Notes)


## 1.0.0
### Performance
- Enhanced the performance of `sparse.dot` operator.
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ deep learning systems, and interesting insights of DL systems for hackers.

What's New
----------
* [Version 1.1.0 Release](https://github.com/apache/incubator-mxnet/releases/tag/1.1.0) - MXNet 1.1.0 Release.
* [Version 1.0.0 Release](https://github.com/apache/incubator-mxnet/releases/tag/1.0.0) - MXNet 1.0.0 Release.
* [Version 0.12.1 Release](https://github.com/apache/incubator-mxnet/releases/tag/0.12.1) - MXNet 0.12.1 Patch Release.
* [Version 0.12.0 Release](https://github.com/apache/incubator-mxnet/releases/tag/0.12.0) - MXNet 0.12.0 Release.
Expand Down
12 changes: 6 additions & 6 deletions cpp-package/include/mxnet-cpp/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,15 @@ class NDArray {
*/
void SyncCopyToCPU(std::vector<mx_float> *data, size_t size = 0);
/*!
* \brief Copy the content of current array to other.
* \param other the new context of this NDArray
* \return the new copy
* \brief copy the content of current array to a target array.
* \param other the target NDArray
* \return the target NDarray
*/
NDArray CopyTo(NDArray * other) const;
/*!
* \brief return a new copy this NDArray
* \param other the target NDArray
* \return the copy target NDarray
* \brief return a new copy to this NDArray
* \param Context the new context of this NDArray
* \return the new copy
*/
NDArray Copy(const Context &) const;
/*!
Expand Down
131 changes: 72 additions & 59 deletions example/numpy-ops/custom_softmax_rtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,78 +23,91 @@

class Softmax(mx.operator.CustomOp):
def __init__(self):
self.fwd_kernel_mod = None
self.bwd_kernel_mod = None
super().__init__()
super(Softmax,self).__init__()
# Each thread processes a row (a sample in the batch).
fwd_src = r"""
template<class DType>
__global__ void fwd(const DType* x, DType* y, const int row_size, const int req) {
const int offset = row_size * threadIdx.x;
DType max = x[offset];
for(int i = 1; i < row_size; ++i) {
if(max < x[offset + i]) {
max = x[offset + i];
}
}
DType sum = 0;
for(int i = 0; i < row_size; ++i) {
sum += exp(x[offset + i] - max);
}
switch(req) {
case 1:
for(int i = 0; i < row_size; ++i) {
y[offset + i] = exp(x[offset + i] - max) / sum;
}
break;
case 2:
for(int i = 0; i < row_size; ++i) {
y[offset + i] += exp(x[offset + i] - max) / sum;
}
break;
}
}
"""

# Each block processes a row and each thread in a block calculate an element of `dx`.
bwd_src = r"""
template<class DType>
__global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) {
const int z = static_cast<int>(l[blockIdx.x]);
const int i = threadIdx.x + blockDim.x * blockIdx.x;
if(req == 1) {
dx[i] = threadIdx.x == z ? y[i] - 1 : y[i];
} else {
dx[i] += threadIdx.x == z ? y[i] - 1 : y[i];
}
}
"""
fwd_kernel_mod = mx.rtc.CudaModule(fwd_src, exports=["fwd<float>", "fwd<double>"])
bwd_kernel_mod = mx.rtc.CudaModule(bwd_src, exports=["bwd<float>", "bwd<double>"])

fwd_kernel_float_signature = "const float*, const float*, const int, const int"
self.fwd_float_kernel = fwd_kernel_mod.get_kernel("fwd<float>", fwd_kernel_float_signature)

bwd_kernel_float_signature = "const float*, const float*, float*, const int"
self.bwd_float_kernel = bwd_kernel_mod.get_kernel("bwd<float>", bwd_kernel_float_signature)

fwd_kernel_double_signature = "const double*, const double*, const int, const int"
self.fwd_double_kernel = fwd_kernel_mod.get_kernel("fwd<double>", fwd_kernel_double_signature)

bwd_kernel_double_signature = "const double*, const double*, double*, const int"
self.bwd_double_kernel = bwd_kernel_mod.get_kernel("bwd<double>", bwd_kernel_double_signature)

def forward(self, is_train, req, in_data, out_data, aux):
if req[0] == "null":
return
x = in_data[0] # input
y = out_data[0] # output
if self.fwd_kernel_mod is None:
# Each thread processes a row (a sample in the batch).
src = r"""
template<class DType>
__global__ void fwd(const DType* x, DType* y, const int row_size, const int req) {
const int offset = row_size * threadIdx.x;
DType max = x[offset];
for(int i = 1; i < row_size; ++i) {
if(max < x[offset + i]) {
max = x[offset + i];
}
}
DType sum = 0;
for(int i = 0; i < row_size; ++i) {
sum += exp(x[offset + i] - max);
}
switch(req) {
case 1:
for(int i = 0; i < row_size; ++i) {
y[offset + i] = exp(x[offset + i] - max) / sum;
}
break;
case 2:
for(int i = 0; i < row_size; ++i) {
y[offset + i] += exp(x[offset + i] - max) / sum;
}
break;
}
}
"""
self.fwd_kernel_mod = mx.rtc.CudaModule(src, exports=["fwd<float>", "fwd<double>"])
dtype = "double" if y.dtype == np.float64 else "float"
kernel_signature = "const {0}*, const {0}*, const int, const int".format(dtype)
kernel = self.fwd_kernel_mod.get_kernel("fwd<{}>".format(dtype), kernel_signature)
# args, ctx, grid_shape, block_shape, shared_mem = 0
kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))

if y.dtype == np.float64:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.fwd_double_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))
else:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.fwd_float_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
if req[0] == "null":
return
l = in_data[1] # label
y = out_data[0] # output from the forward pass
dx = in_grad[0] # the storage for the gradient
if self.bwd_kernel_mod is None:
# Each block processes a row and each thread in a block calculate an element of `dx`.
src = r"""
template<class DType>
__global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) {
const int z = static_cast<int>(l[blockIdx.x]);
const int i = threadIdx.x + blockDim.x * blockIdx.x;
if(req == 1) {
dx[i] = threadIdx.x == z ? y[i] - 1 : y[i];
} else {
dx[i] += threadIdx.x == z ? y[i] - 1 : y[i];
}
}
"""
self.bwd_kernel_mod = mx.rtc.CudaModule(src, exports=["bwd<float>", "bwd<double>"])
dtype = "double" if dx.dtype == np.float64 else "float"
kernel_signature = "const {0}*, const {0}*, {0}*, const int".format(dtype)
kernel = self.bwd_kernel_mod.get_kernel("bwd<{}>".format(dtype), kernel_signature)
# args, ctx, grid_shape, block_shape, shared_mem = 0
kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))

if dx.dtype == np.float64:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.bwd_double_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))
else:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.bwd_float_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))

def _reqCode(self, req):
if(req == "write"):
Expand Down
25 changes: 25 additions & 0 deletions python/mxnet/text/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""This module includes utilities for indexing and embedding text."""

from .vocab import *

from . import embedding

from .utils import *
Loading