Skip to content
344 changes: 344 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3761,6 +3761,171 @@
raise NotImplementedError()


@torch_op("aten::gru.input", trace_only=True)
def aten_gru(
input: TFloat,
hx: TFloat,
params: Sequence[TFloat],
has_biases: bool,
num_layers: int,
dropout: float,
train: bool,
bidirectional: bool,
batch_first: bool,
) -> tuple[TFloat, TFloat]:
"""gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)"""

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Determine number of directions
num_directions = 2 if bidirectional else 1

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Get dimensions
if batch_first:
# Convert from [batch, seq, input_size] to [seq, batch, input_size]
input = op.Transpose(input, perm=[1, 0, 2])

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
seq_length = op.Shape(input, start=0, end=1)

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'seq_length' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable seq_length is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
batch_size = op.Shape(input, start=1, end=2)

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'batch_size' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable batch_size is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
input_size = op.Shape(input, start=2, end=3)

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'input_size' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable input_size is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
hidden_size = op.Shape(hx, start=2, end=3)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Process each layer
current_input = input
output_h_list = []

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
for layer_idx in range(num_layers):
# Extract hidden state for this layer
layer_start = layer_idx * num_directions
layer_end = (layer_idx + 1) * num_directions
layer_h = op.Slice(hx, layer_start, layer_end, axes=[0])

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Extract parameters for this layer
# Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction
params_per_direction = 4 if has_biases else 2
params_per_layer = params_per_direction * num_directions
param_start_idx = layer_idx * params_per_layer

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Build weight matrices for ONNX GRU
# ONNX expects: W[zrh] shape [num_directions, 3*hidden_size, input_size]
# PyTorch provides: W_ih shape [3*hidden_size, input_size]
W_list = []
R_list = []
B_list = [] if has_biases else None

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
for dir_idx in range(num_directions):
dir_param_start = param_start_idx + dir_idx * params_per_direction
W_ih = params[dir_param_start] # [3*hidden_size, input_size] - PyTorch order: [r,z,n]
W_hh = params[dir_param_start + 1] # [3*hidden_size, hidden_size] - PyTorch order: [r,z,n]

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Reorder gates from PyTorch [r,z,n] to ONNX [z,r,n]
# Split into individual gates
W_ir = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0])
W_iz = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0])
W_in = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0])

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
W_hr = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0])
W_hz = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0])
W_hn = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0])

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Reorder: [z,r,n]
W_ih_reordered = op.Concat(W_iz, W_ir, W_in, axis=0) # [3*hidden_size, input_size] - ONNX order
W_hh_reordered = op.Concat(W_hz, W_hr, W_hn, axis=0) # [3*hidden_size, hidden_size] - ONNX order

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Add direction dimension
W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 3*hidden_size, input_size]
W_hh_expanded = op.Unsqueeze(W_hh_reordered, [0]) # [1, 3*hidden_size, hidden_size]

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
W_list.append(W_ih_expanded)
R_list.append(W_hh_expanded)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
if has_biases:
b_ih = params[dir_param_start + 2] # [3*hidden_size] - PyTorch order: [r,z,n]
b_hh = params[dir_param_start + 3] # [3*hidden_size] - PyTorch order: [r,z,n]

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Reorder biases from PyTorch [r,z,n] to ONNX [z,r,n]
b_ir = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0])
b_iz = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0])
b_in = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0])

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
b_hr = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0])
b_hz = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0])
b_hn = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0])

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Reorder: [z,r,n]
b_ih_reordered = op.Concat(b_iz, b_ir, b_in, axis=0) # [3*hidden_size] - ONNX order
b_hh_reordered = op.Concat(b_hz, b_hr, b_hn, axis=0) # [3*hidden_size] - ONNX order

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# ONNX expects biases concatenated: [Wb[zrh], Rb[zrh]]
b_combined = op.Concat(b_ih_reordered, b_hh_reordered, axis=0) # [6*hidden_size]
b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 6*hidden_size]
B_list.append(b_expanded)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Concatenate weights for all directions
W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0]
R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0]
B = op.Concat(*B_list, axis=0) if has_biases and len(B_list) > 1 else (B_list[0] if has_biases else None)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Call ONNX GRU operator
direction = "bidirectional" if bidirectional else "forward"

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

# Extract hidden_size from hx shape: [num_layers * num_directions, batch, hidden_size]
hidden_size_attr = hx.shape[2]

