forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
NamedTensorUtils.cpp
542 lines (485 loc) · 17.4 KB
/
NamedTensorUtils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
#include <ATen/NamedTensorUtils.h>
#include <ATen/TensorNames.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <bitset>
#include <sstream>
namespace at {
// Returns "Tensor['N', 'C', 'H', 'W']" for a tensor with names ('N', 'C', 'H', 'W').
static std::string toDimnameRepr(const Tensor& tensor) {
std::ostringstream os;
os << "Tensor" << tensor.names();
return os.str();
}
int64_t dimname_to_position(const Tensor& tensor, Dimname dim) {
TORCH_CHECK(dim.type() != NameType::WILDCARD,
"Please look up dimensions by name, got: name = None.");
TORCH_CHECK(tensor.has_names(),
"Name ", dim, " not found in ", toDimnameRepr(tensor), ".");
const auto names = tensor.names();
const auto it = std::find(names.begin(), names.end(), dim);
TORCH_CHECK(it != names.end(),
"Name ", dim, " not found in ", toDimnameRepr(tensor), ".");
return std::distance(names.begin(), it);
}
std::vector<int64_t> dimnames_to_positions(const Tensor& tensor, DimnameList dims) {
std::vector<int64_t> result;
result.reserve(dims.size());
for (const auto& name : dims) {
result.push_back(dimname_to_position(tensor, name));
}
return result;
}
static void report_positional_error(
const Dimname& name,
const Dimname& other_name,
DimnameList names,
DimnameList other_names,
const char* action) {
// TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds
TORCH_CHECK(false,
"Error when attempting to ", action, " dims ", names, " and dims ",
other_names, ": dim ", name, " and dim ", other_name, " are at the same position "
"from the right but do not match.")
}
static void check_for_misalignment(
const Dimname& name,
DimnameList names,
DimnameList other_names,
const char* action) {
if (name.isWildcard()) {
return;
}
auto it = std::find(other_names.begin(), other_names.end(), name);
// TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds
TORCH_CHECK(it == other_names.end(),
"Misaligned dims when attempting to ", action, " dims ", names, " and dims ",
other_names, ": dim ", name, " appears in a different position from the right "
"across both lists.");
}
// Assumption: A DimnameList can have no duplicate full names with
// the exception of wildcards
std::vector<Dimname> unify_from_right(
DimnameList names,
DimnameList other_names,
const char* action) {
const auto wildcard = Dimname::wildcard();
const auto size = std::max(names.size(), other_names.size());
auto result = std::vector<Dimname>(size, wildcard);
auto names_it = names.rbegin();
auto other_it = other_names.rbegin();
auto result_it = result.rbegin();
while (names_it != names.rend() || other_it != other_names.rend()) {
const auto& name = names_it == names.rend() ? wildcard : *names_it;
const auto& other_name = other_it == other_names.rend() ? wildcard : *other_it;
// Step 1: Check that the names match
const auto maybeName = name.unify(other_name);
if (!maybeName) {
report_positional_error(name, other_name, names, other_names, action);
}
*result_it = *maybeName;
// Step 2: Check that the names are not misaligned
if (!name.isBasic() || !other_name.isBasic()) {
// Let: N = max(len(names), len(other_names))
// K = # of special names among names and other_names.
// This search (including the outer loop) is O(N*K) but typically # of dims is small.
check_for_misalignment(name, names, other_names, action);
check_for_misalignment(other_name, other_names, names, action);
}
if (names_it != names.rend()) {
++names_it;
}
if (other_it != other_names.rend()) {
++other_it;
}
++result_it;
}
return result;
}
namespace namedinference {
static std::bitset<dim_bitset_size>
compute_included_idxs(IntArrayRef excluded_idxs, int64_t ndims) {
auto result = dim_list_to_bitset(excluded_idxs, ndims);
result.flip();
return result;
}
static void assert_names_equal(DimnameList a, DimnameList b) {
TORCH_CHECK(a == b,
"Name mismatch: specified out tensor with names ", a,
" are not the same as the computed output names ", b,
". Please rename the out tensor's dims with `Tensor.rename`.");
}
Tensor& propagate_names_if_nonempty(Tensor& result,
DimnameList maybe_names,
bool validate_names) {
propagate_names_if_nonempty(result.unsafeGetTensorImpl(), maybe_names, validate_names);
return result;
}
TensorImpl* propagate_names_if_nonempty(TensorImpl* result,
DimnameList maybe_names,
bool validate_names) {
if (maybe_names.empty()) {
return result;
}
return propagate_names(result, maybe_names, validate_names);
}
Tensor& propagate_names(Tensor& result, DimnameList names, bool validate_names) {
propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
return result;
}
TensorImpl* propagate_names(TensorImpl* result, DimnameList names, bool validate_names) {
if (result->dim() > 0) {
TORCH_INTERNAL_ASSERT(
!names.empty(),
"propagate_names: passed in empty names to propagate to result with",
" shape ", result->sizes(), ". Empty names means that name inference did",
"not occur; use `propagate_names_if_nonempty` instead of `propagate_names`.");
}
if (!impl::has_names(result)) {
impl::internal_set_names_inplace(result, names, validate_names);
} else {
assert_names_equal(impl::get_names(result), names);
}
return result;
}
void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs) {
if (!result.has_names() && !src.has_names()) {
return;
}
auto src_names = src.names();
auto result_dim = result.dim();
auto src_dim = src_names.size();
TORCH_INTERNAL_ASSERT(src_dim - excluded_idxs.size() == result_dim);
// fast path
if (excluded_idxs.size() == 1) {
std::vector<Dimname> outnames = src_names.vec();
outnames.erase(outnames.begin() + maybe_wrap_dim(excluded_idxs[0], src_dim));
propagate_names(result, outnames);
return;
}
std::vector<Dimname> outnames;
outnames.reserve(result_dim);
auto included_idxs = compute_included_idxs(excluded_idxs, src_dim);
for (size_t dim = 0; dim < src_dim; ++dim) {
if (included_idxs[dim]) {
outnames.push_back(src_names[dim]);
}
}
propagate_names(result, outnames);
}
void propagate_names_for_reduction(Tensor& result, const Tensor& src, IntArrayRef reduced_dims, bool keepdim) {
if (keepdim) {
propagate_names(result, src);
return;
}
// This actually means "full reduction"
if (reduced_dims.size() == 0) {
return;
}
propagate_names_except(result, src, reduced_dims);
}
void propagate_names(Tensor& result, const Tensor& src) {
propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
}
void propagate_names(TensorImpl* result, TensorImpl* src) {
if (result == src) {
return;
}
if (!impl::has_names(result) && !impl::has_names(src)) {
return;
}
propagate_names(result, impl::get_names(src));
}
std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor) {
if (!tensor.has_names()) {
return {};
}
std::vector<Dimname> outnames;
auto tensor_names = tensor.names();
for (int64_t d = 0; d < tensor.dim(); d++) {
if (tensor.sizes()[d] != 1) {
outnames.push_back(tensor_names[d]);
}
}
return outnames;
}
std::vector<Dimname> compute_diagonal_outnames(
const Tensor& tensor,
int64_t dim1,
int64_t dim2) {
if (!tensor.has_names()) {
return {};
}
std::vector<Dimname> outnames;
auto tensor_names = tensor.names();
for (int64_t d = 0; d < tensor.dim(); d++) {
if (d == dim1 || d == dim2) {
continue;
}
outnames.push_back(tensor_names[d]);
}
outnames.push_back(Dimname::wildcard());
return outnames;
}
// tensor_dotted_dim and other_dotted_dim are the dimensions of the two
// tensors that we contract together. Usually other_dotted_dim is 0
// and tensor_dotted_dim is the last dim of tensor, but there are some special
// cases like einsum and tensordot where one can contract arbitrary dims.
static std::vector<Dimname> compute_dot_product_outnames(
DimnameList tensor_names,
int64_t tensor_dotted_dim,
DimnameList other_names,
int64_t other_dotted_dim) {
int64_t num_outnames = tensor_names.size() + other_names.size() - 2;
if (num_outnames == 0) {
return {};
}
std::vector<Dimname> outnames(num_outnames, Dimname::wildcard());
int64_t index = 0;
for (int64_t j = 0; j < tensor_names.size(); ++j) {
if (j == tensor_dotted_dim) continue;
outnames[index++] = tensor_names[j];
}
for (int64_t j = 0; j < other_names.size(); ++j) {
if (j == other_dotted_dim) continue;
outnames[index++] = other_names[j];
}
return outnames;
}
static void check_feature_names_are_distinct(
DimnameList self_names,
DimnameList other_names,
DimnameList outnames) {
if (self_names.size() < 2 || other_names.size() < 2) {
// There are less than 2 feature dims in outnames so there is nothing to check
return;
}
auto feature0 = outnames[outnames.size() - 2];
auto feature1 = outnames[outnames.size() - 1];
TORCH_CHECK(
feature0 == Dimname::wildcard() || feature0 != feature1,
"Matrix multiplying Tensor", self_names,
" with Tensor", other_names,
" would produce output tensor with duplicate names ",
outnames,
". Please rename the input tensors with `Tensor.rename` to prevent this.");
}
static DimnameList batch_dims(DimnameList names) {
if (names.size() <= 2) {
return {};
}
return DimnameList(names.begin(), names.end() - 2);
}
static DimnameList feature_dims(DimnameList names) {
if (names.size() <= 2) {
return names;
}
return DimnameList(names.end() - 2, 2);
}
static bool are_distinct(DimnameList batch_dims, DimnameList feature_dims) {
for (const auto& target : feature_dims) {
if (target.isWildcard()) {
continue;
}
if (std::any_of(batch_dims.begin(), batch_dims.end(),
[&](const Dimname& dim) { return target == dim; })) {
return false;
}
}
return true;
}
static int64_t num_batch_dims(DimnameList names) {
if (names.size() <= 2) {
return 0;
}
return names.size() - 2;
}
static std::vector<Dimname> compute_matmul_outnames(
DimnameList self_names,
DimnameList other_names) {
TORCH_CHECK(self_names.size() >= 1 && other_names.size() >= 1,
"both arguments to matmul need to be at least 1D, but they are ",
self_names.size(), "D and ", other_names.size(), "D");
// matmul performs a batch matrix multiply between self and other, each of which
// can either be:
// - a batches of matrices (if dim > 2)
// - a matrix (if dim == 2)
// - a vector (if dim == 1)
//
// To compute output names, we unify the batch dimensions because those are
// broadcastable to get the output batch dimensions.
//
// After that, we append some names that are equal to the result of the matmul
// without batch dimensions. Those names are computed by removing the names
// of the dimensions that were contracted away. We always contract the
// last dim of the first tensor with the first feature dimension of the second.
// Get the output's batch dimension names
auto wrapped_self_names = TensorNames(self_names, 0, num_batch_dims(self_names));
const auto wrapped_other_names = TensorNames(other_names, 0, num_batch_dims(other_names));
auto& working_names = wrapped_self_names.unifyFromRightInplace(wrapped_other_names, "matmul");
// Append the result of each individual (non-batched) matmul.
// If either of self or other have dim 1, that means they are a vector. Vectors get
// completely contracted away during matmul so we don't take any names from them.
if (self_names.size() >= 2) {
working_names.append(TensorName(self_names, -2));
}
if (other_names.size() >= 2) {
working_names.append(TensorName(other_names, -1));
}
const auto result = working_names.toDimnameVec();
check_feature_names_are_distinct(self_names, other_names, result);
return result;
}
void propagate_names_for_addmv(
Tensor& result,
const Tensor& mat,
const Tensor& vec,
const Tensor& bias) {
if (!result.has_names() && !mat.has_names() &&
!vec.has_names() && !bias.has_names()) {
return;
}
auto mv_outnames = compute_matmul_outnames(mat.names(), vec.names());
auto add_outnames = unify_from_right(mv_outnames, bias.names());
propagate_names(result, add_outnames);
}
void propagate_names_for_addmm(
Tensor& result,
const Tensor& m1,
const Tensor& m2,
const Tensor& bias) {
if (!m1.has_names() && !m2.has_names() &&
!bias.has_names() && !result.has_names()) {
return;
}
auto mm_outnames = compute_matmul_outnames(m1.names(), m2.names());
auto add_outnames = unify_from_right(mm_outnames, bias.names());
propagate_names(result, add_outnames);
}
void check_names_for_dot(
TensorImpl* vec1,
TensorImpl* vec2) {
if (!impl::has_names(vec1) && !impl::has_names(vec2)) {
return;
}
compute_matmul_outnames(impl::get_names(vec1), impl::get_names(vec2));
}
// expand adds new None dimensions. This is consistent with name inference
// rules for binary ops that expect the named dims to line up positionally
// from the right. i.e.,
// Tensor[H, W].expand(3, 3, 3, 3) -> Tensor[None, None, H, W]
void propagate_names_for_expand(Tensor& result, const Tensor& self) {
if (!self.has_names()) {
return;
}
auto result_dim = result.dim();
if (self.dim() == result_dim) {
propagate_names(result, self);
return;
}
std::vector<Dimname> outnames(result_dim, Dimname::wildcard());
std::copy(
self.opt_names()->begin(),
self.opt_names()->end(),
outnames.begin() + result_dim - self.dim());
propagate_names(result, outnames);
}
std::vector<Dimname> compute_broadcast_outnames(
const Tensor& self,
const Tensor& other) {
if (!self.has_names() && !other.has_names()) {
return {};
}
return unify_from_right(self.names(), other.names());
}
std::vector<Dimname> broadcast_to_outnames(
const Tensor& tensor,
const Tensor& reference_tensor,
const char* op_name) {
if (!tensor.has_names() && !reference_tensor.has_names()) {
return {};
}
auto reference_names = reference_tensor.names();
auto tensor_names = tensor.names();
TORCH_CHECK(
reference_names.size() >= tensor_names.size(),
op_name, ": attempted to broadcast Tensor", tensor_names, " to Tensor",
reference_names, " but the number of dims (", tensor_names.size(),
") must be less than or equal to the number of dims in the tensor (",
reference_names.size(), ")");
return unify_from_right(reference_names, tensor_names);
}
std::vector<Dimname> compute_cat_outnames(TensorList tensors) {
if (!at::has_names(tensors)) {
return {};
}
std::vector<Dimname> result;
for (const auto& tensor : tensors) {
const auto tensor_names = tensor.names();
TORCH_CHECK(tensor_names.size() > 0, "zero-dimensional tensor cannot be concatenated");
TORCH_CHECK(result.empty() || tensor_names.size() == result.size(),
"Tensors must have same number of dimensions: got ", result.size(),
" and ", tensor_names.size());
result = unify_from_right(result, tensor_names, "cat");
}
return result;
}
std::vector<Dimname> compute_matmul_outnames(
const Tensor& self,
const Tensor& other) {
if (!self.has_names() && !other.has_names()) {
return {};
}
return compute_matmul_outnames(self.names(), other.names());
}
std::vector<Dimname> compute_cdist_outnames(
const Tensor& self,
const Tensor& other) {
if (!self.has_names() && !other.has_names()) {
return {};
}
const auto self_names = self.names();
const auto other_names = other.names();
auto self_batch = TensorNames(self_names, 0, num_batch_dims(self_names));
const auto other_batch = TensorNames(other_names, 0, num_batch_dims(other_names));
auto& result = self_batch.unifyFromRightInplace(other_batch, "cdist");
// cdist treats self and other like batches of M x D and N X D tensors, respectively.
// It computes the pairwise distance between each of the M vectors (of size D)
// in `self` and each of the N vectors in `other`, returning a batch of M x N
// distance values. We propagate the names of the dimension of size M (in self)
// and the dimension of size N (in other), both of which are second-from-last.
result.append(TensorName(self_names, -2));
result.append(TensorName(other_names, -2));
result.checkUnique("cdist");
return result.toDimnameVec();
}
std::vector<Dimname> compute_bmm_outnames(
Tensor& result,
const Tensor& self,
const Tensor& other) {
if (!result.has_names() && !self.has_names() && !other.has_names()) {
return {};
}
return compute_matmul_outnames(self.names(), other.names());
}
std::vector<Dimname> compute_baddbmm_outnames(
TensorImpl* result,
TensorImpl* batch1,
TensorImpl* batch2,
TensorImpl* bias) {
if (!impl::has_names(result) && !impl::has_names(batch1) &&
!impl::has_names(batch2) && !impl::has_names(bias)) {
return {};
}
auto bmm_names = compute_matmul_outnames(
impl::get_names(batch1), impl::get_names(batch2));
auto baddbmm_names = unify_from_right(impl::get_names(bias), bmm_names);
return baddbmm_names;
}
bool are_names_equal(TensorImpl* self, TensorImpl* other) {
if (!impl::has_names(self) && !impl::has_names(other)) {
return true;
}
return impl::get_names(self) == impl::get_names(other);
}
} // namespace namedinference
} // namespace at