Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Initial infra of boolean ndarray
Browse files Browse the repository at this point in the history
Add np.equal implemented using tvmop

Fix setting DLDataType conversion for boolean ndarray

Add equal_gpu

Fix inputs with different ndims

Fix copying boolean ndarrays across devices

Refactor binary logic op impl by tvm

Add more logic ops

Refactor TVMOpModule::Call to CallEx

Add binary scalar logic op expr and schedule

Add binary scalar logic ops

Add free functions for logic ops

Rebase with master to fix SetDLTensor bug

Fix pylint

Add sum op for boolean ndarrays using tvm op module

Add sum boolean gpu compute

Add bool type support to boolean_mask
  • Loading branch information
reminisce committed Aug 28, 2019
1 parent 8df9469 commit 595e2f7
Show file tree
Hide file tree
Showing 26 changed files with 1,416 additions and 228 deletions.
62 changes: 61 additions & 1 deletion 3rdparty/mshadow/mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ enum TypeFlag {
kInt32 = 4,
kInt8 = 5,
kInt64 = 6,
kBool = 7,
};

template<typename DType>
Expand Down Expand Up @@ -411,6 +412,11 @@ struct DataType<int64_t> {
static const int kFlag = kInt64;
static const int kLanes = 1;
};
template<>
struct DataType<bool> {
static const int kFlag = kBool;
static const int kLanes = 1;
};

