This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
transposecsr_lib.cc
204 lines (181 loc) · 7.2 KB
/
transposecsr_lib.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file transsparse_lib.cc
* \brief Sample 2D transpose custom operator.
*/
#include <iostream>
#include <utility>
#include "mxnet/lib_api.h"
using namespace mxnet::ext;
void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
MXSparse* A = src.data<MXSparse>();
MXSparse* B = dst.data<MXSparse>();
std::vector<int64_t> shape = src.shape;
int64_t h = shape[0];
int64_t w = shape[1];
if (src.stype == kCSRStorage) {
float* Aval = (float*)(A->data);
// Here we need one more element to help calculate index(line 57).
std::vector<int64_t> rowPtr(w + 2, 0);
// count column
for (int i = 0; i < A->data_len; i++) {
rowPtr[A->indices[i] + 2]++;
}
// Accumulated sum. After this for loop, rowPtr[1:w+2) stores the correct
// result of transposed rowPtr.
for (int i = 2; i < rowPtr.size(); i++) {
rowPtr[i] += rowPtr[i - 1];
}
// Alloc memory for sparse data, where 0 is the index
// of B in output vector.
res.alloc_sparse(B, 0, A->data_len, w + 1);
float* Bval = (float*)(B->data);
for (int i = 0; i < h; i++) {
for (int j = A->indptr[i]; j < A->indptr[i + 1]; j++) {
// Helps calculate index and after that rowPtr[0:w+1) stores the
// correct result of transposed rowPtr.
int index = rowPtr[A->indices[j] + 1]++;
Bval[index] = Aval[j];
B->indices[index] = i;
}
}
memcpy(B->indptr, rowPtr.data(), sizeof(int64_t) * (w + 1));
}
}
MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& res) {
// The data types and storage types of inputs and outputs should be the same.
if (inputs->at(0).dtype != outputs->at(0).dtype || inputs->at(0).stype != outputs->at(0).stype) {
MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype;
return MX_FAIL;
}
transpose(inputs->at(0), outputs->at(0), res);
return MX_SUCCESS;
}
MXReturnValue backward(const std::unordered_map<std::string, std::string>& attrs,
std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& res) {
return MX_SUCCESS;
}
MXReturnValue parseAttrs(const std::unordered_map<std::string, std::string>& attrs,
int* num_in,
int* num_out) {
*num_in = 1;
*num_out = 1;
return MX_SUCCESS;
}
MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attrs,
std::vector<int>* intypes,
std::vector<int>* outtypes) {
// validate inputs
if (intypes->size() != 1) {
MX_ERROR_MSG << "Expected 1 inputs to inferType";
return MX_FAIL;
}
if (intypes->at(0) != kFloat32) {
MX_ERROR_MSG << "Expected input to have float32 type";
return MX_FAIL;
}
outtypes->at(0) = intypes->at(0);
return MX_SUCCESS;
}
MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& attrs,
std::vector<int>* instypes,
std::vector<int>* outstypes) {
if (instypes->at(0) != kCSRStorage) {
MX_ERROR_MSG << "Expected storage type is kCSRStorage";
return MX_FAIL;
}
outstypes->at(0) = instypes->at(0);
return MX_SUCCESS;
}
MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& attrs,
std::vector<std::vector<unsigned int>>* inshapes,
std::vector<std::vector<unsigned int>>* outshapes) {
// validate inputs
if (inshapes->size() != 1) {
MX_ERROR_MSG << "Expected 1 inputs to inferShape";
return MX_FAIL;
}
outshapes->at(0).push_back(inshapes->at(0)[1]);
outshapes->at(0).push_back(inshapes->at(0)[0]);
return MX_SUCCESS;
}
REGISTER_OP(my_transposecsr)
.setForward(forward, "cpu")
.setBackward(backward, "cpu")
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferSType(inferSType)
.setInferShape(inferShape);
/* ------------------------------------------------------------------------- */
class MyStatefulTransposeCSR : public CustomStatefulOp {
public:
explicit MyStatefulTransposeCSR(int count, std::unordered_map<std::string, std::string> attrs)
: count(count), attrs_(std::move(attrs)) {}
MXReturnValue Forward(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& op_res) override {
std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
return forward(attrs_, inputs, outputs, op_res);
}
MXReturnValue Backward(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& op_res) override {
return backward(attrs_, inputs, outputs, op_res);
}
private:
int count;
const std::unordered_map<std::string, std::string> attrs_;
};
MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int>>& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
// creating stateful operator instance
*op_inst = new MyStatefulTransposeCSR(count, attrs);
std::cout << "Info: stateful operator created" << std::endl;
return MX_SUCCESS;
}
REGISTER_OP(my_state_transposecsr)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferSType(inferSType)
.setInferShape(inferShape)
.setCreateOpState(createOpState, "cpu");
MXReturnValue initialize(int version) {
if (version >= 10700) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}