forked from torch/nngraph
-
Notifications
You must be signed in to change notification settings - Fork 1
/
graphinspecting.lua
139 lines (123 loc) · 4.25 KB
/
graphinspecting.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
-- The findCurrentNode() depends on the names of the
-- local variables in the nngraph.gModule source code.
local function findCurrentNode()
for level = 2, math.huge do
local info = debug.getinfo(level, "n")
if info == nil then
return nil
end
local funcName = info.name
if funcName == "neteval" then
local varName, node = debug.getlocal(level, 1)
if varName == "node" then
return node
end
end
end
end
-- Runs the func and calls onError(failedNode, ...) on an error.
-- The stack trace is inspected to find the failedNode.
local function runChecked(func, onError, ...)
-- The current node needs to be searched-for, before unrolling the stack.
local failedNode
local function errorHandler(message)
-- The stack traceback is added only if not already present.
if not string.find(message, 'stack traceback:\n', 1, true) then
message = debug.traceback(message, 2)
end
failedNode = findCurrentNode()
return message
end
local ok, result = xpcall(func, errorHandler)
if ok then
return result
end
onError(failedNode, ...)
-- Passing the level 0 avoids adding an additional error position info
-- to the message.
error(result, 0)
end
local function customToDot(graph, title, failedNode)
local str = graph:todot(title)
if not failedNode then
return str
end
local failedNodeId = nil
for i, node in ipairs(graph.nodes) do
if node.data == failedNode.data then
failedNodeId = node.id
break
end
end
if failedNodeId ~= nil then
-- The closing '}' is removed.
-- And red fillcolor is specified for the failedNode.
str = string.gsub(str, '}%s*$', '')
str = str .. string.format('n%s[style=filled, fillcolor=red];\n}',
failedNodeId)
end
return str
end
local function saveSvg(svgPathPrefix, dotStr)
io.stderr:write(string.format("saving %s.svg\n", svgPathPrefix))
local dotPath = svgPathPrefix .. '.dot'
local dotFile = io.open(dotPath, 'w')
dotFile:write(dotStr)
dotFile:close()
local svgPath = svgPathPrefix .. '.svg'
local cmd = string.format('dot -Tsvg -o %s %s', svgPath, dotPath)
os.execute(cmd)
end
local function onError(failedNode, gmodule)
local nInputs = gmodule.nInputs or #gmodule.innode.children
local svgPathPrefix = gmodule.name or string.format(
'nngraph_%sin_%sout', nInputs, #gmodule.outnode.children)
local dotStr = customToDot(gmodule.fg, svgPathPrefix, failedNode)
saveSvg(svgPathPrefix, dotStr)
end
local origFuncs = {
runForwardFunction = nn.gModule.runForwardFunction,
updateGradInput = nn.gModule.updateGradInput,
accGradParameters = nn.gModule.accGradParameters,
}
-- When debug is enabled,
-- a gmodule.name .. '.svg' will be saved
-- if an exception occurs in a graph execution.
-- The problematic node will be marked by red color.
function nngraph.setDebug(enable)
if not enable then
-- When debug is disabled,
-- the origFuncs are restored on nn.gModule.
for funcName, origFunc in pairs(origFuncs) do
nn.gModule[funcName] = origFunc
end
return
end
for funcName, origFunc in pairs(origFuncs) do
nn.gModule[funcName] = function(...)
local args = {...}
local gmodule = args[1]
return runChecked(function()
return origFunc(unpack(args))
end, onError, gmodule)
end
end
end
-- Sets node.data.annotations.name for the found nodes.
-- The local variables at the given stack level are inspected.
-- The default stack level is 1 (the function that called annotateNodes()).
function nngraph.annotateNodes(stackLevel)
stackLevel = stackLevel or 1
for index = 1, math.huge do
local varName, varValue = debug.getlocal(stackLevel + 1, index)
if not varName then
break
end
if torch.typename(varValue) == "nngraph.Node" then
-- An explicit name is preserved.
if not varValue.data.annotations.name then
varValue:annotate({name = varName})
end
end
end
end