if B is not None:
Y, Y_h = op.GRU(
current_input,

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
W,
R,
B,
initial_h=layer_h,
direction=direction,
hidden_size=hidden_size_attr,
)
else:
Y, Y_h = op.GRU(
current_input,
W,
R,
initial_h=layer_h,
direction=direction,
hidden_size=hidden_size_attr,
)

# Y shape: [seq_length, num_directions, batch_size, hidden_size]
# Reshape to [seq_length, batch_size, num_directions * hidden_size]
Y = op.Transpose(Y, perm=[0, 2, 1, 3]) # [seq_length, batch_size, num_directions, hidden_size]

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
Y_shape = op.Shape(Y)
new_shape = op.Concat(
op.Slice(Y_shape, [0], [1]), # seq_length
op.Slice(Y_shape, [1], [2]), # batch_size
op.Reshape(
op.Mul(
op.Slice(Y_shape, [2], [3]), # num_directions
op.Slice(Y_shape, [3], [4]), # hidden_size
),
op.Constant(value_ints=[-1]),
),
axis=0,
)
current_input = op.Reshape(Y, new_shape)

# Apply dropout if not last layer and dropout > 0
if layer_idx < num_layers - 1 and dropout > 0.0 and train:
current_input, _ = op.Dropout(current_input, dropout, train)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

# Store final hidden state
output_h_list.append(Y_h)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Concatenate all layer outputs
final_h = output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Handle batch_first for output
if batch_first:
# Convert from [seq, batch, features] to [batch, seq, features]

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
current_input = op.Transpose(current_input, perm=[1, 0, 2])

return current_input, final_h


Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
@torch_op(("_operator::getitem", "aten::getitem"))
def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor:
return op.SequenceAt(self, i)
Expand Down Expand Up @@ -4991,6 +5156,185 @@
raise NotImplementedError()


@torch_op("aten::lstm.input", trace_only=True)
def aten_lstm(
input: TFloat,
hx: Sequence[TFloat],
params: Sequence[TFloat],
has_biases: bool,
num_layers: int,
dropout: float,
train: bool,
bidirectional: bool,
batch_first: bool,
) -> tuple[TFloat, TFloat, TFloat]:
"""lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)"""

# Extract initial hidden and cell states
initial_h = hx[0] # Shape: [num_directions * num_layers, batch_size, hidden_size]
initial_c = hx[1] # Shape: [num_directions * num_layers, batch_size, hidden_size]

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

# Determine number of directions
num_directions = 2 if bidirectional else 1

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Get dimensions
if batch_first:
# Convert from [batch, seq, input_size] to [seq, batch, input_size]

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
input = op.Transpose(input, perm=[1, 0, 2])

seq_length = op.Shape(input, start=0, end=1)
batch_size = op.Shape(input, start=1, end=2)
input_size = op.Shape(input, start=2, end=3)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
hidden_size = op.Shape(initial_h, start=2, end=3)

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'seq_length' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable seq_length is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'batch_size' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable batch_size is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
# Process each layer

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'input_size' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable input_size is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
current_input = input
output_h_list = []

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
output_c_list = []

for layer_idx in range(num_layers):
# Extract hidden and cell states for this layer
layer_start = layer_idx * num_directions

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
layer_end = (layer_idx + 1) * num_directions
layer_h = op.Slice(initial_h, layer_start, layer_end, axes=[0])
layer_c = op.Slice(initial_c, layer_start, layer_end, axes=[0])

# Extract parameters for this layer
# Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction
params_per_direction = 4 if has_biases else 2

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
params_per_layer = params_per_direction * num_directions
param_start_idx = layer_idx * params_per_layer

# Build weight matrices for ONNX LSTM
# ONNX expects: W[iofc] shape [num_directions, 4*hidden_size, input_size]
# PyTorch provides: W_ih shape [4*hidden_size, input_size]

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
W_list = []
R_list = []
B_list = [] if has_biases else None

for dir_idx in range(num_directions):
dir_param_start = param_start_idx + dir_idx * params_per_direction
W_ih = params[dir_param_start] # [4*hidden_size, input_size] - PyTorch order: [i,f,g,o]

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
W_hh = params[dir_param_start + 1] # [4*hidden_size, hidden_size] - PyTorch order: [i,f,g,o]

