-
Notifications
You must be signed in to change notification settings - Fork 3k
/
ndarray.cc
277 lines (252 loc) · 8.83 KB
/
ndarray.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
/*!
* Copyright (c) 2017 by Contributors
* \file ndarray.cc
* \brief NDArray container infratructure.
*/
#include <dmlc/logging.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include "runtime_base.h"
// deleter for arrays used by DLPack exporter
extern "C" void NDArrayDLPackDeleter(DLManagedTensor* tensor);
namespace dgl {
namespace runtime {
inline void VerifyDataType(DLDataType dtype) {
CHECK_GE(dtype.lanes, 1);
if (dtype.code == kDLFloat) {
CHECK_EQ(dtype.bits % 8, 0);
} else {
CHECK_EQ(dtype.bits % 8, 0);
}
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
}
inline size_t GetDataSize(const DLTensor& arr) {
size_t size = 1;
for (dgl_index_t i = 0; i < arr.ndim; ++i) {
size *= arr.shape[i];
}
size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8;
return size;
}
inline size_t GetDataAlignment(const DLTensor& arr) {
size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes;
if (align < kAllocAlignment) return kAllocAlignment;
return align;
}
struct NDArray::Internal {
// Default deleter for the container
static void DefaultDeleter(NDArray::Container* ptr) {
using dgl::runtime::NDArray;
if (ptr->manager_ctx != nullptr) {
static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef();
} else if (ptr->dl_tensor.data != nullptr) {
dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace(
ptr->dl_tensor.ctx, ptr->dl_tensor.data);
}
delete ptr;
}
// Deleter for NDArray converted from DLPack
// This is used from data which is passed from external DLPack(DLManagedTensor)
// that are not allocated inside of DGL.
// This enables us to create NDArray from memory allocated by other
// frameworks that are DLPack compatible
static void DLPackDeleter(NDArray::Container* ptr) {
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx);
if (tensor->deleter != nullptr) {
(*tensor->deleter)(tensor);
}
delete ptr;
}
// Local create function which allocates tensor metadata
// but does not allocate space for the data.
static NDArray Create(std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx) {
VerifyDataType(dtype);
// critical zone
NDArray::Container* data = new NDArray::Container();
data->deleter = DefaultDeleter;
NDArray ret(data);
ret.data_ = data;
// RAII now in effect
// setup shape
data->shape_ = std::move(shape);
data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
data->dl_tensor.ndim = static_cast<int>(data->shape_.size());
// setup stride (this should be optional, but some framework
// does not support NULL stride and thus will crash the program).
data->stride_.resize(data->dl_tensor.ndim, 1);
for (int i = data->dl_tensor.ndim - 2; i >= 0; --i) {
data->stride_[i] = data->shape_[i+1] * data->stride_[i+1];
}
data->dl_tensor.strides = dmlc::BeginPtr(data->stride_);
// setup dtype
data->dl_tensor.dtype = dtype;
// setup ctx
data->dl_tensor.ctx = ctx;
return ret;
}
// Implementation of API function
static DLTensor* MoveAsDLTensor(NDArray arr) {
DLTensor* tensor = const_cast<DLTensor*>(arr.operator->());
CHECK(reinterpret_cast<DLTensor*>(arr.data_) == tensor);
arr.data_ = nullptr;
return tensor;
}
// Container to DLManagedTensor
static DLManagedTensor* ToDLPack(NDArray::Container* from) {
CHECK(from != nullptr);
DLManagedTensor* ret = new DLManagedTensor();
ret->dl_tensor = from->dl_tensor;
ret->manager_ctx = from;
from->IncRef();
ret->deleter = NDArrayDLPackDeleter;
return ret;
}
};
NDArray NDArray::CreateView(std::vector<int64_t> shape,
DLDataType dtype) {
CHECK(data_ != nullptr);
CHECK(data_->dl_tensor.strides == nullptr)
<< "Can only create view for compact tensor";
NDArray ret = Internal::Create(shape, dtype, data_->dl_tensor.ctx);
ret.data_->dl_tensor.byte_offset =
this->data_->dl_tensor.byte_offset;
size_t curr_size = GetDataSize(this->data_->dl_tensor);
size_t view_size = GetDataSize(ret.data_->dl_tensor);
CHECK_LE(view_size, curr_size)
<< "Tries to create a view that has bigger memory than current one";
// increase ref count
this->data_->IncRef();
ret.data_->manager_ctx = this->data_;
ret.data_->dl_tensor.data = this->data_->dl_tensor.data;
return ret;
}
DLManagedTensor* NDArray::ToDLPack() const {
return Internal::ToDLPack(data_);
}
NDArray NDArray::Empty(std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx) {
NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor);
size_t alignment = GetDataAlignment(ret.data_->dl_tensor);
ret.data_->dl_tensor.data =
DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret->ctx, size, alignment, ret->dtype);
return ret;
}
NDArray NDArray::FromDLPack(DLManagedTensor* tensor) {
NDArray::Container* data = new NDArray::Container();
data->deleter = Internal::DLPackDeleter;
data->manager_ctx = tensor;
data->dl_tensor = tensor->dl_tensor;
return NDArray(data);
}
void NDArray::CopyFromTo(DLTensor* from,
DLTensor* to,
DGLStreamHandle stream) {
size_t from_size = GetDataSize(*from);
size_t to_size = GetDataSize(*to);
CHECK_EQ(from_size, to_size)
<< "DGLArrayCopyFromTo: The size must exactly match";
CHECK(from->ctx.device_type == to->ctx.device_type
|| from->ctx.device_type == kDLCPU
|| to->ctx.device_type == kDLCPU)
<< "Can not copy across different ctx types directly";
// Use the context that is *not* a cpu context to get the correct device
// api manager.
DGLContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx;
DeviceAPI::Get(ctx)->CopyDataFromTo(
from->data, static_cast<size_t>(from->byte_offset),
to->data, static_cast<size_t>(to->byte_offset),
from_size, from->ctx, to->ctx, from->dtype, stream);
}
} // namespace runtime
} // namespace dgl
using namespace dgl::runtime;
void NDArrayDLPackDeleter(DLManagedTensor* tensor) {
static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef();
delete tensor;
}
int DGLArrayAlloc(const dgl_index_t* shape,
int ndim,
int dtype_code,
int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
DGLArrayHandle* out) {
API_BEGIN();
DLDataType dtype;
dtype.code = static_cast<uint8_t>(dtype_code);
dtype.bits = static_cast<uint8_t>(dtype_bits);
dtype.lanes = static_cast<uint16_t>(dtype_lanes);
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
*out = NDArray::Internal::MoveAsDLTensor(
NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, ctx));
API_END();
}
int DGLArrayFree(DGLArrayHandle handle) {
API_BEGIN();
reinterpret_cast<NDArray::Container*>(handle)->DecRef();
API_END();
}
int DGLArrayCopyFromTo(DGLArrayHandle from,
DGLArrayHandle to,
DGLStreamHandle stream) {
API_BEGIN();
NDArray::CopyFromTo(from, to, stream);
API_END();
}
int DGLArrayFromDLPack(DLManagedTensor* from,
DGLArrayHandle* out) {
API_BEGIN();
*out = NDArray::Internal::MoveAsDLTensor(NDArray::FromDLPack(from));
API_END();
}
int DGLArrayToDLPack(DGLArrayHandle from,
DLManagedTensor** out) {
API_BEGIN();
*out = NDArray::Internal::ToDLPack(reinterpret_cast<NDArray::Container*>(from));
API_END();
}
void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor) {
(*(dltensor->deleter))(dltensor);
}
int DGLArrayCopyFromBytes(DGLArrayHandle handle,
void* data,
size_t nbytes) {
API_BEGIN();
DGLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes)
<< "DGLArrayCopyFromBytes: size mismatch";
DeviceAPI::Get(handle->ctx)->CopyDataFromTo(
data, 0,
handle->data, static_cast<size_t>(handle->byte_offset),
nbytes, cpu_ctx, handle->ctx, handle->dtype, nullptr);
API_END();
}
int DGLArrayCopyToBytes(DGLArrayHandle handle,
void* data,
size_t nbytes) {
API_BEGIN();
DGLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes)
<< "DGLArrayCopyToBytes: size mismatch";
DeviceAPI::Get(handle->ctx)->CopyDataFromTo(
handle->data, static_cast<size_t>(handle->byte_offset),
data, 0,
nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr);
API_END();
}