Skip to content

Commit

Permalink
* Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hanke580 committed Jan 12, 2021
1 parent 1514969 commit 71916d6
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 30 deletions.
49 changes: 28 additions & 21 deletions include/tvm/topi/einsum.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <bitset>
#include <iterator>
#include <string>
#include <tuple>
#include <unordered_set>
#include <vector>

Expand Down Expand Up @@ -438,8 +439,8 @@ inline std::vector<std::string> Split(const std::string& str, const std::string&
* if no output, the vector[1] is NULL.
* "ab, bc -> ac" => ["ab,bc", "ac"]
*/
inline std::vector<std::string> ParseEinsumInput(std::string subscripts,
const std::vector<Array<PrimExpr>>& operands) {
inline std::tuple<std::string, std::string> ParseEinsumInput(
std::string subscripts, const std::vector<Array<PrimExpr>>& operands) {
const std::string einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
std::bitset<128> einsum_symbols_set;
for (const char& c : einsum_symbols) {
Expand Down Expand Up @@ -565,12 +566,14 @@ inline std::vector<std::string> ParseEinsumInput(std::string subscripts,
}

// Build output string if does not exist
std::vector<std::string> ret(2);
std::tuple<std::string, std::string> ret;
if (subscripts.find("->") != std::string::npos) {
ret = Split(subscripts, "->");
std::vector<std::string> tmp(2);
tmp = Split(subscripts, "->");
ret = std::make_tuple(tmp[0], tmp[1]);
} else {
ret[0] = subscripts;
ret[1] = "";
std::string first = subscripts;
std::string second = "";
// Build output subscripts
std::string tmp_subscripts = subscripts;
size_t len_tmp_subscripts = tmp_subscripts.length();
Expand All @@ -583,20 +586,21 @@ inline std::vector<std::string> ParseEinsumInput(std::string subscripts,
CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
if ((i == 0 || tmp_subscripts[i - 1] != c) &&
(i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c)) {
ret[1].append(1, c);
second.append(1, c);
}
}
ret = std::make_tuple(first, second);
}

// Make sure output subscripts are in the input
std::bitset<128> input_subscripts_set = Str2Set(ret[0]);
for (const char& c : ret[1]) {
std::bitset<128> input_subscripts_set = Str2Set(std::get<0>(ret));
for (const char& c : std::get<1>(ret)) {
CHECK(input_subscripts_set.test(c))
<< "Output character " << c << " did not appear in the input";
}

// Make sure number operands is equivalent to the number of terms
CHECK_EQ(std::count(ret[0].begin(), ret[0].end(), ',') + 1, operands.size())
CHECK_EQ(std::count(std::get<0>(ret).begin(), std::get<0>(ret).end(), ',') + 1, operands.size())
<< "Number of einsum subscripts must be equal to the "
<< "number of operands.";

Expand All @@ -613,10 +617,10 @@ inline std::vector<std::string> ParseEinsumInput(std::string subscripts,
inline Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,
const std::vector<Array<PrimExpr>>& operands) {
// Parsing
std::vector<std::string> parsed_subscripts = ParseEinsumInput(subscripts, operands);
std::tuple<std::string, std::string> parsed_subscripts = ParseEinsumInput(subscripts, operands);

// Build a few useful list and sets
std::vector<std::string> input_list = Split(parsed_subscripts[0], ",");
std::vector<std::string> input_list = Split(std::get<0>(parsed_subscripts), ",");
size_t isize = input_list.size();

// Get length of each unique dimension and ensure all dimensions are correct
Expand Down Expand Up @@ -649,7 +653,7 @@ inline Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,
}

// Get oshape
const std::string& output_str = parsed_subscripts[1];
const std::string& output_str = std::get<1>(parsed_subscripts);
size_t odim = output_str.size();
Array<PrimExpr> oshape(odim, -1);
for (size_t i = 0; i < odim; ++i) {
Expand All @@ -662,17 +666,17 @@ inline Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,
/*!
* \brief Evaluates the Einstein summation convention on the operands.
*
* \param subscripts_str Specifies the subscripts for summation as comma separated list of subscript
* labels. \param inputs Arrays for the operation. \param name The name of the operation. \param tag
* The tag to mark the operation.
* \param subscripts_str Specifies the subscripts for summation as comma separated list of
* subscript labels. \param inputs Arrays for the operation. \param name The name of the
* operation. \param tag The tag to mark the operation.
*
* \return The calculation based on the Einstein summation convention.
*/
inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inputs,
std::string name = "T_einsum", std::string tag = kMatMul) {
// able to compute op: trace, diag, sum, transpose, matmul, dot, inner, outer, multiply, tensordot
bool back = false;
const char* subscripts = subscripts_str.data();
const char* head = subscripts;
const int nop = inputs.size();

/* Step 1: Parse the subscripts string into label_counts and op_labels */
Expand All @@ -688,12 +692,15 @@ inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inpu
CHECK(!(iop < nop - 1 && subscripts[length] != ','))
<< "fewer operands provided to einstein sum function "
<< "than specified in the subscripts string";
ParseOperandSubscripts(subscripts, length, inputs[iop + back].ndim(), iop, op_labels[iop],
label_counts, &min_label, &max_label);
CHECK_EQ(ParseOperandSubscripts(subscripts, length, inputs[iop + back].ndim(), iop,
op_labels[iop], label_counts, &min_label, &max_label),
0);

/* Move subscripts to the start of the labels for the next op */
subscripts += length;

if (iop < nop - 1) {
CHECK_LT(subscripts - head, subscripts_str.length()) << "subscripts out of range";
subscripts++;
}
}
Expand Down Expand Up @@ -891,15 +898,15 @@ inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inpu
}
}
Array<PrimExpr> ridx = UnravelIndex(0, reduceshape);
// AType needs to fit the output data

PrimExpr sum = 0;
bool rec_flag = false;
do {
PrimExpr tmp = 1;
for (int iop = 0; iop < nop; ++iop) {
if (iop != -1) {
PrimExpr k = 0;
// k = dot(input_indices, ostride) + dot(ridx, rstride)

for (size_t i = 0; i < input_indices.size(); ++i) {
k += input_indices[i] * ostride[iop][i];
}
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .scatter import *
from .scatter_add import *
from .argwhere import *
from .einsum import *
from . import generic
from . import nn
from . import x86
Expand Down
46 changes: 46 additions & 0 deletions python/tvm/topi/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,consider-using-enumerate,redefined-outer-name
"""Einsum operator"""
from . import cpp


def einsum(subscripts, a_tuple):
"""Evaluates the Einstein summation convention on the operands.
Parameters
----------
subscripts : string
Specifies the subscripts for summation as comma separated list of subscript labels.
An implicit (classical Einstein summation) calculation is performed unless the
explicit indicator ‘->’ is included as well as subscript labels of the precise
output form.
a_tuple : tuple of tvm.te.Tensor
These are the Tensors for the operation.
The only difference of einsum between in tvm and numpy is it needs an extra brackets
for the tensors. For example, topi.einsum("ij, jk -> ik", (A, B)).
Returns
-------
out : tvm.te.Tensor
The calculation based on the Einstein summation convention.
"""

if not isinstance(a_tuple, tuple):
a_tuple = (a_tuple,)
return cpp.einsum(subscripts, a_tuple)
8 changes: 0 additions & 8 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,6 @@ def tensordot(a, b, axes):
return cpp.tensordot(a, b, axes[0], axes[1])


def einsum(subscripts, a_tuple):
# Grab non-einsum kwargs; do not optimize by default.
# Temporally not implement the optimized einsum
if not isinstance(a_tuple, tuple):
a_tuple = (a_tuple,)
return cpp.einsum(subscripts, a_tuple)


def arange(start, stop=None, step=1, dtype="float32"):
"""Creates a tensor with evenly spaced values within a given interval.
Expand Down
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tvm.testing
from tvm import te
from tvm import topi
from tvm.topi.util import get_const_tuple
from tvm.topi.utils import get_const_tuple


def with_tvm(lam, *args):
Expand Down

0 comments on commit 71916d6

Please sign in to comment.