@@ -687,98 +687,6 @@ defmodule Axon.Compiler do
687687 { id , model_funs , cache , op_counts , block_cache , model_state_meta }
688688 end
689689
690- defp recur_model_funs (
691- % Axon.Node { id: id , op: :namespace , name: name_fn , parent: [ parent ] } ,
692- nodes ,
693- { cache , op_counts , block_cache , model_state_meta } ,
694- config
695- ) do
696- name = name_fn . ( :namespace , op_counts )
697- # To ensure that a namespace always has the same layer names,
698- # we reset op_counts, input layers always belong to the global
699- # namespace, so we include those regardless
700- input_count = op_counts [ :input ] || 0
701- namespace_op_counts = % { input: input_count }
702- namespace_model_state_meta = % { parameters: % { } , state: % { } , frozen_parameters: % { } }
703-
704- # All of the children of this namespace belong to it, so
705- # we forward this name to the namespace, but everything after
706- # it belongs to whatever namespace we're currently in
707- { parent_id , { cache , namespace_op_counts , block_cache , namespace_model_state_meta } } =
708- to_model_funs (
709- parent ,
710- nodes ,
711- { cache , namespace_op_counts , block_cache , namespace_model_state_meta } ,
712- config
713- )
714-
715- # Update the global op_count of input layers, since they
716- # are a global operation regardless of where they are
717- input_count = namespace_op_counts [ :input ] || 0
718- op_counts = Map . put ( op_counts , :input , input_count )
719-
720- # Update the model state meta to include the namespace model state meta
721- model_state_meta =
722- model_state_meta
723- |> Map . update! ( :parameters , & Map . put ( & 1 , name , namespace_model_state_meta [ :parameters ] ) )
724- |> Map . update! ( :state , & Map . put ( & 1 , name , namespace_model_state_meta [ :state ] ) )
725- |> Map . update! (
726- :frozen_parameters ,
727- & Map . put ( & 1 , name , namespace_model_state_meta [ :frozen_parameters ] )
728- )
729-
730- # The function just returns the result of it's child,
731- # or parent depending on how you view the tree
732- predict_fun = fn params , inputs , state , cache , result_cache , fn_stacktrace ->
733- # We're only concerned with this namespaces parameters, so we pair
734- # down parameters first given the namespace
735- namespace_params = params [ name ]
736-
737- # TODO: How should hooks be handled here?
738- # TODO: I think we can actually handle parameter freezing and access
739- # better here by only forwarding params[namespace] to the child function
740- { out , { state , result_cache } } =
741- call_predict_cache (
742- parent_id ,
743- namespace_params ,
744- inputs ,
745- state ,
746- cache ,
747- result_cache ,
748- fn_stacktrace
749- )
750-
751- state =
752- if map_size ( state ) == 0 do
753- state
754- else
755- % { name => state }
756- end
757-
758- { out , { state , result_cache } }
759- end
760-
761- init_fun = fn template , cache , result_cache , fn_stacktrace , keys ->
762- { _parent_template , { namespace_params , result_cache } } =
763- call_init_cache ( parent_id , template , % { } , cache , result_cache , fn_stacktrace , keys )
764-
765- params =
766- if namespace_params == % { } do
767- % { }
768- else
769- % { name => namespace_params }
770- end
771-
772- { pred_expr , { _ , result_cache } } =
773- predict_fun . ( params , template , % { } , cache , result_cache , fn_stacktrace )
774-
775- { Nx . to_template ( pred_expr ) , { params , result_cache } }
776- end
777-
778- model_funs = % { predict: predict_fun , init: init_fun }
779- { id , model_funs , cache , op_counts , block_cache , model_state_meta }
780- end
781-
782690 defp recur_model_funs (
783691 % Axon.Node {
784692 id: id ,
0 commit comments