forked from torch/cutorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FFI.lua
110 lines (92 loc) · 2.69 KB
/
FFI.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
local ok, ffi = pcall(require, 'ffi')
if ok then
local unpack = unpack or table.unpack
local cdefs = [[
typedef struct CUstream_st *cudaStream_t;
struct cublasContext;
typedef struct cublasContext *cublasHandle_t;
typedef struct CUhandle_st *cublasHandle_t;
typedef struct _THCCudaResourcesPerDevice {
cudaStream_t* streams;
cublasHandle_t* blasHandles;
size_t scratchSpacePerStream;
void** devScratchSpacePerStream;
} THCCudaResourcesPerDevice;
typedef struct THCState
{
struct THCRNGState* rngState;
struct cudaDeviceProp* deviceProperties;
THCCudaResourcesPerDevice* resourcesPerDevice;
int numDevices;
int numUserStreams;
int numUserBlasHandles;
struct THAllocator* cudaHostAllocator;
} THCState;
cudaStream_t THCState_getCurrentStream(THCState *state);
]]
local CudaTypes = {
{'float', ''},
{'unsigned char', 'Byte'},
{'char', 'Char'},
{'short', 'Short'},
{'int', 'Int'},
{'long','Long'},
{'double','Double'},
}
if cutorch.hasHalf then
table.insert(CudaTypes, {'half','Half'})
end
for _, typedata in ipairs(CudaTypes) do
local real, Real = unpack(typedata)
local ctype_def = [[
typedef struct THCStorage
{
real *data;
ptrdiff_t size;
int refcount;
char flag;
THAllocator *allocator;
void *allocatorContext;
struct THCStorage *view;
} THCStorage;
typedef struct THCTensor
{
long *size;
long *stride;
int nDimension;
THCStorage *storage;
ptrdiff_t storageOffset;
int refcount;
char flag;
} THCTensor;
]]
ctype_def = ctype_def:gsub('real',real):gsub('THCStorage','THCuda'..Real..'Storage'):gsub('THCTensor','THCuda'..Real..'Tensor')
cdefs = cdefs .. ctype_def
end
if cutorch.hasHalf then
ffi.cdef([[
typedef struct {
unsigned short x;
} __half;
typedef __half half;
]])
end
ffi.cdef(cdefs)
for _, typedata in ipairs(CudaTypes) do
local real, Real = unpack(typedata)
local Storage = torch.getmetatable('torch.Cuda' .. Real .. 'Storage')
local Storage_tt = ffi.typeof('THCuda' .. Real .. 'Storage**')
rawset(Storage, "cdata", function(self) return Storage_tt(self)[0] end)
rawset(Storage, "data", function(self) return Storage_tt(self)[0].data end)
-- Tensor
local Tensor = torch.getmetatable('torch.Cuda' .. Real .. 'Tensor')
local Tensor_tt = ffi.typeof('THCuda' .. Real .. 'Tensor**')
rawset(Tensor, "cdata", function(self) return Tensor_tt(self)[0] end)
rawset(Tensor, "data",
function(self)
self = Tensor_tt(self)[0]
return self.storage ~= nil and self.storage.data + self.storageOffset or nil
end
)
end
end