Skip to content

Commit a8d3149

Browse files
committed
Fix display
1 parent cc7dec6 commit a8d3149

File tree

1 file changed

+53
-28
lines changed

1 file changed

+53
-28
lines changed

lib/axon/display.ex

+53-28
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,24 @@ defmodule Axon.Display do
5656
vertical_symbol: "|"
5757
)
5858
|> then(&(&1 <> "Total Parameters: #{model_info.num_params}\n"))
59-
|> then(&(&1 <> "Total Parameters Memory: #{model_info.total_param_byte_size} bytes\n"))
59+
|> then(
60+
&(&1 <> "Total Parameters Memory: #{readable_size(model_info.total_param_byte_size)}\n")
61+
)
6062
end
6163

64+
defp readable_size(n) when n < 1_000, do: "#{n} bytes"
65+
66+
defp readable_size(n) when n >= 1_000 and n < 1_000_000,
67+
do: "#{float_format(n / 1_000)} kilobytes"
68+
69+
defp readable_size(n) when n >= 1_000_000 and n < 1_000_000_000,
70+
do: "#{float_format(n / 1_000_000)} megabytes"
71+
72+
defp readable_size(n) when n >= 1_000_000_000 and n < 1_000_000_000_000,
73+
do: "#{float_format(n / 1_000_000_000)} gigabytes"
74+
75+
defp float_format(value), do: :io_lib.format("~.2f", [value])
76+
6277
defp assert_table_rex!(fn_name) do
6378
unless Code.ensure_loaded?(TableRex) do
6479
raise RuntimeError, """
@@ -93,7 +108,6 @@ defmodule Axon.Display do
93108
defp do_axon_to_rows(
94109
%Axon.Node{
95110
id: id,
96-
op: structure,
97111
op_name: :container,
98112
parent: [parents],
99113
name: name_fn
@@ -104,7 +118,7 @@ defmodule Axon.Display do
104118
op_counts,
105119
model_info
106120
) do
107-
{input_names, {cache, op_counts, model_info}} =
121+
{_, {cache, op_counts, model_info}} =
108122
Enum.map_reduce(parents, {cache, op_counts, model_info}, fn
109123
parent_id, {cache, op_counts, model_info} ->
110124
{_, name, _shape, cache, op_counts, model_info} =
@@ -119,11 +133,11 @@ defmodule Axon.Display do
119133
shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates)
120134

121135
row = [
122-
"#{name} ( #{op_string} #{inspect(apply(structure, input_names))} )",
136+
"#{name} ( #{op_string} )",
123137
"#{inspect({})}",
124-
"#{inspect(shape)}",
138+
render_output_shape(shape),
125139
render_options([]),
126-
render_parameters(%{}, [])
140+
render_parameters(nil, %{}, [])
127141
]
128142

129143
{row, name, shape, cache, op_counts, model_info}
@@ -136,7 +150,7 @@ defmodule Axon.Display do
136150
parameters: params,
137151
name: name_fn,
138152
opts: opts,
139-
policy: %{params: {_, bitsize}},
153+
policy: %{params: params_policy},
140154
op_name: op_name
141155
},
142156
nodes,
@@ -145,6 +159,12 @@ defmodule Axon.Display do
145159
op_counts,
146160
model_info
147161
) do
162+
bitsize =
163+
case params_policy do
164+
nil -> 32
165+
{_, bitsize} -> bitsize
166+
end
167+
148168
{input_names_and_shapes, {cache, op_counts, model_info}} =
149169
Enum.map_reduce(parents, {cache, op_counts, model_info}, fn
150170
parent_id, {cache, op_counts, model_info} ->
@@ -154,39 +174,34 @@ defmodule Axon.Display do
154174
{{name, shape}, {cache, op_counts, model_info}}
155175
end)
156176

157-
{input_names, input_shapes} = Enum.unzip(input_names_and_shapes)
177+
{_, input_shapes} = Enum.unzip(input_names_and_shapes)
178+
179+
inputs =
180+
Map.new(input_names_and_shapes, fn {name, shape} ->
181+
{name, render_output_shape(shape)}
182+
end)
158183

159184
num_params =
160185
Enum.reduce(params, 0, fn
161186
%Parameter{shape: {:tuple, shapes}}, acc ->
162187
Enum.reduce(shapes, acc, &(Nx.size(apply(&1, input_shapes)) + &2))
163188

164-
%Parameter{shape: shape_fn}, acc ->
189+
%Parameter{template: shape_fn}, acc when is_function(shape_fn) ->
165190
acc + Nx.size(apply(shape_fn, input_shapes))
166191
end)
167192

168193
param_byte_size = num_params * div(bitsize, 8)
169194

170195
op_inspect = Atom.to_string(op_name)
171-
172-
inputs =
173-
case input_names do
174-
[] ->
175-
""
176-
177-
[_ | _] = input_names ->
178-
"#{inspect(input_names)}"
179-
end
180-
181196
name = name_fn.(op_name, op_counts)
182197
shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates)
183198

184199
row = [
185-
"#{name} ( #{op_inspect}#{inputs} )",
186-
"#{inspect(input_shapes)}",
187-
"#{inspect(shape)}",
200+
"#{name} ( #{op_inspect} )",
201+
"#{inspect(inputs)}",
202+
render_output_shape(shape),
188203
render_options(opts),
189-
render_parameters(params, input_shapes)
204+
render_parameters(params_policy, params, input_shapes)
190205
]
191206

192207
model_info =
@@ -200,6 +215,14 @@ defmodule Axon.Display do
200215
{row, name, shape, cache, op_counts, model_info}
201216
end
202217

218+
defp render_output_shape(%Nx.Tensor{} = template) do
219+
type = type_str(Nx.type(template))
220+
shape = shape_string(Nx.shape(template))
221+
"#{type}#{shape}"
222+
end
223+
224+
defp type_str({type, size}), do: "#{Atom.to_string(type)}#{size}"
225+
203226
defp render_options(opts) do
204227
opts
205228
|> Enum.map(fn {key, val} ->
@@ -209,21 +232,23 @@ defmodule Axon.Display do
209232
|> Enum.join("\n")
210233
end
211234

212-
defp render_parameters(params, input_shapes) do
235+
defp render_parameters(policy, params, input_shapes) do
236+
type = policy || {:f, 32}
237+
213238
params
214239
|> Enum.map(fn
215240
%Parameter{name: name, shape: {:tuple, shape_fns}} ->
216241
shapes =
217242
shape_fns
218243
|> Enum.map(&apply(&1, input_shapes))
219-
|> Enum.map(fn shape -> "f32#{shape_string(shape)}" end)
244+
|> Enum.map(fn shape -> "#{type_str(type)}#{shape_string(shape)}" end)
220245
|> List.to_tuple()
221246

222247
"#{name}: tuple#{inspect(shapes)}"
223248

224-
%Parameter{name: name, shape: shape_fn} ->
225-
shape = apply(shape_fn, input_shapes)
226-
"#{name}: f32#{shape_string(shape)}"
249+
%Parameter{name: name, template: shape_fn} when is_function(shape_fn) ->
250+
shape = Nx.shape(apply(shape_fn, input_shapes))
251+
"#{name}: #{type_str(type)}#{shape_string(shape)}"
227252
end)
228253
|> Enum.join("\n")
229254
end

0 commit comments

Comments
 (0)