-
Notifications
You must be signed in to change notification settings - Fork 90
feat: implement LSTM and GRU operators for torchlib #2674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
5c72122
3d4addc
f4881f4
039ffb0
ea8c549
4163c7e
5411ca0
72ea8a1
1f5e764
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3761,6 +3761,171 @@ | |
| raise NotImplementedError() | ||
|
|
||
|
|
||
| @torch_op("aten::gru.input", trace_only=True) | ||
| def aten_gru( | ||
| input: TFloat, | ||
| hx: TFloat, | ||
| params: Sequence[TFloat], | ||
| has_biases: bool, | ||
| num_layers: int, | ||
| dropout: float, | ||
| train: bool, | ||
| bidirectional: bool, | ||
| batch_first: bool, | ||
| ) -> tuple[TFloat, TFloat]: | ||
| """gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)""" | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Determine number of directions | ||
| num_directions = 2 if bidirectional else 1 | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Get dimensions | ||
| if batch_first: | ||
| # Convert from [batch, seq, input_size] to [seq, batch, input_size] | ||
| input = op.Transpose(input, perm=[1, 0, 2]) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| seq_length = op.Shape(input, start=0, end=1) | ||
Check warningCode scanning / lintrunner PYLINT/W0612 Warning
Unused variable 'seq_length' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable Check warningCode scanning / lintrunner RUFF/F841 Warning
Local variable seq_length is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable |
||
| batch_size = op.Shape(input, start=1, end=2) | ||
Check warningCode scanning / lintrunner PYLINT/W0612 Warning
Unused variable 'batch_size' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable Check warningCode scanning / lintrunner RUFF/F841 Warning
Local variable batch_size is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable |
||
|
||
| input_size = op.Shape(input, start=2, end=3) | ||
Check warningCode scanning / lintrunner PYLINT/W0612 Warning
Unused variable 'input_size' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable Check warningCode scanning / lintrunner RUFF/F841 Warning
Local variable input_size is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable |
||
|
||
| hidden_size = op.Shape(hx, start=2, end=3) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Process each layer | ||
| current_input = input | ||
| output_h_list = [] | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| for layer_idx in range(num_layers): | ||
| # Extract hidden state for this layer | ||
| layer_start = layer_idx * num_directions | ||
| layer_end = (layer_idx + 1) * num_directions | ||
| layer_h = op.Slice(hx, layer_start, layer_end, axes=[0]) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Extract parameters for this layer | ||
| # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction | ||
| params_per_direction = 4 if has_biases else 2 | ||
| params_per_layer = params_per_direction * num_directions | ||
| param_start_idx = layer_idx * params_per_layer | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Build weight matrices for ONNX GRU | ||
| # ONNX expects: W[zrh] shape [num_directions, 3*hidden_size, input_size] | ||
| # PyTorch provides: W_ih shape [3*hidden_size, input_size] | ||
| W_list = [] | ||
| R_list = [] | ||
| B_list = [] if has_biases else None | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| for dir_idx in range(num_directions): | ||
| dir_param_start = param_start_idx + dir_idx * params_per_direction | ||
| W_ih = params[dir_param_start] # [3*hidden_size, input_size] - PyTorch order: [r,z,n] | ||
| W_hh = params[dir_param_start + 1] # [3*hidden_size, hidden_size] - PyTorch order: [r,z,n] | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Reorder gates from PyTorch [r,z,n] to ONNX [z,r,n] | ||
| # Split into individual gates | ||
| W_ir = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) | ||
| W_iz = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) | ||
| W_in = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| W_hr = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) | ||
| W_hz = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) | ||
| W_hn = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Reorder: [z,r,n] | ||
| W_ih_reordered = op.Concat(W_iz, W_ir, W_in, axis=0) # [3*hidden_size, input_size] - ONNX order | ||
| W_hh_reordered = op.Concat(W_hz, W_hr, W_hn, axis=0) # [3*hidden_size, hidden_size] - ONNX order | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Add direction dimension | ||
| W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 3*hidden_size, input_size] | ||
| W_hh_expanded = op.Unsqueeze(W_hh_reordered, [0]) # [1, 3*hidden_size, hidden_size] | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| W_list.append(W_ih_expanded) | ||
| R_list.append(W_hh_expanded) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| if has_biases: | ||
| b_ih = params[dir_param_start + 2] # [3*hidden_size] - PyTorch order: [r,z,n] | ||
| b_hh = params[dir_param_start + 3] # [3*hidden_size] - PyTorch order: [r,z,n] | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Reorder biases from PyTorch [r,z,n] to ONNX [z,r,n] | ||
| b_ir = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) | ||
| b_iz = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) | ||
| b_in = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| b_hr = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) | ||
| b_hz = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) | ||
| b_hn = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Reorder: [z,r,n] | ||
| b_ih_reordered = op.Concat(b_iz, b_ir, b_in, axis=0) # [3*hidden_size] - ONNX order | ||
| b_hh_reordered = op.Concat(b_hz, b_hr, b_hn, axis=0) # [3*hidden_size] - ONNX order | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # ONNX expects biases concatenated: [Wb[zrh], Rb[zrh]] | ||
| b_combined = op.Concat(b_ih_reordered, b_hh_reordered, axis=0) # [6*hidden_size] | ||
| b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 6*hidden_size] | ||
| B_list.append(b_expanded) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Concatenate weights for all directions | ||
| W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] | ||
| R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] | ||
| B = op.Concat(*B_list, axis=0) if has_biases and len(B_list) > 1 else (B_list[0] if has_biases else None) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Call ONNX GRU operator | ||
| direction = "bidirectional" if bidirectional else "forward" | ||
|
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
| # Extract hidden_size from hx shape: [num_layers * num_directions, batch, hidden_size] | ||
| hidden_size_attr = hx.shape[2] | ||
|
|
||
| if B is not None: | ||
| Y, Y_h = op.GRU( | ||
| current_input, | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| W, | ||
| R, | ||
| B, | ||
| initial_h=layer_h, | ||
| direction=direction, | ||
| hidden_size=hidden_size_attr, | ||
| ) | ||
| else: | ||
| Y, Y_h = op.GRU( | ||
| current_input, | ||
| W, | ||
| R, | ||
| initial_h=layer_h, | ||
| direction=direction, | ||
| hidden_size=hidden_size_attr, | ||
| ) | ||
|
|
||
| # Y shape: [seq_length, num_directions, batch_size, hidden_size] | ||
| # Reshape to [seq_length, batch_size, num_directions * hidden_size] | ||
| Y = op.Transpose(Y, perm=[0, 2, 1, 3]) # [seq_length, batch_size, num_directions, hidden_size] | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| Y_shape = op.Shape(Y) | ||
| new_shape = op.Concat( | ||
| op.Slice(Y_shape, [0], [1]), # seq_length | ||
| op.Slice(Y_shape, [1], [2]), # batch_size | ||
| op.Reshape( | ||
| op.Mul( | ||
| op.Slice(Y_shape, [2], [3]), # num_directions | ||
| op.Slice(Y_shape, [3], [4]), # hidden_size | ||
| ), | ||
| op.Constant(value_ints=[-1]), | ||
| ), | ||
| axis=0, | ||
| ) | ||
| current_input = op.Reshape(Y, new_shape) | ||
|
|
||
| # Apply dropout if not last layer and dropout > 0 | ||
| if layer_idx < num_layers - 1 and dropout > 0.0 and train: | ||
| current_input, _ = op.Dropout(current_input, dropout, train) | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
|
|
||
| # Store final hidden state | ||
| output_h_list.append(Y_h) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Concatenate all layer outputs | ||
| final_h = output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Handle batch_first for output | ||
| if batch_first: | ||
| # Convert from [seq, batch, features] to [batch, seq, features] | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| current_input = op.Transpose(current_input, perm=[1, 0, 2]) | ||
|
|
||
| return current_input, final_h | ||
|
|
||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| @torch_op(("_operator::getitem", "aten::getitem")) | ||
| def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor: | ||
| return op.SequenceAt(self, i) | ||
|
|
@@ -4991,6 +5156,185 @@ | |
| raise NotImplementedError() | ||
|
|
||
|
|
||
| @torch_op("aten::lstm.input", trace_only=True) | ||
| def aten_lstm( | ||
| input: TFloat, | ||
| hx: Sequence[TFloat], | ||
| params: Sequence[TFloat], | ||
| has_biases: bool, | ||
| num_layers: int, | ||
| dropout: float, | ||
| train: bool, | ||
| bidirectional: bool, | ||
| batch_first: bool, | ||
| ) -> tuple[TFloat, TFloat, TFloat]: | ||
| """lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)""" | ||
|
|
||
| # Extract initial hidden and cell states | ||
| initial_h = hx[0] # Shape: [num_directions * num_layers, batch_size, hidden_size] | ||
| initial_c = hx[1] # Shape: [num_directions * num_layers, batch_size, hidden_size] | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
|
|
||
| # Determine number of directions | ||
| num_directions = 2 if bidirectional else 1 | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| # Get dimensions | ||
| if batch_first: | ||
| # Convert from [batch, seq, input_size] to [seq, batch, input_size] | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| input = op.Transpose(input, perm=[1, 0, 2]) | ||
|
|
||
| seq_length = op.Shape(input, start=0, end=1) | ||
|
||
| batch_size = op.Shape(input, start=1, end=2) | ||
|
||
| input_size = op.Shape(input, start=2, end=3) | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
|
||
| hidden_size = op.Shape(initial_h, start=2, end=3) | ||
Check warningCode scanning / lintrunner PYLINT/W0612 Warning
Unused variable 'seq_length' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable Check warningCode scanning / lintrunner RUFF/F841 Warning
Local variable seq_length is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable |
||
|
|
||
Check warningCode scanning / lintrunner PYLINT/W0612 Warning
Unused variable 'batch_size' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable Check warningCode scanning / lintrunner RUFF/F841 Warning
Local variable batch_size is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable |
||
| # Process each layer | ||
Check warningCode scanning / lintrunner PYLINT/W0612 Warning
Unused variable 'input_size' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable Check warningCode scanning / lintrunner RUFF/F841 Warning
Local variable input_size is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable |
||
| current_input = input | ||
| output_h_list = [] | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| output_c_list = [] | ||
|
|
||
| for layer_idx in range(num_layers): | ||
| # Extract hidden and cell states for this layer | ||
| layer_start = layer_idx * num_directions | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| layer_end = (layer_idx + 1) * num_directions | ||
| layer_h = op.Slice(initial_h, layer_start, layer_end, axes=[0]) | ||
| layer_c = op.Slice(initial_c, layer_start, layer_end, axes=[0]) | ||
|
|
||
| # Extract parameters for this layer | ||
| # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction | ||
| params_per_direction = 4 if has_biases else 2 | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| params_per_layer = params_per_direction * num_directions | ||
| param_start_idx = layer_idx * params_per_layer | ||
|
|
||
| # Build weight matrices for ONNX LSTM | ||
| # ONNX expects: W[iofc] shape [num_directions, 4*hidden_size, input_size] | ||
| # PyTorch provides: W_ih shape [4*hidden_size, input_size] | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| W_list = [] | ||
| R_list = [] | ||
| B_list = [] if has_biases else None | ||
|
|
||
| for dir_idx in range(num_directions): | ||
| dir_param_start = param_start_idx + dir_idx * params_per_direction | ||
| W_ih = params[dir_param_start] # [4*hidden_size, input_size] - PyTorch order: [i,f,g,o] | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| W_hh = params[dir_param_start + 1] # [4*hidden_size, hidden_size] - PyTorch order: [i,f,g,o] | ||
|
|
||
| # Reorder gates from PyTorch [i,f,g,o] to ONNX [i,o,f,g] | ||
| # Split into individual gates | ||
| W_ii = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| W_if = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) | ||
| W_ig = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) | ||
| W_io = op.Slice(W_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) | ||
|
|
||
| W_hi = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) | ||
| W_hf = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) | ||
| W_hg = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| W_ho = op.Slice(W_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) | ||
|
|
||
| # Reorder: [i,o,f,g] | ||
| W_ih_reordered = op.Concat(W_ii, W_io, W_if, W_ig, axis=0) # [4*hidden_size, input_size] - ONNX order | ||
| W_hh_reordered = op.Concat(W_hi, W_ho, W_hf, W_hg, axis=0) # [4*hidden_size, hidden_size] - ONNX order | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
|
|
||
| # Add direction dimension | ||
| W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 4*hidden_size, input_size] | ||
| W_hh_expanded = op.Unsqueeze(W_hh_reordered, [0]) # [1, 4*hidden_size, hidden_size] | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
|
|
||
| W_list.append(W_ih_expanded) | ||
| R_list.append(W_hh_expanded) | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| if has_biases: | ||
| b_ih = params[dir_param_start + 2] # [4*hidden_size] - PyTorch order: [i,f,g,o] | ||
| b_hh = params[dir_param_start + 3] # [4*hidden_size] - PyTorch order: [i,f,g,o] | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
|
|
||
| # Reorder biases from PyTorch [i,f,g,o] to ONNX [i,o,f,g] | ||
| b_ii = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) | ||
| b_if = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) | ||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
| b_ig = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) | ||
| b_io = op.Slice(b_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) | ||
|
|
||
| b_hi = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) | ||
| b_hf = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) | ||
| b_hg = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) | ||
| b_ho = op.Slice(b_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) | ||
|
|
||
| # Reorder: [i,o,f,g] | ||
| b_ih_reordered = op.Concat(b_ii, b_io, b_if, b_ig, axis=0) # [4*hidden_size] - ONNX order | ||
| b_hh_reordered = op.Concat(b_hi, b_ho, b_hf, b_hg, axis=0) # [4*hidden_size] - ONNX order | ||
|
|
||
| # ONNX expects biases concatenated: [Wb[iofc], Rb[iofc]] | ||
| b_combined = op.Concat(b_ih_reordered, b_hh_reordered, axis=0) # [8*hidden_size] | ||
| b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 8*hidden_size] | ||
| B_list.append(b_expanded) | ||
|
|
||
| # Concatenate weights for all directions | ||
| W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] | ||
| R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] | ||
| B = op.Concat(*B_list, axis=0) if has_biases and len(B_list) > 1 else (B_list[0] if has_biases else None) | ||
|
|
||
| # Call ONNX LSTM operator | ||
| direction = "bidirectional" if bidirectional else "forward" | ||
|
|
||
| # Extract hidden_size from initial_h shape: [num_layers * num_directions, batch, hidden_size] | ||
| hidden_size_attr = initial_h.shape[2] | ||
|
|
||
| if B is not None: | ||
| Y, Y_h, Y_c = op.LSTM( | ||
| current_input, | ||
| W, | ||
| R, | ||
| B, | ||
| initial_h=layer_h, | ||
| initial_c=layer_c, | ||
| direction=direction, | ||
| hidden_size=hidden_size_attr, | ||
| ) | ||
| else: | ||
| Y, Y_h, Y_c = op.LSTM( | ||
| current_input, | ||
| W, | ||
| R, | ||
| initial_h=layer_h, | ||
| initial_c=layer_c, | ||
| direction=direction, | ||
| hidden_size=hidden_size_attr, | ||
| ) | ||
|
|
||
| # Y shape: [seq_length, num_directions, batch_size, hidden_size] | ||
| # Reshape to [seq_length, batch_size, num_directions * hidden_size] | ||
| Y = op.Transpose(Y, perm=[0, 2, 1, 3]) # [seq_length, batch_size, num_directions, hidden_size] | ||
| Y_shape = op.Shape(Y) | ||
| new_shape = op.Concat( | ||
| op.Slice(Y_shape, [0], [1]), # seq_length | ||
| op.Slice(Y_shape, [1], [2]), # batch_size | ||
| op.Reshape( | ||
| op.Mul( | ||
| op.Slice(Y_shape, [2], [3]), # num_directions | ||
| op.Slice(Y_shape, [3], [4]), # hidden_size | ||
| ), | ||
| op.Constant(value_ints=[-1]), | ||
| ), | ||
| axis=0, | ||
| ) | ||
| current_input = op.Reshape(Y, new_shape) | ||
|
|
||
| # Apply dropout if not last layer and dropout > 0 | ||
| if layer_idx < num_layers - 1 and dropout > 0.0 and train: | ||
| current_input, _ = op.Dropout(current_input, dropout, train) | ||
|
|
||
| # Store final hidden and cell states | ||
| output_h_list.append(Y_h) | ||
| output_c_list.append(Y_c) | ||
|
|
||
| # Concatenate all layer outputs | ||
| final_h = output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) | ||
| final_c = output_c_list[0] if len(output_c_list) == 1 else op.Concat(*output_c_list, axis=0) | ||
|
|
||
| # Handle batch_first for output | ||
| if batch_first: | ||
| # Convert from [seq, batch, features] to [batch, seq, features] | ||
| current_input = op.Transpose(current_input, perm=[1, 0, 2]) | ||
|
|
||
| return current_input, final_h, final_c | ||
|
|
||
|
|
||
| @torch_op( | ||
| ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), | ||
| trace_only=True, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.