Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,22 @@ def torch_matmul(input, other, *, out=None):

def torch_bmm(input, mat2, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
raise ValueError("Don't support in-place bmm for MetaTensor analysis")
batch_size, n, m = input.shape
_, _, p = mat2.shape
return torch.empty(batch_size, n, p, device="meta")


def torch_baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place baddbmm for MetaTensor analysis")
return torch_bmm(batch1, batch2)


def torch_tensor_baddbmm(self, batch1, batch2, *, beta=1, alpha=1, out=None):
return torch_baddbmm(self, batch1, batch2, beta=beta, alpha=alpha, out=out)


def torch_einsum(equation, *operands):
# TODO: infer shape without performing the computation, this might be quite hard.
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
Expand Down Expand Up @@ -495,6 +505,8 @@ def to_concrete(t):
torch.Tensor.mul: torch_tensor_mul,
torch.matmul: torch_matmul,
torch.bmm: torch_bmm,
torch.baddbmm: torch_baddbmm,
torch.Tensor.baddbmm: torch_tensor_baddbmm,
torch.einsum: torch_einsum,
torch.Tensor.repeat: torch_tensor_repeat,
torch.roll: torch_roll,
Expand Down