Skip to content

Commit e5c8a37

Browse files
authored
Remove dead code related to namespaces (#602)
They were removed in this commit 44f36b5
1 parent d1dc13b commit e5c8a37

File tree

2 files changed

+2
-94
lines changed

2 files changed

+2
-94
lines changed

lib/axon.ex

+2-2
Original file line numberDiff line numberDiff line change
@@ -3875,7 +3875,7 @@ defmodule Axon do
38753875
## `init_fn`
38763876
38773877
The `init_fn` receives two arguments, the input template and
3878-
an optional map with initial parameters for layers or namespaces:
3878+
an optional map with initial parameters for layers:
38793879
38803880
{init_fn, predict_fn} = Axon.build(model)
38813881
init_fn.(Nx.template({1, 1}, {:f, 32}), %{"dense_0" => dense_params})
@@ -3968,7 +3968,7 @@ defmodule Axon do
39683968
purposes.
39693969
39703970
You may optionally specify initial parameters for some layers or
3971-
namespaces by passing a partial parameter map:
3971+
by passing a partial parameter map:
39723972
39733973
Axon.trace_init(model, %{"dense_0" => dense_params})
39743974

lib/axon/compiler.ex

-92
Original file line numberDiff line numberDiff line change
@@ -687,98 +687,6 @@ defmodule Axon.Compiler do
687687
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
688688
end
689689

690-
defp recur_model_funs(
691-
%Axon.Node{id: id, op: :namespace, name: name_fn, parent: [parent]},
692-
nodes,
693-
{cache, op_counts, block_cache, model_state_meta},
694-
config
695-
) do
696-
name = name_fn.(:namespace, op_counts)
697-
# To ensure that a namespace always has the same layer names,
698-
# we reset op_counts, input layers always belong to the global
699-
# namespace, so we include those regardless
700-
input_count = op_counts[:input] || 0
701-
namespace_op_counts = %{input: input_count}
702-
namespace_model_state_meta = %{parameters: %{}, state: %{}, frozen_parameters: %{}}
703-
704-
# All of the children of this namespace belong to it, so
705-
# we forward this name to the namespace, but everything after
706-
# it belongs to whatever namespace we're currently in
707-
{parent_id, {cache, namespace_op_counts, block_cache, namespace_model_state_meta}} =
708-
to_model_funs(
709-
parent,
710-
nodes,
711-
{cache, namespace_op_counts, block_cache, namespace_model_state_meta},
712-
config
713-
)
714-
715-
# Update the global op_count of input layers, since they
716-
# are a global operation regardless of where they are
717-
input_count = namespace_op_counts[:input] || 0
718-
op_counts = Map.put(op_counts, :input, input_count)
719-
720-
# Update the model state meta to include the namespace model state meta
721-
model_state_meta =
722-
model_state_meta
723-
|> Map.update!(:parameters, &Map.put(&1, name, namespace_model_state_meta[:parameters]))
724-
|> Map.update!(:state, &Map.put(&1, name, namespace_model_state_meta[:state]))
725-
|> Map.update!(
726-
:frozen_parameters,
727-
&Map.put(&1, name, namespace_model_state_meta[:frozen_parameters])
728-
)
729-
730-
# The function just returns the result of it's child,
731-
# or parent depending on how you view the tree
732-
predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace ->
733-
# We're only concerned with this namespaces parameters, so we pair
734-
# down parameters first given the namespace
735-
namespace_params = params[name]
736-
737-
# TODO: How should hooks be handled here?
738-
# TODO: I think we can actually handle parameter freezing and access
739-
# better here by only forwarding params[namespace] to the child function
740-
{out, {state, result_cache}} =
741-
call_predict_cache(
742-
parent_id,
743-
namespace_params,
744-
inputs,
745-
state,
746-
cache,
747-
result_cache,
748-
fn_stacktrace
749-
)
750-
751-
state =
752-
if map_size(state) == 0 do
753-
state
754-
else
755-
%{name => state}
756-
end
757-
758-
{out, {state, result_cache}}
759-
end
760-
761-
init_fun = fn template, cache, result_cache, fn_stacktrace, keys ->
762-
{_parent_template, {namespace_params, result_cache}} =
763-
call_init_cache(parent_id, template, %{}, cache, result_cache, fn_stacktrace, keys)
764-
765-
params =
766-
if namespace_params == %{} do
767-
%{}
768-
else
769-
%{name => namespace_params}
770-
end
771-
772-
{pred_expr, {_, result_cache}} =
773-
predict_fun.(params, template, %{}, cache, result_cache, fn_stacktrace)
774-
775-
{Nx.to_template(pred_expr), {params, result_cache}}
776-
end
777-
778-
model_funs = %{predict: predict_fun, init: init_fun}
779-
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
780-
end
781-
782690
defp recur_model_funs(
783691
%Axon.Node{
784692
id: id,

0 commit comments

Comments
 (0)