Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,7 +1813,7 @@ def rms_norm(

.. math::

out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight + bias
out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight

Parameters
----------
Expand All @@ -1823,9 +1823,6 @@ def rms_norm(
weight : relax.Expr
The scale factor.

bias : relax.Expr
The offset factor.

axes : Union[int, List[int]]
The axes that along which the normalization is applied.

Expand Down
9 changes: 4 additions & 5 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -848,13 +848,12 @@ InferLayoutOutput InferLayoutRMSNorm(const Call& call,

LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
ObjectPtr<RMSNormAttrs> new_attrs = make_object<RMSNormAttrs>(*attrs);
std::vector<Integer> new_axis;
std::vector<Integer> new_axes;
for (const auto& axis : attrs->axes) {
new_axis.push_back(FindAxis(layout->layout, axis->value));
new_axes.push_back(FindAxis(layout->layout, axis->value));
}
new_attrs->axes = std::move(new_axis);
return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout},
Attrs(new_attrs));
new_attrs->axes = std::move(new_axes);
return InferLayoutOutput({layout, initial_layouts[1]}, {layout}, Attrs(new_attrs));
}

TVM_REGISTER_OP("relax.nn.rms_norm")
Expand Down