# Reorder gates from PyTorch [i,f,g,o] to ONNX [i,o,f,g]
# Split into individual gates
W_ii = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0])

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
W_if = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0])
W_ig = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0])
W_io = op.Slice(W_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0])

W_hi = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0])
W_hf = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0])
W_hg = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0])

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
W_ho = op.Slice(W_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0])

# Reorder: [i,o,f,g]
W_ih_reordered = op.Concat(W_ii, W_io, W_if, W_ig, axis=0) # [4*hidden_size, input_size] - ONNX order
W_hh_reordered = op.Concat(W_hi, W_ho, W_hf, W_hg, axis=0) # [4*hidden_size, hidden_size] - ONNX order

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

# Add direction dimension
W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 4*hidden_size, input_size]
W_hh_expanded = op.Unsqueeze(W_hh_reordered, [0]) # [1, 4*hidden_size, hidden_size]

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

W_list.append(W_ih_expanded)
R_list.append(W_hh_expanded)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
if has_biases:
b_ih = params[dir_param_start + 2] # [4*hidden_size] - PyTorch order: [i,f,g,o]
b_hh = params[dir_param_start + 3] # [4*hidden_size] - PyTorch order: [i,f,g,o]

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

# Reorder biases from PyTorch [i,f,g,o] to ONNX [i,o,f,g]
b_ii = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0])
b_if = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0])

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
b_ig = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0])
b_io = op.Slice(b_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0])

b_hi = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0])
b_hf = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0])
b_hg = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0])
b_ho = op.Slice(b_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0])

# Reorder: [i,o,f,g]
b_ih_reordered = op.Concat(b_ii, b_io, b_if, b_ig, axis=0) # [4*hidden_size] - ONNX order
b_hh_reordered = op.Concat(b_hi, b_ho, b_hf, b_hg, axis=0) # [4*hidden_size] - ONNX order

# ONNX expects biases concatenated: [Wb[iofc], Rb[iofc]]
b_combined = op.Concat(b_ih_reordered, b_hh_reordered, axis=0) # [8*hidden_size]
b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 8*hidden_size]
B_list.append(b_expanded)

# Concatenate weights for all directions
W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0]
R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0]
B = op.Concat(*B_list, axis=0) if has_biases and len(B_list) > 1 else (B_list[0] if has_biases else None)

# Call ONNX LSTM operator
direction = "bidirectional" if bidirectional else "forward"

# Extract hidden_size from initial_h shape: [num_layers * num_directions, batch, hidden_size]
hidden_size_attr = initial_h.shape[2]

if B is not None:
Y, Y_h, Y_c = op.LSTM(
current_input,
W,
R,
B,
initial_h=layer_h,
initial_c=layer_c,
direction=direction,
hidden_size=hidden_size_attr,
)
else:
Y, Y_h, Y_c = op.LSTM(
current_input,
W,
R,
initial_h=layer_h,
initial_c=layer_c,
direction=direction,
hidden_size=hidden_size_attr,
)

# Y shape: [seq_length, num_directions, batch_size, hidden_size]
# Reshape to [seq_length, batch_size, num_directions * hidden_size]
Y = op.Transpose(Y, perm=[0, 2, 1, 3]) # [seq_length, batch_size, num_directions, hidden_size]
Y_shape = op.Shape(Y)
new_shape = op.Concat(
op.Slice(Y_shape, [0], [1]), # seq_length
op.Slice(Y_shape, [1], [2]), # batch_size
op.Reshape(
op.Mul(
op.Slice(Y_shape, [2], [3]), # num_directions
op.Slice(Y_shape, [3], [4]), # hidden_size
),
op.Constant(value_ints=[-1]),
),
axis=0,
)
current_input = op.Reshape(Y, new_shape)

# Apply dropout if not last layer and dropout > 0
if layer_idx < num_layers - 1 and dropout > 0.0 and train:
current_input, _ = op.Dropout(current_input, dropout, train)

# Store final hidden and cell states
output_h_list.append(Y_h)
output_c_list.append(Y_c)

# Concatenate all layer outputs
final_h = output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0)
final_c = output_c_list[0] if len(output_c_list) == 1 else op.Concat(*output_c_list, axis=0)

# Handle batch_first for output
if batch_first:
# Convert from [seq, batch, features] to [batch, seq, features]
current_input = op.Transpose(current_input, perm=[1, 0, 2])

return current_input, final_h, final_c


@torch_op(
("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"),
trace_only=True,
Expand Down
Loading
Loading