@@ -56,9 +56,24 @@ defmodule Axon.Display do
5656 vertical_symbol: "|"
5757 )
5858 |> then ( & ( & 1 <> "Total Parameters: #{ model_info . num_params } \n " ) )
59- |> then ( & ( & 1 <> "Total Parameters Memory: #{ model_info . total_param_byte_size } bytes\n " ) )
59+ |> then (
60+ & ( & 1 <> "Total Parameters Memory: #{ readable_size ( model_info . total_param_byte_size ) } \n " )
61+ )
6062 end
6163
64+ defp readable_size ( n ) when n < 1_000 , do: "#{ n } bytes"
65+
66+ defp readable_size ( n ) when n >= 1_000 and n < 1_000_000 ,
67+ do: "#{ float_format ( n / 1_000 ) } kilobytes"
68+
69+ defp readable_size ( n ) when n >= 1_000_000 and n < 1_000_000_000 ,
70+ do: "#{ float_format ( n / 1_000_000 ) } megabytes"
71+
72+ defp readable_size ( n ) when n >= 1_000_000_000 and n < 1_000_000_000_000 ,
73+ do: "#{ float_format ( n / 1_000_000_000 ) } gigabytes"
74+
75+ defp float_format ( value ) , do: :io_lib . format ( "~.2f" , [ value ] )
76+
6277 defp assert_table_rex! ( fn_name ) do
6378 unless Code . ensure_loaded? ( TableRex ) do
6479 raise RuntimeError , """
@@ -93,7 +108,6 @@ defmodule Axon.Display do
93108 defp do_axon_to_rows (
94109 % Axon.Node {
95110 id: id ,
96- op: structure ,
97111 op_name: :container ,
98112 parent: [ parents ] ,
99113 name: name_fn
@@ -104,7 +118,7 @@ defmodule Axon.Display do
104118 op_counts ,
105119 model_info
106120 ) do
107- { input_names , { cache , op_counts , model_info } } =
121+ { _ , { cache , op_counts , model_info } } =
108122 Enum . map_reduce ( parents , { cache , op_counts , model_info } , fn
109123 parent_id , { cache , op_counts , model_info } ->
110124 { _ , name , _shape , cache , op_counts , model_info } =
@@ -119,11 +133,11 @@ defmodule Axon.Display do
119133 shape = Axon . get_output_shape ( % Axon { output: id , nodes: nodes } , templates )
120134
121135 row = [
122- "#{ name } ( #{ op_string } #{ inspect ( apply ( structure , input_names ) ) } )" ,
136+ "#{ name } ( #{ op_string } )" ,
123137 "#{ inspect ( { } ) } " ,
124- " #{ inspect ( shape ) } " ,
138+ render_output_shape ( shape ) ,
125139 render_options ( [ ] ) ,
126- render_parameters ( % { } , [ ] )
140+ render_parameters ( nil , % { } , [ ] )
127141 ]
128142
129143 { row , name , shape , cache , op_counts , model_info }
@@ -136,7 +150,7 @@ defmodule Axon.Display do
136150 parameters: params ,
137151 name: name_fn ,
138152 opts: opts ,
139- policy: % { params: { _ , bitsize } } ,
153+ policy: % { params: params_policy } ,
140154 op_name: op_name
141155 } ,
142156 nodes ,
@@ -145,6 +159,12 @@ defmodule Axon.Display do
145159 op_counts ,
146160 model_info
147161 ) do
162+ bitsize =
163+ case params_policy do
164+ nil -> 32
165+ { _ , bitsize } -> bitsize
166+ end
167+
148168 { input_names_and_shapes , { cache , op_counts , model_info } } =
149169 Enum . map_reduce ( parents , { cache , op_counts , model_info } , fn
150170 parent_id , { cache , op_counts , model_info } ->
@@ -154,39 +174,34 @@ defmodule Axon.Display do
154174 { { name , shape } , { cache , op_counts , model_info } }
155175 end )
156176
157- { input_names , input_shapes } = Enum . unzip ( input_names_and_shapes )
177+ { _ , input_shapes } = Enum . unzip ( input_names_and_shapes )
178+
179+ inputs =
180+ Map . new ( input_names_and_shapes , fn { name , shape } ->
181+ { name , render_output_shape ( shape ) }
182+ end )
158183
159184 num_params =
160185 Enum . reduce ( params , 0 , fn
161186 % Parameter { shape: { :tuple , shapes } } , acc ->
162187 Enum . reduce ( shapes , acc , & ( Nx . size ( apply ( & 1 , input_shapes ) ) + & 2 ) )
163188
164- % Parameter { shape : shape_fn } , acc ->
189+ % Parameter { template : shape_fn } , acc when is_function ( shape_fn ) ->
165190 acc + Nx . size ( apply ( shape_fn , input_shapes ) )
166191 end )
167192
168193 param_byte_size = num_params * div ( bitsize , 8 )
169194
170195 op_inspect = Atom . to_string ( op_name )
171-
172- inputs =
173- case input_names do
174- [ ] ->
175- ""
176-
177- [ _ | _ ] = input_names ->
178- "#{ inspect ( input_names ) } "
179- end
180-
181196 name = name_fn . ( op_name , op_counts )
182197 shape = Axon . get_output_shape ( % Axon { output: id , nodes: nodes } , templates )
183198
184199 row = [
185- "#{ name } ( #{ op_inspect } #{ inputs } )" ,
186- "#{ inspect ( input_shapes ) } " ,
187- " #{ inspect ( shape ) } " ,
200+ "#{ name } ( #{ op_inspect } )" ,
201+ "#{ inspect ( inputs ) } " ,
202+ render_output_shape ( shape ) ,
188203 render_options ( opts ) ,
189- render_parameters ( params , input_shapes )
204+ render_parameters ( params_policy , params , input_shapes )
190205 ]
191206
192207 model_info =
@@ -200,6 +215,14 @@ defmodule Axon.Display do
200215 { row , name , shape , cache , op_counts , model_info }
201216 end
202217
218+ defp render_output_shape ( % Nx.Tensor { } = template ) do
219+ type = type_str ( Nx . type ( template ) )
220+ shape = shape_string ( Nx . shape ( template ) )
221+ "#{ type } #{ shape } "
222+ end
223+
224+ defp type_str ( { type , size } ) , do: "#{ Atom . to_string ( type ) } #{ size } "
225+
203226 defp render_options ( opts ) do
204227 opts
205228 |> Enum . map ( fn { key , val } ->
@@ -209,21 +232,23 @@ defmodule Axon.Display do
209232 |> Enum . join ( "\n " )
210233 end
211234
212- defp render_parameters ( params , input_shapes ) do
235+ defp render_parameters ( policy , params , input_shapes ) do
236+ type = policy || { :f , 32 }
237+
213238 params
214239 |> Enum . map ( fn
215240 % Parameter { name: name , shape: { :tuple , shape_fns } } ->
216241 shapes =
217242 shape_fns
218243 |> Enum . map ( & apply ( & 1 , input_shapes ) )
219- |> Enum . map ( fn shape -> "f32 #{ shape_string ( shape ) } " end )
244+ |> Enum . map ( fn shape -> "#{ type_str ( type ) } #{ shape_string ( shape ) } " end )
220245 |> List . to_tuple ( )
221246
222247 "#{ name } : tuple#{ inspect ( shapes ) } "
223248
224- % Parameter { name: name , shape : shape_fn } ->
225- shape = apply ( shape_fn , input_shapes )
226- "#{ name } : f32 #{ shape_string ( shape ) } "
249+ % Parameter { name: name , template : shape_fn } when is_function ( shape_fn ) ->
250+ shape = Nx . shape ( apply ( shape_fn , input_shapes ) )
251+ "#{ name } : #{ type_str ( type ) } #{ shape_string ( shape ) } "
227252 end )
228253 |> Enum . join ( "\n " )
229254 end
0 commit comments