/*! \brief type enum value for default real type */
const int default_type_flag = DataType<default_real_t>::kFlag;
Expand Down Expand Up @@ -1099,10 +1105,64 @@ struct minimum {
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt8: \
{ \
typedef int8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kBool: \
{ \
typedef bool DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

/*! \brief get data type size from type enum */
inline size_t mshadow_sizeof(int type) {
int size = 0;
MSHADOW_TYPE_SWITCH(type, DType, size = sizeof(DType););
MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, size = sizeof(DType););
return size;
}

Expand Down
59 changes: 30 additions & 29 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -570,34 +570,6 @@ ifeq ($(UNAME_S), Darwin)
LDFLAGS += -Wl,-install_name,@rpath/libmxnet.so
endif

# NOTE: to statically link libmxnet.a we need the option
# --Wl,--whole-archive -lmxnet --Wl,--no-whole-archive
lib/libmxnet.a: $(ALLX_DEP)
@mkdir -p $(@D)
ar crv $@ $(filter %.o, $?)

lib/libmxnet.so: $(ALLX_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter-out %libnnvm.a, $(filter %.o %.a, $^)) $(LDFLAGS) \
-Wl,${WHOLE_ARCH} $(filter %libnnvm.a, $^) -Wl,${NO_WHOLE_ARCH}
ifeq ($(USE_MKLDNN), 1)
ifeq ($(UNAME_S), Darwin)
install_name_tool -change '@rpath/libmklml.dylib' '@loader_path/libmklml.dylib' $@
install_name_tool -change '@rpath/libiomp5.dylib' '@loader_path/libiomp5.dylib' $@
install_name_tool -change '@rpath/libmkldnn.0.dylib' '@loader_path/libmkldnn.0.dylib' $@
endif
endif

$(PS_PATH)/build/libps.a: PSLITE

PSLITE:
$(MAKE) CXX="$(CXX)" DEPS_PATH="$(DEPS_PATH)" -C $(PS_PATH) ps

$(DMLC_CORE)/libdmlc.a: DMLCCORE

DMLCCORE:
+ cd $(DMLC_CORE); $(MAKE) libdmlc.a USE_SSE=$(USE_SSE) config=$(ROOTDIR)/$(config); cd $(ROOTDIR)

ifeq ($(USE_TVM_OP), 1)
LIB_DEP += lib/libtvm_runtime.so lib/libtvmop.so
CFLAGS += -I$(TVM_PATH)/include -DMXNET_USE_TVM_OP=1
Expand All @@ -617,16 +589,45 @@ lib/libtvm_runtime.so:
cmake -DUSE_LLVM="$(LLVM_PATH)/bin/llvm-config" \
-DUSE_SORT=OFF -DUSE_CUDA=$(TVM_USE_CUDA) -DUSE_CUDNN=OFF ..; \
$(MAKE) VERBOSE=1; \
mkdir -p $(ROOTDIR)/lib
cp $(TVM_PATH)/build/libtvm_runtime.so $(ROOTDIR)/lib/libtvm_runtime.so; \
cd $(ROOTDIR)

lib/libtvmop.so: lib/libtvm_runtime.so $(wildcard contrib/tvmop/*/*.py contrib/tvmop/*.py)
echo "Compile TVM operators"
PYTHONPATH=$(TVM_PATH)/python:$(TVM_PATH)/topi/python:$(ROOTDIR)/contrib:$PYTHONPATH \
PYTHONPATH=$(TVM_PATH)/python:$(TVM_PATH)/topi/python:$(ROOTDIR)/contrib:$(PYTHONPATH) \
LD_LIBRARY_PATH=lib \
python3 $(ROOTDIR)/contrib/tvmop/compile.py -o $(ROOTDIR)/lib/libtvmop.so
endif

# NOTE: to statically link libmxnet.a we need the option
# --Wl,--whole-archive -lmxnet --Wl,--no-whole-archive
lib/libmxnet.a: $(ALLX_DEP)
@mkdir -p $(@D)
ar crv $@ $(filter %.o, $?)

lib/libmxnet.so: $(ALLX_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter-out %libnnvm.a, $(filter %.o %.a, $^)) $(LDFLAGS) \
-Wl,${WHOLE_ARCH} $(filter %libnnvm.a, $^) -Wl,${NO_WHOLE_ARCH}
ifeq ($(USE_MKLDNN), 1)
ifeq ($(UNAME_S), Darwin)
install_name_tool -change '@rpath/libmklml.dylib' '@loader_path/libmklml.dylib' $@
install_name_tool -change '@rpath/libiomp5.dylib' '@loader_path/libiomp5.dylib' $@
install_name_tool -change '@rpath/libmkldnn.0.dylib' '@loader_path/libmkldnn.0.dylib' $@
endif
endif

$(PS_PATH)/build/libps.a: PSLITE

PSLITE:
$(MAKE) CXX="$(CXX)" DEPS_PATH="$(DEPS_PATH)" -C $(PS_PATH) ps

$(DMLC_CORE)/libdmlc.a: DMLCCORE

DMLCCORE:
+ cd $(DMLC_CORE); $(MAKE) libdmlc.a USE_SSE=$(USE_SSE) config=$(ROOTDIR)/$(config); cd $(ROOTDIR)

NNVM_INC = $(wildcard $(NNVM_PATH)/include/*/*.h)
NNVM_SRC = $(wildcard $(NNVM_PATH)/src/*/*/*.cc $(NNVM_PATH)/src/*/*.cc $(NNVM_PATH)/src/*.cc)
$(NNVM_PATH)/lib/libnnvm.a: $(NNVM_INC) $(NNVM_SRC)
Expand Down
1 change: 1 addition & 0 deletions contrib/tvmop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .utils import assign_by_req, reduce_axes

from . import basic
from . import core
18 changes: 18 additions & 0 deletions contrib/tvmop/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from . import umath, fromnumeric
63 changes: 63 additions & 0 deletions contrib/tvmop/core/fromnumeric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


import tvm
from .. import defop
from ..utils import reduce_axes, assign_by_req


def _compute_sum(itype, otype, ndim, reduce1st_dim, req):
axes = ([reduce1st_dim, 1 - reduce1st_dim] * ndim)[:ndim]
a = tvm.placeholder([tvm.var() for _ in range(ndim)], name='a', dtype=itype)
reduce_output = reduce_axes(a, axes, tvm.sum, otype)
output_placeholder, final_output = assign_by_req(reduce_output, req)
s = tvm.create_schedule(final_output.op)
return s, a, output_placeholder, final_output, [reduce_output, final_output]


@defop(name='sum_cpu', target='cpu', itype=['bool'],
otype=['float32', 'float64', 'int32', 'int64'],
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
attrs=["reduce1st_dim", "req"])
def _sum_cpu(itype, otype, ndim, reduce1st_dim, req):
s, a, output_placeholder, final_output, tensor_list = _compute_sum(
itype, otype, ndim, reduce1st_dim, req)
for t in tensor_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [a, output_placeholder, final_output]


@defop(name='sum_gpu', target='gpu', itype=['bool'],
otype=['float32', 'float64', 'int32', 'int64'],
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
attrs=["reduce1st_dim", "req"])
def _sum_gpu(itype, otype, ndim, reduce1st_dim, req):
s, a, output_placeholder, final_output, tensor_list = _compute_sum(
itype, otype, ndim, reduce1st_dim, req)
num_threads = 64
for t in tensor_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_threads)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [a, output_placeholder, final_output]
122 changes: 122 additions & 0 deletions contrib/tvmop/core/umath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from .. import defop, AllTypes

_bin_logic_op_map = {
'equal': lambda a, b, *idx: a[idx] == b[idx],
'not_equal': lambda a, b, *idx: a[idx] != b[idx],
'greater': lambda a, b, *idx: a[idx] > b[idx],
'less': lambda a, b, *idx: a[idx] < b[idx],
'greater_equal': lambda a, b, *idx: a[idx] >= b[idx],
'less_equal': lambda a, b, *idx: a[idx] <= b[idx],
}


def _compute_binary_logic(op, dtype, ndim):
a = tvm.placeholder([tvm.var() for _ in range(ndim)], dtype=dtype, name='a')
b = tvm.placeholder([tvm.var() for _ in range(ndim)], dtype=dtype, name='b')
c = tvm.compute([tvm.var() for _ in range(ndim)],
lambda *idx: _bin_logic_op_map[op](a, b, *idx), name='c')
s = tvm.create_schedule(c.op)
return s, a, b, c


_bin_logic_cpu_attrs = {
'compute_func': _compute_binary_logic,
'target': 'cpu',
'auto_broadcast': True,
'itype': AllTypes + ['bool'],
'ndim': list(range(6))
}

_bin_logic_gpu_attrs = {
'compute_func': _compute_binary_logic,
'target': 'gpu',
'auto_broadcast': True,
'itype': AllTypes + ['bool'],
'ndim': list(range(6))
}


def _binary_logic_cpu(compute_func, op, itype, ndim):
s, a, b, c = compute_func(op, itype, ndim)
axes = [axis for axis in c.op.axis]
fused = s[c].fuse(*axes)
s[c].parallel(fused)
return s, [a, b, c]


def _binary_logic_gpu(compute_func, op, itype, ndim):
s, a, b, c = compute_func(op, itype, ndim)
axes = [axis for axis in c.op.axis]
fused = s[c].fuse(*axes)
bx, tx = s[c].split(fused, factor=64)
s[c].bind(bx, tvm.thread_axis('blockIdx.x'))
s[c].bind(tx, tvm.thread_axis('threadIdx.x'))
return s, [a, b, c]


# register binary element-wise logic ops with broadcasting supported
for op_name in _bin_logic_op_map.keys():
defop(name='{}_cpu'.format(op_name), op=op_name, **_bin_logic_cpu_attrs)(_binary_logic_cpu)
defop(name='{}_gpu'.format(op_name), op=op_name, **_bin_logic_gpu_attrs)(_binary_logic_gpu)


# Note that `b.dtype` is hard-coded as 'float64'.
# We should always promote `a`'s elements to `b.dtype`.
_bin_scalar_logic_op_map = {
'equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) == b,
'not_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) != b,
'greater_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) > b,
'less_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) < b,
'greater_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) >= b,
'less_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) <= b,
}


def _compute_binary_scalar_logic(op, dtype, ndim):
a = tvm.placeholder([tvm.var() for _ in range(ndim)], name='a', dtype=dtype)
b = tvm.var('b', dtype='float64')
c = tvm.compute([tvm.var() for _ in range(ndim)],
lambda *idx: _bin_scalar_logic_op_map[op](a, b, *idx), name='c')
s = tvm.create_schedule(c.op)
return s, a, b, c


_bin_scalar_logic_cpu_attrs = {
'compute_func': _compute_binary_scalar_logic,
'target': 'cpu',
'itype': AllTypes + ['bool'],
'ndim': list(range(6))
}

_bin_scalar_logic_gpu_attrs = {
'compute_func': _compute_binary_scalar_logic,
'target': 'gpu',
'itype': AllTypes + ['bool'],
'ndim': list(range(6))
}


# register binary element-wise scalar logic ops
for op_name in _bin_scalar_logic_op_map.keys():
defop(name='{}_cpu'.format(op_name), op=op_name,
**_bin_scalar_logic_cpu_attrs)(_binary_logic_cpu)
defop(name='{}_gpu'.format(op_name), op=op_name,
**_bin_scalar_logic_gpu_attrs)(_binary_logic_gpu)
Loading

0 comments on commit 595e2f7

Please sign in to comment.