Skip to content

Commit

Permalink
Prevent erroneous deduping of the full op (#1018)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1018

The `full` op can be erroneously deduped in the `refine_graph` when the shape and / or dtype of the output are different, but the fill value is the same. Here we add the former to the `_attrs` of the op to avoid this erroneous deduping.

Reviewed By: ColinPeppler, 22quinn

Differential Revision: D60990475

fbshipit-source-id: 490c7327cb542650e0c8892cdeed75bdb7848dc4
  • Loading branch information
aakhundov authored and facebook-github-bot committed Aug 9, 2024
1 parent dc13d36 commit 2aef297
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/aitemplate/compiler/ops/tensor/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def __call__(
self._attrs["inputs"] = []
self._attrs["fill_value"] = fill_value

# although not used downstream, these attrs
# are necessary to avoid erroneously deduping
# legitimately different fill op instances
self._attrs["shape"] = shape
self._attrs["dtype"] = dtype

self._set_depth()
output = Tensor(
shape, src_ops={self}, dtype=dtype, skip_constant_folding=not static_shape
Expand Down
61 changes: 61 additions & 0 deletions tests/unittest/compiler/test_refine_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
filter_test_cases_by_params,
get_random_torch_tensor,
get_torch_empty_tensor,
get_torch_full_tensor,
TestEnv,
)
from aitemplate.utils import graph_utils
Expand Down Expand Up @@ -371,6 +372,66 @@ def test_refine_graph_group_gemms(self, dtype):
self.assertTrue(torch.allclose(Y1_pt, outputs["y3"], atol=1e-1, rtol=1e-1))
self.assertTrue(torch.allclose(Y2_pt, outputs["y4"], atol=1e-1, rtol=1e-1))

@parameterized.expand(
**filter_test_cases_by_params(
{
TestEnv.CUDA_SM80: [
(0, [10, 20], [10, 20], 1.0, 1.0, "float16", "float16"),
(1, [10, 20], [100, 20], 1.0, 1.0, "float16", "float16"),
(2, [10, 20], [10, 20], 1.0, 3.14, "float16", "float16"),
(3, [10, 20], [100, 20], 1.0, 3.14, "float16", "float16"),
(4, [10, 20], [10, 20], 1.0, 1.0, "float16", "float32"),
(5, [10, 20], [100, 20], 1.0, 1.0, "float16", "float32"),
(6, [10, 20], [10, 20], 1.0, 3.14, "float16", "float32"),
(7, [10, 20], [100, 20], 1.0, 3.14, "float16", "float32"),
],
}
)
)
def test_refine_graph_full(
self,
test_id,
shape1,
shape2,
fill_value1,
fill_value2,
dtype1,
dtype2,
):
Y1 = ops.full()(
shape=shape1,
fill_value=fill_value1,
dtype=dtype1,
)
Y2 = ops.full()(
shape=shape2,
fill_value=fill_value2,
dtype=dtype2,
)
Y1._attrs["name"] = "Y1"
Y1._attrs["is_output"] = True
Y2._attrs["name"] = "Y2"
Y2._attrs["is_output"] = True

module = compile_model(
[Y1, Y2],
detect_target(),
"./tmp",
f"test_refine_graph_full_{test_id}",
)

inputs = {}
outputs = {}
outputs["Y1"] = get_torch_empty_tensor(shape1, dtype1)
outputs["Y2"] = get_torch_empty_tensor(shape2, dtype2)

module.run_with_tensors(inputs, outputs)
y1 = get_torch_full_tensor(shape1, fill_value1, dtype1)
y2 = get_torch_full_tensor(shape2, fill_value2, dtype2)

torch.testing.assert_close(y1, outputs["Y1"])
torch.testing.assert_close(y2, outputs["Y2"])


if __name__ == "__main__":
unittest.main()

0 comments on commit 2aef297

Please sign in to comment.