@@ -687,98 +687,6 @@ defmodule Axon.Compiler do
687
687
{ id , model_funs , cache , op_counts , block_cache , model_state_meta }
688
688
end
689
689
690
- defp recur_model_funs (
691
- % Axon.Node { id: id , op: :namespace , name: name_fn , parent: [ parent ] } ,
692
- nodes ,
693
- { cache , op_counts , block_cache , model_state_meta } ,
694
- config
695
- ) do
696
- name = name_fn . ( :namespace , op_counts )
697
- # To ensure that a namespace always has the same layer names,
698
- # we reset op_counts, input layers always belong to the global
699
- # namespace, so we include those regardless
700
- input_count = op_counts [ :input ] || 0
701
- namespace_op_counts = % { input: input_count }
702
- namespace_model_state_meta = % { parameters: % { } , state: % { } , frozen_parameters: % { } }
703
-
704
- # All of the children of this namespace belong to it, so
705
- # we forward this name to the namespace, but everything after
706
- # it belongs to whatever namespace we're currently in
707
- { parent_id , { cache , namespace_op_counts , block_cache , namespace_model_state_meta } } =
708
- to_model_funs (
709
- parent ,
710
- nodes ,
711
- { cache , namespace_op_counts , block_cache , namespace_model_state_meta } ,
712
- config
713
- )
714
-
715
- # Update the global op_count of input layers, since they
716
- # are a global operation regardless of where they are
717
- input_count = namespace_op_counts [ :input ] || 0
718
- op_counts = Map . put ( op_counts , :input , input_count )
719
-
720
- # Update the model state meta to include the namespace model state meta
721
- model_state_meta =
722
- model_state_meta
723
- |> Map . update! ( :parameters , & Map . put ( & 1 , name , namespace_model_state_meta [ :parameters ] ) )
724
- |> Map . update! ( :state , & Map . put ( & 1 , name , namespace_model_state_meta [ :state ] ) )
725
- |> Map . update! (
726
- :frozen_parameters ,
727
- & Map . put ( & 1 , name , namespace_model_state_meta [ :frozen_parameters ] )
728
- )
729
-
730
- # The function just returns the result of it's child,
731
- # or parent depending on how you view the tree
732
- predict_fun = fn params , inputs , state , cache , result_cache , fn_stacktrace ->
733
- # We're only concerned with this namespaces parameters, so we pair
734
- # down parameters first given the namespace
735
- namespace_params = params [ name ]
736
-
737
- # TODO: How should hooks be handled here?
738
- # TODO: I think we can actually handle parameter freezing and access
739
- # better here by only forwarding params[namespace] to the child function
740
- { out , { state , result_cache } } =
741
- call_predict_cache (
742
- parent_id ,
743
- namespace_params ,
744
- inputs ,
745
- state ,
746
- cache ,
747
- result_cache ,
748
- fn_stacktrace
749
- )
750
-
751
- state =
752
- if map_size ( state ) == 0 do
753
- state
754
- else
755
- % { name => state }
756
- end
757
-
758
- { out , { state , result_cache } }
759
- end
760
-
761
- init_fun = fn template , cache , result_cache , fn_stacktrace , keys ->
762
- { _parent_template , { namespace_params , result_cache } } =
763
- call_init_cache ( parent_id , template , % { } , cache , result_cache , fn_stacktrace , keys )
764
-
765
- params =
766
- if namespace_params == % { } do
767
- % { }
768
- else
769
- % { name => namespace_params }
770
- end
771
-
772
- { pred_expr , { _ , result_cache } } =
773
- predict_fun . ( params , template , % { } , cache , result_cache , fn_stacktrace )
774
-
775
- { Nx . to_template ( pred_expr ) , { params , result_cache } }
776
- end
777
-
778
- model_funs = % { predict: predict_fun , init: init_fun }
779
- { id , model_funs , cache , op_counts , block_cache , model_state_meta }
780
- end
781
-
782
690
defp recur_model_funs (
783
691
% Axon.Node {
784
692
id: id ,
0 commit comments