@@ -56,9 +56,24 @@ defmodule Axon.Display do
56
56
vertical_symbol: "|"
57
57
)
58
58
|> then ( & ( & 1 <> "Total Parameters: #{ model_info . num_params } \n " ) )
59
- |> then ( & ( & 1 <> "Total Parameters Memory: #{ model_info . total_param_byte_size } bytes\n " ) )
59
+ |> then (
60
+ & ( & 1 <> "Total Parameters Memory: #{ readable_size ( model_info . total_param_byte_size ) } \n " )
61
+ )
60
62
end
61
63
64
+ defp readable_size ( n ) when n < 1_000 , do: "#{ n } bytes"
65
+
66
+ defp readable_size ( n ) when n >= 1_000 and n < 1_000_000 ,
67
+ do: "#{ float_format ( n / 1_000 ) } kilobytes"
68
+
69
+ defp readable_size ( n ) when n >= 1_000_000 and n < 1_000_000_000 ,
70
+ do: "#{ float_format ( n / 1_000_000 ) } megabytes"
71
+
72
+ defp readable_size ( n ) when n >= 1_000_000_000 and n < 1_000_000_000_000 ,
73
+ do: "#{ float_format ( n / 1_000_000_000 ) } gigabytes"
74
+
75
+ defp float_format ( value ) , do: :io_lib . format ( "~.2f" , [ value ] )
76
+
62
77
defp assert_table_rex! ( fn_name ) do
63
78
unless Code . ensure_loaded? ( TableRex ) do
64
79
raise RuntimeError , """
@@ -93,7 +108,6 @@ defmodule Axon.Display do
93
108
defp do_axon_to_rows (
94
109
% Axon.Node {
95
110
id: id ,
96
- op: structure ,
97
111
op_name: :container ,
98
112
parent: [ parents ] ,
99
113
name: name_fn
@@ -104,7 +118,7 @@ defmodule Axon.Display do
104
118
op_counts ,
105
119
model_info
106
120
) do
107
- { input_names , { cache , op_counts , model_info } } =
121
+ { _ , { cache , op_counts , model_info } } =
108
122
Enum . map_reduce ( parents , { cache , op_counts , model_info } , fn
109
123
parent_id , { cache , op_counts , model_info } ->
110
124
{ _ , name , _shape , cache , op_counts , model_info } =
@@ -119,11 +133,11 @@ defmodule Axon.Display do
119
133
shape = Axon . get_output_shape ( % Axon { output: id , nodes: nodes } , templates )
120
134
121
135
row = [
122
- "#{ name } ( #{ op_string } #{ inspect ( apply ( structure , input_names ) ) } )" ,
136
+ "#{ name } ( #{ op_string } )" ,
123
137
"#{ inspect ( { } ) } " ,
124
- " #{ inspect ( shape ) } " ,
138
+ render_output_shape ( shape ) ,
125
139
render_options ( [ ] ) ,
126
- render_parameters ( % { } , [ ] )
140
+ render_parameters ( nil , % { } , [ ] )
127
141
]
128
142
129
143
{ row , name , shape , cache , op_counts , model_info }
@@ -136,7 +150,7 @@ defmodule Axon.Display do
136
150
parameters: params ,
137
151
name: name_fn ,
138
152
opts: opts ,
139
- policy: % { params: { _ , bitsize } } ,
153
+ policy: % { params: params_policy } ,
140
154
op_name: op_name
141
155
} ,
142
156
nodes ,
@@ -145,6 +159,12 @@ defmodule Axon.Display do
145
159
op_counts ,
146
160
model_info
147
161
) do
162
+ bitsize =
163
+ case params_policy do
164
+ nil -> 32
165
+ { _ , bitsize } -> bitsize
166
+ end
167
+
148
168
{ input_names_and_shapes , { cache , op_counts , model_info } } =
149
169
Enum . map_reduce ( parents , { cache , op_counts , model_info } , fn
150
170
parent_id , { cache , op_counts , model_info } ->
@@ -154,39 +174,34 @@ defmodule Axon.Display do
154
174
{ { name , shape } , { cache , op_counts , model_info } }
155
175
end )
156
176
157
- { input_names , input_shapes } = Enum . unzip ( input_names_and_shapes )
177
+ { _ , input_shapes } = Enum . unzip ( input_names_and_shapes )
178
+
179
+ inputs =
180
+ Map . new ( input_names_and_shapes , fn { name , shape } ->
181
+ { name , render_output_shape ( shape ) }
182
+ end )
158
183
159
184
num_params =
160
185
Enum . reduce ( params , 0 , fn
161
186
% Parameter { shape: { :tuple , shapes } } , acc ->
162
187
Enum . reduce ( shapes , acc , & ( Nx . size ( apply ( & 1 , input_shapes ) ) + & 2 ) )
163
188
164
- % Parameter { shape : shape_fn } , acc ->
189
+ % Parameter { template : shape_fn } , acc when is_function ( shape_fn ) ->
165
190
acc + Nx . size ( apply ( shape_fn , input_shapes ) )
166
191
end )
167
192
168
193
param_byte_size = num_params * div ( bitsize , 8 )
169
194
170
195
op_inspect = Atom . to_string ( op_name )
171
-
172
- inputs =
173
- case input_names do
174
- [ ] ->
175
- ""
176
-
177
- [ _ | _ ] = input_names ->
178
- "#{ inspect ( input_names ) } "
179
- end
180
-
181
196
name = name_fn . ( op_name , op_counts )
182
197
shape = Axon . get_output_shape ( % Axon { output: id , nodes: nodes } , templates )
183
198
184
199
row = [
185
- "#{ name } ( #{ op_inspect } #{ inputs } )" ,
186
- "#{ inspect ( input_shapes ) } " ,
187
- " #{ inspect ( shape ) } " ,
200
+ "#{ name } ( #{ op_inspect } )" ,
201
+ "#{ inspect ( inputs ) } " ,
202
+ render_output_shape ( shape ) ,
188
203
render_options ( opts ) ,
189
- render_parameters ( params , input_shapes )
204
+ render_parameters ( params_policy , params , input_shapes )
190
205
]
191
206
192
207
model_info =
@@ -200,6 +215,14 @@ defmodule Axon.Display do
200
215
{ row , name , shape , cache , op_counts , model_info }
201
216
end
202
217
218
+ defp render_output_shape ( % Nx.Tensor { } = template ) do
219
+ type = type_str ( Nx . type ( template ) )
220
+ shape = shape_string ( Nx . shape ( template ) )
221
+ "#{ type } #{ shape } "
222
+ end
223
+
224
+ defp type_str ( { type , size } ) , do: "#{ Atom . to_string ( type ) } #{ size } "
225
+
203
226
defp render_options ( opts ) do
204
227
opts
205
228
|> Enum . map ( fn { key , val } ->
@@ -209,21 +232,23 @@ defmodule Axon.Display do
209
232
|> Enum . join ( "\n " )
210
233
end
211
234
212
- defp render_parameters ( params , input_shapes ) do
235
+ defp render_parameters ( policy , params , input_shapes ) do
236
+ type = policy || { :f , 32 }
237
+
213
238
params
214
239
|> Enum . map ( fn
215
240
% Parameter { name: name , shape: { :tuple , shape_fns } } ->
216
241
shapes =
217
242
shape_fns
218
243
|> Enum . map ( & apply ( & 1 , input_shapes ) )
219
- |> Enum . map ( fn shape -> "f32 #{ shape_string ( shape ) } " end )
244
+ |> Enum . map ( fn shape -> "#{ type_str ( type ) } #{ shape_string ( shape ) } " end )
220
245
|> List . to_tuple ( )
221
246
222
247
"#{ name } : tuple#{ inspect ( shapes ) } "
223
248
224
- % Parameter { name: name , shape : shape_fn } ->
225
- shape = apply ( shape_fn , input_shapes )
226
- "#{ name } : f32 #{ shape_string ( shape ) } "
249
+ % Parameter { name: name , template : shape_fn } when is_function ( shape_fn ) ->
250
+ shape = Nx . shape ( apply ( shape_fn , input_shapes ) )
251
+ "#{ name } : #{ type_str ( type ) } #{ shape_string ( shape ) } "
227
252
end )
228
253
|> Enum . join ( "\n " )
229
254
end
0 commit comments