-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinit.lua
48 lines (39 loc) · 1.77 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
require 'torch'
require 'nn'
local tbc = tbc or {}
tbc.C = require('tbc.ffi') -- keep the loaded C object around, so that it doesn't get garbage collected
function tbc.bind(lib, base_names, type_name, state_getter)
local ftable = {}
local prefix = 'TemporalConvolutionTBC_'
for i,n in ipairs(base_names) do
local ok,v = pcall(function() return lib[prefix .. n .. '_' .. type_name] end)
if ok then
if state_getter then
ftable[n] = function(...) v(state_getter, ...) end
else
ftable[n] = function(...) v(...) end
end
else
print('not found: ' .. prefix .. n .. '_' .. type_name)
end
end
return ftable
end
local function_names = {"updateOutput", "updateGradInput", "accGradParameters"}
tbc.kernels = {}
tbc.kernels['torch.FloatTensor'] = tbc.bind(tbc.C, function_names, 'Float')
tbc.kernels['torch.DoubleTensor'] = tbc.bind(tbc.C, function_names, 'Double')
torch.getmetatable('torch.FloatTensor').TBC = tbc.kernels['torch.FloatTensor']
torch.getmetatable('torch.DoubleTensor').TBC = tbc.kernels['torch.DoubleTensor']
if tbc.C.torchtbc_has_cuda() == 1 then
tbc.kernels['torch.CudaTensor'] = tbc.bind(tbc.C, function_names, 'Cuda', cutorch.getState())
tbc.kernels['torch.CudaDoubleTensor'] = tbc.bind(tbc.C, function_names, 'CudaDouble', cutorch.getState())
torch.getmetatable('torch.CudaTensor').TBC = tbc.kernels['torch.CudaTensor']
torch.getmetatable('torch.CudaDoubleTensor').TBC = tbc.kernels['torch.CudaDoubleTensor']
if cutorch.hasHalf then
tbc.kernels['torch.CudaHalfTensor'] = tbc.bind(tbc.C, function_names, 'CudaHalf', cutorch.getState())
torch.getmetatable('torch.CudaHalfTensor').TBC = tbc.kernels['torch.CudaHalfTensor']
end
end
require('tbc.TemporalConvolutionTBC')
return tbc