This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
lib_api.h
2010 lines (1815 loc) · 76.1 KB
/
lib_api.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file lib_api.h
* \brief APIs to interact with libraries
* This API specifies function prototypes to
* register custom ops, partitioner, and passes
* for library authors
* See example/extension/lib_custom_op/README.md
* See example/extension/lib_subgraph/README.md
* See example/extension/lib_pass/README.md
*/
#ifndef MXNET_LIB_API_H_
#define MXNET_LIB_API_H_
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <vector>
#include <map>
#include <unordered_map>
#include <string>
#include <iostream>
#include <utility>
#include <stdexcept>
#include <random>
#if defined(__NVCC__)
#include <curand_kernel.h>
#endif
/* Make sure to update the version number everytime you make changes */
#define MX_LIBRARY_VERSION 7
/*!
* \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple
* times may lead to undefined behaviour, so we need to set symbol visibility to hidden
* see https://labjack.com/news/simple-cpp-symbol-visibility-demo for details
*/
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
#define PRIVATE_SYMBOL
#else
#define PRIVATE_SYMBOL __attribute__ ((visibility ("hidden")))
#endif
/*
* Import from DLPack https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
*/
#ifndef DLPACK_VERSION
#ifdef __cplusplus
#define DLPACK_EXTERN_C extern "C"
#else
#define DLPACK_EXTERN_C
#endif
/*! \brief The current version of dlpack */
#define DLPACK_VERSION 020
/*! \brief DLPACK_DLL prefix for windows */
#ifdef _WIN32
#ifdef DLPACK_EXPORTS
#define DLPACK_DLL __declspec(dllexport)
#else
#define DLPACK_DLL __declspec(dllimport)
#endif
#else
#define DLPACK_DLL
#endif
#include <stdint.h>
#include <stddef.h>
#ifdef __cplusplus
extern "C" {
#endif
/*!
* \brief The device type in DLContext.
*/
typedef enum {
/*! \brief CPU device */
kDLCPU = 1,
/*! \brief CUDA GPU device */
kDLGPU = 2,
/*!
* \brief Pinned CUDA GPU device by cudaMallocHost
* \note kDLCPUPinned = kDLCPU | kDLGPU
*/
kDLCPUPinned = 3,
/*! \brief OpenCL devices. */
kDLOpenCL = 4,
/*! \brief Vulkan buffer for next generation graphics. */
kDLVulkan = 7,
/*! \brief Metal for Apple GPU. */
kDLMetal = 8,
/*! \brief Verilog simulator buffer */
kDLVPI = 9,
/*! \brief ROCm GPUs for AMD GPUs */
kDLROCM = 10,
/*!
* \brief Reserved extension device type,
* used for quickly test extension device
* The semantics can differ depending on the implementation.
*/
kDLExtDev = 12,
} DLDeviceType;
/*!
* \brief A Device context for Tensor and operator.
*/
typedef struct {
/*! \brief The device type used in the device. */
DLDeviceType device_type;
/*! \brief The device index */
int device_id;
} DLContext;
/*!
* \brief The type code options DLDataType.
*/
typedef enum {
kDLInt = 0U,
kDLUInt = 1U,
kDLFloat = 2U,
} DLDataTypeCode;
/*!
* \brief The data type the tensor can hold.
*
* Examples
* - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1
*/
typedef struct {
/*!
* \brief Type code of base types.
* We keep it uint8_t instead of DLDataTypeCode for minimal memory
* footprint, but the value should be one of DLDataTypeCode enum values.
* */
uint8_t code;
/*!
* \brief Number of bits, common choices are 8, 16, 32.
*/
uint8_t bits;
/*! \brief Number of lanes in the type, used for vector types. */
uint16_t lanes;
} DLDataType;
/*!
* \brief Plain C Tensor object, does not manage memory.
*/
typedef struct {
/*!
* \brief The opaque data pointer points to the allocated data. This will be
* CUDA device pointer or cl_mem handle in OpenCL. This pointer is always
* aligns to 256 bytes as in CUDA.
*
* For given DLTensor, the size of memory required to store the contents of
* data is calculated as follows:
*
* \code{.c}
* static inline size_t GetDataSize(const DLTensor* t) {
* size_t size = 1;
* for (tvm_index_t i = 0; i < t->ndim; ++i) {
* size *= t->shape[i];
* }
* size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
* return size;
* }
* \endcode
*/
void* data;
/*! \brief The device context of the tensor */
DLContext ctx;
/*! \brief Number of dimensions */
int ndim;
/*! \brief The data type of the pointer*/
DLDataType dtype;
/*! \brief The shape of the tensor */
int64_t* shape;
/*!
* \brief strides of the tensor (in number of elements, not bytes)
* can be nullptr, indicating tensor is compact and row-majored.
*/
int64_t* strides;
/*! \brief The offset in bytes to the beginning pointer to data */
uint64_t byte_offset;
} DLTensor;
#ifdef __cplusplus
} // DLPACK_EXTERN_C
#endif
#endif
/*!
* \brief Tensor data type, consistent with mshadow data type
*/
enum MXDType {
kFloat32 = 0,
kFloat64 = 1,
kFloat16 = 2,
kUint8 = 3,
kInt32 = 4,
kInt8 = 5,
kInt64 = 6,
kUNSET = 100,
};
/*
* MXTensor storage type.
*/
enum MXStorageType {
// dense
kDefaultStorage = 0,
// row sparse
kRowSparseStorage = 1,
// csr
kCSRStorage = 2,
};
/*!
* \brief Context info passing from MXNet OpContext
* dev_type is string repr of supported context, currently only "cpu" and "gpu"
* dev_id is the device index where the tensor locates
*/
struct MXContext {
MXContext() : dev_type("error"), dev_id(-1) {}
explicit MXContext(std::string dev_type_, int dev_id_)
: dev_type(dev_type_), dev_id(dev_id_) {}
explicit MXContext(const char* dev_type_, int dev_id_)
: dev_type(dev_type_), dev_id(dev_id_) {}
static MXContext CPU() { return MXContext("cpu", 0); }
static MXContext GPU() { return MXContext("gpu", 0); }
static MXContext CPU(int dev_id) { return MXContext("cpu", dev_id); }
static MXContext GPU(int dev_id) { return MXContext("gpu", dev_id); }
std::string dev_type;
int dev_id;
};
enum MXReturnValue {
MX_FAIL = 0,
MX_SUCCESS = 1,
};
// For sparse tensors, read/write the data from NDarray via pointers.
struct MXSparse {
// Pointer to data.
void *data{nullptr};
// length of (non-zero) data.
int64_t data_len;
// To store aux data for sparse.
// For CSR, indices stores the col index of non-zero elements.
// For row sparse, indices store row index of rows which have non-zero elements.
int64_t* indices;
int64_t indices_len;
// For CSR, indptr gives the start and end index of data for each row.
// For row sparse, indptr is not used.
int64_t* indptr = nullptr;
int64_t indptr_len;
void set(void *data_ptr, const int64_t* dims, int ndims, void *idx,
int64_t num_idx, void *idx_ptr = nullptr, int64_t num_idx_ptr = 0) {
data = data_ptr;
// If CSR, num of non-zero elemets is num_idx,
// If row sparse, num of elements is num_idx * width.
data_len = num_idx;
if (!idx_ptr) {
for (int i = 1; i < ndims; ++i)
data_len *= dims[i];
}
indices = reinterpret_cast<int64_t*>(idx);
indices_len = num_idx;
if (idx_ptr) {
indptr = reinterpret_cast<int64_t*>(idx_ptr);
indptr_len = num_idx_ptr;
}
}
};
/*!
* \brief Tensor data structure used by custom operator
*/
struct MXTensor {
MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0), stype(kDefaultStorage) {}
MXTensor(const MXTensor& oth) : data_ptr(oth.data_ptr), shape(oth.shape),
dtype(oth.dtype), verID(oth.verID), ctx(oth.ctx), stype(oth.stype) {
setDLTensor();
}
MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype,
size_t vID, MXContext mx_ctx, MXStorageType stype = kDefaultStorage)
: data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx), stype(stype) {
setDLTensor();
}
/*! \brief populate internal tensor fields */
void setTensor(void *dptr, MXDType type, const int64_t* dims, int ndims,
size_t vID, MXContext mx_ctx, MXStorageType storage_type) {
data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx; stype = storage_type;
shape.clear();
for (int j = 0; j < ndims; j++) {
shape.push_back(dims[j]);
}
setDLTensor();
}
/*! \brief populate DLTensor fields */
void setDLTensor() {
dltensor.data = data_ptr;
dltensor.ndim = shape.size();
dltensor.shape = const_cast<int64_t*>(shape.data());
dltensor.strides = nullptr;
dltensor.byte_offset = 0;
dltensor.dtype.lanes = 1;
dltensor.ctx.device_id = ctx.dev_id;
if (ctx.dev_type == "cpu")
dltensor.ctx.device_type = kDLCPU;
else if (ctx.dev_type == "gpu")
dltensor.ctx.device_type = kDLGPU;
else if (ctx.dev_type == "opencl")
dltensor.ctx.device_type = kDLOpenCL;
else if (ctx.dev_type == "vulcan")
dltensor.ctx.device_type = kDLVulkan;
else if (ctx.dev_type == "metal")
dltensor.ctx.device_type = kDLMetal;
else if (ctx.dev_type == "vpi")
dltensor.ctx.device_type = kDLVPI;
else if (ctx.dev_type == "rocm")
dltensor.ctx.device_type = kDLROCM;
else
dltensor.ctx.device_type = kDLExtDev;
switch (dtype) {
case kFloat32:
dltensor.dtype.code = kDLFloat;
dltensor.dtype.bits = 32;
break;
case kFloat64:
dltensor.dtype.code = kDLFloat;
dltensor.dtype.bits = 64;
break;
case kFloat16:
dltensor.dtype.code = kDLFloat;
dltensor.dtype.bits = 16;
break;
case kUint8:
dltensor.dtype.code = kDLUInt;
dltensor.dtype.bits = 8;
break;
case kInt32:
dltensor.dtype.code = kDLInt;
dltensor.dtype.bits = 32;
break;
case kInt8:
dltensor.dtype.code = kDLInt;
dltensor.dtype.bits = 8;
break;
case kInt64:
dltensor.dtype.code = kDLInt;
dltensor.dtype.bits = 64;
break;
default:
dltensor.dtype.code = 0;
dltensor.dtype.bits = 0;
throw std::runtime_error("Error! Invalid dtype flag: "
+ std::to_string(static_cast<int>(dtype))
+ " when constructing MXTensor");
}
}
/*! \brief helper function to cast data pointer */
template<typename data_type>
inline data_type* data() {
return reinterpret_cast<data_type*>(data_ptr);
}
/*! \brief helper function to get data size */
inline int64_t size() const {
int64_t size = 1;
for (unsigned int i = 0; i < shape.size(); i++) {
size *= shape[i];
}
return size;
}
/*! \brief helper function to compare two MXTensors */
inline bool isSame(const MXTensor &oth) const {
return data_ptr == oth.data_ptr &&
dtype == oth.dtype &&
verID == oth.verID &&
ctx.dev_type == oth.ctx.dev_type &&
ctx.dev_id == oth.ctx.dev_id &&
shape == oth.shape &&
stype == oth.stype;
}
// For dense, data_ptr points to 1D flattened tensor data
// For sparse, data_ptr points to MXSparse
void *data_ptr;
// shape is in [2,3,4] format to represent high-dim tensor
std::vector<int64_t> shape;
// type can only be MXDType enum types
MXDType dtype;
// version number updated if the tensor has changed since the last use by custom op
size_t verID;
// context of MXTensor representing which device the tensor data is located
MXContext ctx;
// corresponding DLTensor repr of MXTensor
// easy way to reuse functions taking DLTensor
DLTensor dltensor;
// storage type
MXStorageType stype;
};
/*! \brief resource malloc function to allocate memory inside Forward/Backward functions */
typedef void* (*xpu_malloc_t)(void*, int);
/*! \brief sparse alloc function to allocate memory inside Forward/Backward functions */
typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**, int64_t**);
/*! \brief resource malloc function to allocate ndarrays for graph passes */
typedef void (*nd_malloc_t)(const void* _ndarray_alloc, const int64_t* shapes, int num_shapes,
const char* dev_str, int dev_id, int dtype, const char* name,
int isArg, void** data);
/*! \brief GPU stream pointer, is void* when not compiled with CUDA */
#if defined(__NVCC__)
typedef cudaStream_t mx_stream_t;
typedef curandStatePhilox4_32_10_t mx_gpu_rand_t;
#else
typedef void* mx_stream_t;
typedef void* mx_gpu_rand_t;
#endif
typedef std::mt19937 mx_cpu_rand_t;
/*! \brief MXNet initialized random states for each device, used for parallelism */
/* Each thread should generate random number unique sequence out of different states */
#define MX_NUM_CPU_RANDOM_STATES 1024
#define MX_NUM_GPU_RANDOM_STATES 32768
class PassResource {
public:
PassResource(std::unordered_map<std::string, MXTensor>* new_args,
std::unordered_map<std::string, MXTensor>* new_aux,
nd_malloc_t nd_malloc, const void* nd_alloc)
: new_args_(new_args), new_aux_(new_aux), nd_malloc_(nd_malloc), nd_alloc_(nd_alloc) {}
MXTensor* alloc_arg(const std::string& name, const std::vector<int64_t>& shapes,
const MXContext &ctx, MXDType dtype) const {
void* data;
nd_malloc_(nd_alloc_, shapes.data(), shapes.size(), ctx.dev_type.c_str(), ctx.dev_id,
dtype, name.c_str(), 1, &data);
MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage);
(*new_args_)[name] = tensor;
return &(new_args_->at(name));
}
MXTensor* alloc_aux(const std::string& name, const std::vector<int64_t>& shapes,
const MXContext &ctx, MXDType dtype) const {
void* data;
nd_malloc_(nd_alloc_, shapes.data(), shapes.size(), ctx.dev_type.c_str(), ctx.dev_id,
dtype, name.c_str(), 0, &data);
MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage);
(*new_aux_)[name] = tensor;
return &(new_aux_->at(name));
}
private:
std::unordered_map<std::string, MXTensor>* new_args_;
std::unordered_map<std::string, MXTensor>* new_aux_;
nd_malloc_t nd_malloc_;
const void* nd_alloc_;
};
/*!
* \brief provide resource APIs memory allocation mechanism to Forward/Backward functions
*/
class OpResource {
public:
OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp,
xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream,
sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp,
void* rng_cpu_states, void* rng_gpu_states)
: cpu_malloc(cpu_malloc_fp), gpu_malloc(gpu_malloc_fp),
cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream),
sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp),
rand_cpu_states(rng_cpu_states), rand_gpu_states(rng_gpu_states) {}
/*! \brief allocate cpu memory controlled by MXNet */
void* alloc_cpu(int size) const {
return cpu_malloc(cpu_alloc, size);
}
/*! \brief allocate gpu memory controlled by MXNet */
void* alloc_gpu(int size) const {
return gpu_malloc(gpu_alloc, size);
}
/*! \brief return the cuda stream object with correct type */
mx_stream_t get_cuda_stream() const {
return static_cast<mx_stream_t>(cuda_stream);
}
/*! \brief allocate sparse memory controlled by MXNet */
void alloc_sparse(MXSparse* sparse, int index, int indices_len, int indptr_len = 0) const {
sparse_malloc(sparse_alloc, index, indices_len, indptr_len,
&(sparse->data), &(sparse->indices), &(sparse->indptr));
}
/*! \brief get pointer to initialized and seeded random number states located on CPU */
/* Access each state by states[id], but this id should be <= MX_NUM_CPU_RANDOM_STATES */
mx_cpu_rand_t* get_cpu_rand_states() const {
return static_cast<mx_cpu_rand_t*>(rand_cpu_states);
}
/*! \brief get pointer to initialized and seeded random number states located on GPU */
/* Access each state by states[id], but this id should be <= MX_NUM_GPU_RANDOM_STATES */
/* Note that if you are using cpu build, it will return a nullptr */
mx_gpu_rand_t* get_gpu_rand_states() const {
return static_cast<mx_gpu_rand_t*>(rand_gpu_states);
}
private:
/*! \brief allocation lambda function */
xpu_malloc_t cpu_malloc, gpu_malloc;
/*! \brief lambda function to return allocated memory handle */
void *cpu_alloc, *gpu_alloc;
/*! \brief cuda stream passed from MXNet */
void *cuda_stream;
/*! \brief sparse allocation lambda function */
sparse_malloc_t sparse_malloc;
/*! \brief lambda function to return allocated sparse memory handle */
void *sparse_alloc;
/*! \brief cpu and gpu rng fully inited and seeded states */
void *rand_cpu_states, *rand_gpu_states;
};
/*! \brief Macro to help passing serialized subgraph through attribute dict */
#define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json"
#define MX_STR_DTYPE "__ext_dtype__"
#define MX_STR_SHAPE "__ext_shape__"
/* \brief get shape value from list of shapes string
*
* Examples:
*
* getShapeAt("[[1]]", 0) returns "[1]"
* getShapeAt("[[1],[2,3]]", 1) returns "[2,3]"
*/
std::string getShapeAt(const std::string& shape, unsigned index) {
int idx = 1; // start at 1 to skip the first square bracket [
// find the beginning of the output shape for the particular output index
for (unsigned x=0; x < index; x++)
idx = shape.find("[", idx+1);
int stop = shape.find("]", idx); // find stop index for this output shape
// add this shape to the list
return shape.substr(idx, stop-idx+1);
}
/* \brief get dtype value from list of dtypes string
*
* Examples:
*
* getDtypeAt("[1]", 0) returns "1"
* getDtypeAt("[1,2]", 1) returns "2"
*/
std::string getDtypeAt(const std::string& dtype, unsigned index) {
// find the beginning of the output dtype for the particular output index
int idx = 0;
for (unsigned x=0; x < index; x++)
idx = dtype.find(",", idx+1);
int stop = dtype.find(",", idx+1); // find stop index for this output dtype
if (stop == -1) stop = dtype.find("]", idx+1);
return dtype.substr(idx+1, stop-idx-1);
}
/*!
* \brief Json utility to parse serialized subgraph symbol
*/
/*! \brief Types of JSON objects */
enum JsonType {ERR, STR, NUM, LIST, MAP};
/*! \brief definition of JSON objects */
struct JsonVal {
JsonVal() : type(ERR), num(-1), str("") {} // default constructor
// construct a JSON object by type
explicit JsonVal(JsonType t) : type(t), num(-1), str("") {}
// construct a string JSON object
explicit JsonVal(std::string s) : type(STR), num(-1), str(s) {}
// construct a number JSON object
explicit JsonVal(int n) : type(NUM), num(n), str(std::to_string(n)) {}
// complex constructor
JsonVal(JsonType t, int n, std::string s) : type(t), num(n), str(s) {}
bool operator<(const JsonVal &o) const {
// for string JSON objects compare the string
if (type == STR) return type == o.type && str < o.str;
// for number JSON objects compare the number
if (type == NUM) return type == o.type && num < o.num;
// for list JSON objects, compare the size of list, and then each object in the list
if (type == LIST) {
if (list.size() != o.list.size()) return false;
for (unsigned int i=0; i< list.size(); i++)
if (list[i] < o.list[i])
return false; // if we find an object that doesnt match return
return true; // all objects in lists matched
}
// for map JSON objects, compare the size of map, and then each key/value in the maps
if (type == MAP) {
if (map.size() != o.map.size()) return false;
for (auto &item : map) {
// if one map is missing a key in another return
if (o.map.find(item.first) == o.map.end()) return false;
if (item.second < o.map.at(item.first)) return false;
}
return true;
}
return type < o.type;
}
JsonType type;
int num;
std::string str;
std::vector<JsonVal> list;
std::map<JsonVal, JsonVal> map;
};
/*! \brief functions used for parsing JSON */
struct JsonParser {
JsonVal parse_to_json(const std::string& json) {
unsigned int idx = 0;
return parse(json, &idx);
}
void print_json_val(const JsonVal& val) {
std::cout << json_val_string(val) << std::endl;
}
// debug function to dump data structure to string
std::string json_val_string(const JsonVal &val) {
std::string ret;
switch (val.type) {
case ERR:
ret = "json(Error)";
break;
case STR:
ret = "json(STR:" + val.str + ")";
break;
case NUM:
ret = "json(INT:" + val.str + ")";
break;
case LIST:
ret = "json(LIST:[";
for (auto &item : val.list)
ret += json_val_string(item) + ",";
ret += "])";
break;
case MAP:
ret = "json(MAP:{";
for (auto &item : val.map)
ret += json_val_string(item.first) + " : " + json_val_string(item.second) + ",";
ret += "})";
break;
}
return ret;
}
// parse a string JSON object
JsonVal parse_string(const std::string& json, unsigned int* idx) {
JsonVal ret(STR);
while (*idx < json.size()) {
if (json[*idx] == '"') {
++(*idx);
return ret;
} else {
ret.str += json[*idx];
++(*idx);
}
}
std::cout << "Error! Unable to parse string" << std::endl;
return JsonVal();
}
// parse a number JSON object
JsonVal parse_num(const std::string& json, unsigned int* idx) {
JsonVal ret(NUM);
while (*idx < json.size()) {
if (json[*idx] >= '0' && json[*idx] <= '9') {
ret.str += json[*idx];
++(*idx);
} else {
break;
}
}
ret.num = std::stoi(ret.str);
return ret;
}
// parse a list of JSON objects
JsonVal parse_list(const std::string& json, unsigned int* idx) {
JsonVal ret(LIST);
while (*idx < json.size()) {
if (json[*idx] == ']') {
++(*idx);
return ret;
} else {
JsonVal item = parse(json, idx);
if (item.type != ERR)
ret.list.push_back(item);
}
}
std::cout << "Error! Unable to parse list" << std::endl;
return JsonVal();
}
// parse a map of JSON objects
JsonVal parse_map(const std::string& json, unsigned int* idx) {
JsonVal ret(MAP), key;
while (*idx < json.size()) {
if (json[*idx] == '}') {
++(*idx);
return ret;
} else {
JsonVal item = parse(json, idx);
if (key.type == ERR) {
key = item;
} else {
ret.map[key] = item;
key.type = ERR;
}
}
}
std::cout << "Error! Unable to parse map" << std::endl;
return JsonVal();
}
// generic parse function
JsonVal parse(const std::string& json, unsigned int *idx) {
JsonVal ret;
while (*idx < json.size()) {
if (json[*idx] == '"') {
++(*idx);
ret = parse_string(json, idx);
} else if (json[*idx] >= '0' && json[*idx] <= '9') {
ret = parse_num(json, idx);
} else if (json[*idx] == '[') {
++(*idx);
ret = parse_list(json, idx);
} else if (json[*idx] == '{') {
++(*idx);
ret = parse_map(json, idx);
} else if (json[*idx] == ']' || json[*idx] == '}') {return ret;}
if (ret.type != ERR) return ret;
++(*idx);
}
return ret;
}
// convert JSON object back to JSON-compatible string
std::string dump(const JsonVal &val) {
std::string ret;
switch (val.type) {
case ERR:
ret = "json(Error)";
break;
case STR:
ret = "\"" + val.str + "\"";
break;
case NUM:
ret = val.str;
break;
case LIST:
ret = "[";
for (unsigned i=0; i < val.list.size(); i++) {
auto &item = val.list[i];
ret += dump(item);
if (i < val.list.size()-1)
ret += ",";
}
ret += "]";
break;
case MAP:
ret = "{";
unsigned cnt = 0;
for (auto &item : val.map) {
ret += dump(item.first) + " : " + dump(item.second);
if (cnt++ < val.map.size()-1)
ret += ",";
}
ret += "}";
break;
}
return ret;
}
};
/* \brief An abstract class for library authors creating custom
* partitioners. Optional, can just implement supportedOps instead
*/
class CustomOpSelector {
public:
/* \brief Select a node to include in subgraph, return true to include node
* nodeID - index of node in graph
*/
virtual bool Select(int nodeID) = 0;
/* \brief Select an input node from current node to include in subgraph
* return true to include node
* nodeID - index of node in graph
* input_nodeID - index of input node in graph
*/
virtual bool SelectInput(int nodeID, int input_nodeID) = 0;
/* \brief Select an output node from current node to include in subgraph
* return true to include node
* nodeID - index of node in graph
* output_nodeID - index of output node in graph
*/
virtual bool SelectOutput(int nodeID, int output_nodeID) = 0;
/* \brief Review nodes to include in subgraph
* return set of candidate nodes to keep in subgraph
* candidates - indices of nodes to include in subgraph
* keep - indices of nodes to keep in subgraph
*/
virtual void Filter(const std::vector<int>& candidates,
std::vector<int>* keep) {
keep->insert(keep->end(), candidates.begin(), candidates.end());
}
/* \brief Reset any selector state, called after growing subgraph, before filter
* Called after finished calling SelectInput/SelectOutput and growing subgraph
*/
virtual void Reset() {}
};
/*!
* \brief An abstract class for library authors creating stateful op
* custom library should override Forward and destructor, and has an
* option to implement Backward
*/
class CustomStatefulOp {
public:
virtual MXReturnValue Forward(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& op_res) = 0;
virtual MXReturnValue Backward(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& op_res) {
std::cout << "Error! Operator does not support backward" << std::endl;
return MX_FAIL;
}
};
/*! \brief StatefulOp wrapper class to pass to backend OpState */
class CustomStatefulOpWrapper {
public:
explicit CustomStatefulOpWrapper(CustomStatefulOp* inst) : instance(inst) {}
CustomStatefulOp* get_instance() { return instance; }
private:
CustomStatefulOp* instance;
};
/*! \brief Custom Operator function templates */
typedef MXReturnValue (*fcomp_t)(const std::unordered_map<std::string,
std::string>& attributes,
std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& res);
typedef MXReturnValue (*parseAttrs_t)(const std::unordered_map<std::string,
std::string>& attributes,
int* num_inputs, int* num_outputs);
typedef MXReturnValue (*inferType_t)(const std::unordered_map<std::string,
std::string>& attributes,
std::vector<int>* in_types,
std::vector<int>* out_types);
typedef MXReturnValue (*inferSType_t)(const std::unordered_map<std::string,
std::string>& attributes,
std::vector<int>* in_storage_types,
std::vector<int>* out_storage_types);
typedef MXReturnValue (*inferShape_t)(const std::unordered_map<std::string,
std::string>& attributes,
std::vector<std::vector<unsigned int> >* in_shapes,
std::vector<std::vector<unsigned int> >* out_shapes);
typedef MXReturnValue (*mutateInputs_t)(const std::unordered_map<std::string,
std::string>& attributes,
std::vector<int>* input_indices);
typedef MXReturnValue (*createOpState_t)(const std::unordered_map<std::string,
std::string>& attributes,
CustomStatefulOp**);
/*!
* \brief Class to hold custom operator registration
*/
class CustomOp {
public:
explicit CustomOp(const char* op_name) : name(op_name),
parse_attrs(NULL), infer_type(NULL), infer_storage_type(NULL), infer_shape(NULL),
mutate_inputs(NULL), isSGop(false) {}
CustomOp& setForward(fcomp_t fcomp, const char* ctx) {
if (forward_ctx_map.count(ctx) > 0)
raiseDuplicateContextError();
forward_ctx_map[ctx] = fcomp;
return *this;
}
CustomOp& setBackward(fcomp_t fgrad, const char* ctx) {
if (backward_ctx_map.count(ctx) > 0)
raiseDuplicateContextError();
backward_ctx_map[ctx] = fgrad;
return *this;
}
CustomOp& setParseAttrs(parseAttrs_t func) {
parse_attrs = func;
return *this;
}
CustomOp& setInferType(inferType_t func) {
infer_type = func;
return *this;
}
CustomOp& setInferSType(inferSType_t func) {
infer_storage_type = func;
return *this;
}
CustomOp& setInferShape(inferShape_t func) {
infer_shape = func;
return *this;
}
CustomOp& setMutateInputs(mutateInputs_t func) {
mutate_inputs = func;
return *this;
}
CustomOp& setCreateOpState(createOpState_t func, const char* ctx) {
if (create_op_ctx_map.count(ctx) > 0)
raiseDuplicateContextError();
create_op_ctx_map[ctx] = func;
return *this;
}
CustomOp& setIsSubgraphOp() {
isSGop = true;
return *this;
}
void mapToVector() {
for (auto kv : forward_ctx_map) {
forward_ctx_cstr.push_back(kv.first);
forward_fp.push_back(kv.second);
}
for (auto kv : backward_ctx_map) {
backward_ctx_cstr.push_back(kv.first);
backward_fp.push_back(kv.second);
}
for (auto kv : create_op_ctx_map) {
create_op_ctx_cstr.push_back(kv.first);
create_op_fp.push_back(kv.second);
}
}
~CustomOp() {}
/*! \brief operator name */
const char* name;
/*! \brief operator functions */
parseAttrs_t parse_attrs;
inferType_t infer_type;
inferSType_t infer_storage_type;
inferShape_t infer_shape;
mutateInputs_t mutate_inputs;
bool isSGop;
/*! \brief vector repr of ctx map to be easily loaded from c_api */
std::vector<const char*> forward_ctx_cstr, backward_ctx_cstr, create_op_ctx_cstr;
std::vector<fcomp_t> forward_fp, backward_fp;
std::vector<createOpState_t> create_op_fp;
private:
void raiseDuplicateContextError() {
std::string op_name_str(name);
throw std::runtime_error(
"Error! Error! Cannot register multiple functions under same context for operator '"
+ op_name_str + "'");
}
/*! \brief dedup context maps - static string ctx to custom function */
std::unordered_map<const char*, fcomp_t> forward_ctx_map, backward_ctx_map;
std::unordered_map<const char*, createOpState_t> create_op_ctx_map;
};
/*! \brief Custom Pass Create function template */
typedef MXReturnValue (*graphPass_t)(const std::string& in_graph, const std::string** out_graph,
const std::unordered_map<std::string, std::string>& options,
const std::unordered_map<std::string, MXTensor>& args,
const std::unordered_map<std::string, MXTensor>& aux,
const PassResource& res);