Skip to content

Commit

Permalink
* Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hanke580 committed Feb 2, 2021
1 parent 8578678 commit 605c4db
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 53 deletions.
95 changes: 54 additions & 41 deletions include/tvm/topi/einsum.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
#ifndef TVM_TOPI_EINSUM_H_
#define TVM_TOPI_EINSUM_H_

#define LABELRANGE 128
#define NPY_MAXDIMS 16
#define NPY_MAXARGS 16

#include <tvm/te/operation.h>
#include <tvm/tir/data_layout.h>
#include <tvm/topi/detail/constant_utils.h>
Expand All @@ -35,6 +39,7 @@
#include <bitset>
#include <iterator>
#include <string>
#include <tuple>
#include <unordered_set>
#include <vector>

Expand Down Expand Up @@ -220,8 +225,8 @@ inline int ParseOutputSubscripts(const char* subscripts, int length, int ndim_br
<< "in an input";

/* Check that there is room in out_labels for this label. */
CHECK(ndim < 16) << "einstein sum subscripts string contains "
<< "too many subscripts in the output";
CHECK(ndim < NPY_MAXDIMS) << "einstein sum subscripts string contains "
<< "too many subscripts in the output";

out_labels[ndim++] = label;
} else if (label == '.') {
Expand All @@ -233,8 +238,8 @@ inline int ParseOutputSubscripts(const char* subscripts, int length, int ndim_br
<< "an ellipsis ('...') in the output";

/* Check there is room in out_labels for broadcast dims. */
CHECK(ndim + ndim_broadcast <= 16) << "einstein sum subscripts string contains "
<< "too many subscripts in the output";
CHECK(ndim + ndim_broadcast <= NPY_MAXDIMS) << "einstein sum subscripts string contains "
<< "too many subscripts in the output";

ellipsis = 1;
for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
Expand Down Expand Up @@ -273,7 +278,7 @@ inline int ParseOutputSubscripts(const char* subscripts, int length, int ndim_br
inline void GetCombinedDimsView(const Tensor& op, int iop, char* labels, Array<PrimExpr>* newshape,
Array<PrimExpr>* newstride) {
int idim, ndim, icombine, combineoffset;
int icombinemap[16];
int icombinemap[NPY_MAXDIMS];
int newdim;

Array<PrimExpr> shape = op->shape;
Expand Down Expand Up @@ -402,8 +407,8 @@ inline int CountSubstring(const std::string& str, const std::string& sub) {
*
* \return bitset.
*/
inline std::bitset<128> Str2Set(const std::string& str) {
std::bitset<128> ret;
inline std::bitset<LABELRANGE> Str2Set(const std::string& str) {
std::bitset<LABELRANGE> ret;
for (const char& c : str) {
ret.set(static_cast<int>(c));
}
Expand Down Expand Up @@ -438,10 +443,10 @@ inline std::vector<std::string> Split(const std::string& str, const std::string&
* if no output, the vector[1] is NULL.
* "ab, bc -> ac" => ["ab,bc", "ac"]
*/
inline std::vector<std::string> ParseEinsumInput(std::string subscripts,
const std::vector<Array<PrimExpr>>& operands) {
inline std::tuple<std::string, std::string> ParseEinsumInput(
std::string subscripts, const std::vector<Array<PrimExpr>>& operands) {
const std::string einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
std::bitset<128> einsum_symbols_set;
std::bitset<LABELRANGE> einsum_symbols_set;
for (const char& c : einsum_symbols) {
einsum_symbols_set.set(c);
}
Expand Down Expand Up @@ -475,7 +480,7 @@ inline std::vector<std::string> ParseEinsumInput(std::string subscripts,
[](const char& c) { return c == '.' || c == ',' || c == '-' || c == '>'; }),
used.end());

std::bitset<128> used_set = Str2Set(used);
std::bitset<LABELRANGE> used_set = Str2Set(used);
std::string ellipse_inds = "";
for (const char& c : einsum_symbols) {
if (!used_set.test(static_cast<int>(c))) {
Expand Down Expand Up @@ -544,7 +549,7 @@ inline std::vector<std::string> ParseEinsumInput(std::string subscripts,
subscripts += "->" + output_sub;
} else {
// Special care for outputless ellipses
std::bitset<128> out_ellipse_set = Str2Set(out_ellipse);
std::bitset<LABELRANGE> out_ellipse_set = Str2Set(out_ellipse);
std::string tmp_subscripts = subscripts, output_subscript = "";
size_t len_tmp_subscripts = tmp_subscripts.length();
std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
Expand All @@ -565,12 +570,14 @@ inline std::vector<std::string> ParseEinsumInput(std::string subscripts,
}

// Build output string if does not exist
std::vector<std::string> ret(2);
std::tuple<std::string, std::string> ret;
if (subscripts.find("->") != std::string::npos) {
ret = Split(subscripts, "->");
std::vector<std::string> tmp(2);
tmp = Split(subscripts, "->");
ret = std::make_tuple(tmp[0], tmp[1]);
} else {
ret[0] = subscripts;
ret[1] = "";
std::string first = subscripts;
std::string second = "";
// Build output subscripts
std::string tmp_subscripts = subscripts;
size_t len_tmp_subscripts = tmp_subscripts.length();
Expand All @@ -583,20 +590,21 @@ inline std::vector<std::string> ParseEinsumInput(std::string subscripts,
CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
if ((i == 0 || tmp_subscripts[i - 1] != c) &&
(i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c)) {
ret[1].append(1, c);
second.append(1, c);
}
}
ret = std::make_tuple(first, second);
}

// Make sure output subscripts are in the input
std::bitset<128> input_subscripts_set = Str2Set(ret[0]);
for (const char& c : ret[1]) {
std::bitset<LABELRANGE> input_subscripts_set = Str2Set(std::get<0>(ret));
for (const char& c : std::get<1>(ret)) {
CHECK(input_subscripts_set.test(c))
<< "Output character " << c << " did not appear in the input";
}

// Make sure number operands is equivalent to the number of terms
CHECK_EQ(std::count(ret[0].begin(), ret[0].end(), ',') + 1, operands.size())
CHECK_EQ(std::count(std::get<0>(ret).begin(), std::get<0>(ret).end(), ',') + 1, operands.size())
<< "Number of einsum subscripts must be equal to the "
<< "number of operands.";

Expand All @@ -613,14 +621,14 @@ inline std::vector<std::string> ParseEinsumInput(std::string subscripts,
inline Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,
const std::vector<Array<PrimExpr>>& operands) {
// Parsing
std::vector<std::string> parsed_subscripts = ParseEinsumInput(subscripts, operands);
std::tuple<std::string, std::string> parsed_subscripts = ParseEinsumInput(subscripts, operands);

// Build a few useful list and sets
std::vector<std::string> input_list = Split(parsed_subscripts[0], ",");
std::vector<std::string> input_list = Split(std::get<0>(parsed_subscripts), ",");
size_t isize = input_list.size();

// Get length of each unique dimension and ensure all dimensions are correct
int dimension_dict[128];
int dimension_dict[LABELRANGE];
memset(dimension_dict, -1, sizeof(dimension_dict));
for (size_t i = 0; i < isize; ++i) {
const std::string& term = input_list[i];
Expand Down Expand Up @@ -649,7 +657,7 @@ inline Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,
}

// Get oshape
const std::string& output_str = parsed_subscripts[1];
const std::string& output_str = std::get<1>(parsed_subscripts);
size_t odim = output_str.size();
Array<PrimExpr> oshape(odim, -1);
for (size_t i = 0; i < odim; ++i) {
Expand All @@ -662,22 +670,24 @@ inline Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,
/*!
* \brief Evaluates the Einstein summation convention on the operands.
*
* \param subscripts_str Specifies the subscripts for summation as comma separated list of subscript
* labels. \param inputs Arrays for the operation. \param name The name of the operation. \param tag
* The tag to mark the operation.
* \param subscripts_str Specifies the subscripts for summation as comma separated list of
* subscript labels.
* \param inputs Arrays for the operation.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return The calculation based on the Einstein summation convention.
*/
inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inputs,
std::string name = "T_einsum", std::string tag = kMatMul) {
// able to compute op: trace, diag, sum, transpose, matmul, dot, inner, outer, multiply, tensordot
std::string name = "T_einsum", std::string tag = kEinsum) {
bool back = false;
const char* subscripts = subscripts_str.data();
const char* head = subscripts;
const int nop = inputs.size();

/* Step 1: Parse the subscripts string into label_counts and op_labels */
int iop, idim, min_label = 127, max_label = 0;
char label_counts[128], op_labels[16][16];
int iop, idim, min_label = LABELRANGE - 1, max_label = 0;
char label_counts[LABELRANGE], op_labels[NPY_MAXARGS][NPY_MAXDIMS];
memset(label_counts, 0, sizeof(label_counts));
for (iop = 0; iop < nop; ++iop) {
int length = static_cast<int>(strcspn(subscripts, ",-"));
Expand All @@ -688,12 +698,15 @@ inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inpu
CHECK(!(iop < nop - 1 && subscripts[length] != ','))
<< "fewer operands provided to einstein sum function "
<< "than specified in the subscripts string";
ParseOperandSubscripts(subscripts, length, inputs[iop + back].ndim(), iop, op_labels[iop],
label_counts, &min_label, &max_label);
CHECK_EQ(ParseOperandSubscripts(subscripts, length, inputs[iop + back].ndim(), iop,
op_labels[iop], label_counts, &min_label, &max_label),
0);

/* Move subscripts to the start of the labels for the next op */
subscripts += length;

if (iop < nop - 1) {
CHECK_LT(subscripts - head, subscripts_str.length()) << "subscripts out of range";
subscripts++;
}
}
Expand Down Expand Up @@ -724,16 +737,16 @@ inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inpu
* using each label that appeared once, in alphabetical order.
*/
int label, ndim_output;
char output_labels[16];
char output_labels[NPY_MAXDIMS];
if (subscripts[0] == '\0') {
/* If no output was specified, always broadcast left, as usual. */
for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) {
output_labels[ndim_output] = 0;
}
for (label = min_label; label <= max_label; ++label) {
if (label_counts[label] == 1) {
CHECK(ndim_output < 16) << "einstein sum subscript string has too many "
<< "distinct labels";
CHECK(ndim_output < NPY_MAXDIMS) << "einstein sum subscript string has too many "
<< "distinct labels";
output_labels[ndim_output++] = label;
}
}
Expand Down Expand Up @@ -795,7 +808,7 @@ inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inpu
int ndim_iter = ndim_output;
for (label = min_label; label <= max_label; ++label) {
if (label_counts[label] > 0 && memchr(output_labels, label, ndim_output) == nullptr) {
CHECK(ndim_iter < 16) << "too many subscripts in einsum";
CHECK(ndim_iter < NPY_MAXDIMS) << "too many subscripts in einsum";
iter_labels[ndim_iter++] = label;
}
}
Expand All @@ -813,8 +826,8 @@ inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inpu
Array<PrimExpr> ostride_true = GetStride(oshape);
Array<PrimExpr> reduceshape;
std::vector<Array<PrimExpr>> remainshape(nop);
int op_axes_arrays[16][16];
int* op_axes[16];
int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
int* op_axes[NPY_MAXARGS];
for (iop = 0; iop < nop; ++iop) {
op_axes[iop] = op_axes_arrays[iop];
CHECK_GE(PrepareOpAxes(opshape[iop].size(), iop, op_labels[iop], op_axes[iop], ndim_iter,
Expand Down Expand Up @@ -891,15 +904,15 @@ inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inpu
}
}
Array<PrimExpr> ridx = UnravelIndex(0, reduceshape);
// AType needs to fit the output data

PrimExpr sum = 0;
bool rec_flag = false;
do {
PrimExpr tmp = 1;
for (int iop = 0; iop < nop; ++iop) {
if (iop != -1) {
PrimExpr k = 0;
// k = dot(input_indices, ostride) + dot(ridx, rstride)

for (size_t i = 0; i < input_indices.size(); ++i) {
k += input_indices[i] * ostride[iop][i];
}
Expand Down
1 change: 1 addition & 0 deletions include/tvm/topi/tags.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw";
constexpr auto kDepthwiseConv2dNHWC = "depthwise_conv2d_nhwc";
constexpr auto kDepthwiseConv2dBackInputNHWC = "depthwise_conv2d_back_input_nhwc";
constexpr auto kDepthwiseConv2dBackWeightNHWC = "depthwise_conv2d_back_weight_nhwc";
constexpr auto kEinsum = "einsum";
constexpr auto kGroupConv2d = "group_conv2d";

inline bool is_broadcast(std::string tag) {
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .scatter_add import *
from .argwhere import *
from .cumsum import *
from .einsum import *
from . import generic
from . import nn
from . import x86
Expand Down
44 changes: 44 additions & 0 deletions python/tvm/topi/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,consider-using-enumerate,redefined-outer-name
"""Einsum operator"""
from . import cpp


def einsum(subscripts, *operand):
"""Evaluates the Einstein summation convention on the operands.
Parameters
----------
subscripts : string
Specifies the subscripts for summation as comma separated list of subscript labels.
An implicit (classical Einstein summation) calculation is performed unless the
explicit indicator ‘->’ is included as well as subscript labels of the precise
output form.
a_tuple : tuple of tvm.te.Tensor
These are the Tensors for the operation.
The only difference of einsum between in tvm and numpy is it needs an extra brackets
for the tensors. For example, topi.einsum("ij, jk -> ik", (A, B)).
Returns
-------
out : tvm.te.Tensor
The calculation based on the Einstein summation convention.
"""

return cpp.einsum(subscripts, operand)
8 changes: 0 additions & 8 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,6 @@ def tensordot(a, b, axes):
return cpp.tensordot(a, b, axes[0], axes[1])


def einsum(subscripts, a_tuple):
# Grab non-einsum kwargs; do not optimize by default.
# Temporally not implement the optimized einsum
if not isinstance(a_tuple, tuple):
a_tuple = (a_tuple,)
return cpp.einsum(subscripts, a_tuple)


def arange(start, stop=None, step=1, dtype="float32"):
"""Creates a tensor with evenly spaced values within a given interval.
Expand Down
8 changes: 4 additions & 4 deletions tests/python/topi/python/test_topi_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tvm.testing
from tvm import te
from tvm import topi
from tvm.topi.util import get_const_tuple
from tvm.topi.utils import get_const_tuple


def with_tvm(lam, *args):
Expand Down Expand Up @@ -50,11 +50,11 @@ def verify_einsum(subscripts, shapes):
c1 = np.einsum(subscripts, *ops)

if len(ops) == 1:
c2 = with_tvm(lambda A: topi.einsum(subscripts, (A)), *ops)
c2 = with_tvm(lambda A: topi.einsum(subscripts, A), *ops)
elif len(ops) == 2:
c2 = with_tvm(lambda A, B: topi.einsum(subscripts, (A, B)), *ops)
c2 = with_tvm(lambda A, B: topi.einsum(subscripts, A, B), *ops)
elif len(ops) == 3:
c2 = with_tvm(lambda A, B, C: topi.einsum(subscripts, (A, B, C)), *ops)
c2 = with_tvm(lambda A, B, C: topi.einsum(subscripts, A, B, C), *ops)

tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)

Expand Down

0 comments on commit 605c4db

Please sign in to comment.