Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/compile/piecewise/test_multiple_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
silly_lib = Library("silly_multiple", "FRAGMENT") # noqa

BATCH_SIZE = 32
MLP_SIZE = 128
Expand Down Expand Up @@ -90,7 +90,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre_attn(x)
x = self.rms_norm_ref(x)
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
torch.ops.silly_multiple.attention(x, x, x, attn_output)
x = attn_output
x = self.rms_norm_ref(x)
x = self.post_attn(x)
Expand Down Expand Up @@ -188,7 +188,7 @@ def __init__(self,
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
torch.ops.silly_multiple.attention(x, x, x, attn_output)
x = attn_output
x = x * 3
return x
Expand Down
6 changes: 3 additions & 3 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
global_counter = 0

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
silly_lib = Library("silly_simple", "FRAGMENT") # noqa


def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
Expand Down Expand Up @@ -66,12 +66,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 1
x = x + 2
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
torch.ops.silly_simple.attention(x, x, x, out)
x = out
x = x - 2
x = x - 1
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
torch.ops.silly_simple.attention(x, x, x, out)
x = out
x = x + 1
return x
Expand Down
4 changes: 2 additions & 2 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
silly_lib = Library("silly_toy_llama", "FRAGMENT") # noqa


def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
Expand Down Expand Up @@ -160,7 +160,7 @@ def forward(
k = k + positions.unsqueeze(1)

attn_output = torch.empty_like(q)
torch.ops.silly.attention(q, k, v, attn_output)
torch.ops.silly_toy_llama.attention(q, k, v, attn_output)

output = self.output_projection(attn_output)
return output
Expand Down