Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/windows_webgpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ jobs:
strategy:
matrix:
vcpkg_option: [novcpkg, vcpkg]
wgsl_template: [static, dynamic]
env:
OrtPackageId: Microsoft.ML.OnnxRuntime
OnnxRuntimeBuildDirectory: ${{ github.workspace }}
Expand Down Expand Up @@ -124,7 +123,6 @@ jobs:
--build_nodejs `
--build_java `
--use_webgpu `
--wgsl_template ${{ matrix.wgsl_template }} `
${{ matrix.vcpkg_option == 'vcpkg' && '--use_vcpkg' || '' }} `
--cmake_extra_defines `
onnxruntime_BUILD_UNIT_TESTS=ON `
Expand Down
1 change: 0 additions & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OF
option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF)
option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF)
option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF)
option(onnxruntime_CLIENT_PACKAGE_BUILD "Enables default settings that are more appropriate for client/on-device workloads." OFF)
cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF)
# For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone
cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF)
Expand Down
5 changes: 0 additions & 5 deletions cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ if (onnxruntime_MINIMAL_BUILD)
endif()
endif()

# ORT build with default settings more appropriate for client/on-device workloads.
if (onnxruntime_CLIENT_PACKAGE_BUILD)
add_compile_definitions(ORT_CLIENT_PACKAGE_BUILD)
endif()

if (onnxruntime_ENABLE_LTO)
include(CheckIPOSupported)
check_ipo_supported(RESULT ipo_enabled OUTPUT ipo_output)
Expand Down
25 changes: 7 additions & 18 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -774,24 +774,13 @@ if (onnxruntime_USE_WEBGPU)
endif()

if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic")
if(onnxruntime_USE_VCPKG)
find_package(unofficial-duktape CONFIG REQUIRED)
add_library(duktape_static ALIAS unofficial::duktape::duktape)
else()
onnxruntime_fetchcontent_declare(
duktape
URL ${DEP_URL_duktape}
URL_HASH SHA1=${DEP_SHA1_duktape}
EXCLUDE_FROM_ALL
)
onnxruntime_fetchcontent_makeavailable(duktape)

if(NOT TARGET duktape_static)
add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c")
target_compile_features(duktape_static PRIVATE c_std_99)
target_include_directories(duktape_static INTERFACE $<BUILD_INTERFACE:${duktape_SOURCE_DIR}/src>)
endif()
endif()
onnxruntime_fetchcontent_declare(
duktape
URL ${DEP_URL_duktape}
URL_HASH SHA1=${DEP_SHA1_duktape}
EXCLUDE_FROM_ALL
)
onnxruntime_fetchcontent_makeavailable(duktape)
endif()
endif()

Expand Down
1 change: 0 additions & 1 deletion cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/eltwise.cpp
${MLAS_SRC_DIR}/erf.cpp
${MLAS_SRC_DIR}/compute.cpp
${MLAS_SRC_DIR}/dequantize.cpp
${MLAS_SRC_DIR}/quantize.cpp
${MLAS_SRC_DIR}/qgemm_kernel_default.cpp
${MLAS_SRC_DIR}/qladd.cpp
Expand Down
23 changes: 18 additions & 5 deletions cmake/onnxruntime_providers_tensorrt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,26 @@
endif()

# TensorRT 10 GA onwards, the TensorRT libraries will have major version appended to the end on Windows,
# for example, nvinfer_10.dll, nvonnxparser_10.dll ...
# for example, nvinfer_10.dll, nvinfer_plugin_10.dll, nvonnxparser_10.dll ...
if (WIN32 AND TRT_GREATER_OR_EQUAL_TRT_10_GA)
set(NVINFER_LIB "nvinfer_${NV_TENSORRT_MAJOR}")
set(NVINFER_PLUGIN_LIB "nvinfer_plugin_${NV_TENSORRT_MAJOR}")
set(PARSER_LIB "nvonnxparser_${NV_TENSORRT_MAJOR}")
endif()

if (NOT NVINFER_LIB)
set(NVINFER_LIB "nvinfer")
endif()

if (NOT NVINFER_PLUGIN_LIB)
set(NVINFER_PLUGIN_LIB "nvinfer_plugin")
endif()

if (NOT PARSER_LIB)
set(PARSER_LIB "nvonnxparser")
endif()

MESSAGE(STATUS "Looking for ${NVINFER_LIB}")
MESSAGE(STATUS "Looking for ${NVINFER_LIB} and ${NVINFER_PLUGIN_LIB}")

find_library(TENSORRT_LIBRARY_INFER ${NVINFER_LIB}
HINTS ${TENSORRT_ROOT}
Expand All @@ -96,6 +101,14 @@
MESSAGE(STATUS "Can't find ${NVINFER_LIB}")
endif()

find_library(TENSORRT_LIBRARY_INFER_PLUGIN ${NVINFER_PLUGIN_LIB}
HINTS ${TENSORRT_ROOT}
PATH_SUFFIXES lib lib64 lib/x64)

if (NOT TENSORRT_LIBRARY_INFER_PLUGIN)
MESSAGE(STATUS "Can't find ${NVINFER_PLUGIN_LIB}")
endif()

if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
MESSAGE(STATUS "Looking for ${PARSER_LIB}")

Expand All @@ -107,7 +120,7 @@
MESSAGE(STATUS "Can't find ${PARSER_LIB}")
endif()

set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_NVONNXPARSER})
set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_NVONNXPARSER})
MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}")
else()
if (TRT_GREATER_OR_EQUAL_TRT_10_GA)
Expand Down Expand Up @@ -140,15 +153,15 @@
endif()
# Static libraries are just nvonnxparser_static on all platforms
set(onnxparser_link_libs nvonnxparser_static)
set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER})
set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN})
MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}")
endif()

# ${TENSORRT_LIBRARY} is empty if we link nvonnxparser_static.
# nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt
# See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121
# However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries.
# Therefore, the above code finds ${TENSORRT_LIBRARY_INFER}.
# Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}.
if(onnxruntime_CUDA_MINIMAL)
set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
else()
Expand Down
11 changes: 5 additions & 6 deletions cmake/onnxruntime_providers_webgpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,10 @@
file(MAKE_DIRECTORY ${WGSL_GENERATED_DIR})

# Find all WGSL template input files
file(GLOB_RECURSE WGSL_TEMPLATE_FILES
"${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template"
"${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.wgsl.template")
file(GLOB_RECURSE WGSL_TEMPLATE_FILES "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template")

# Set wgsl-gen command line options as a list
set(WGSL_GEN_OPTIONS "-i" "${ONNXRUNTIME_ROOT}/core/providers/webgpu/" "-i" "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose")
set(WGSL_GEN_OPTIONS "-i" "../" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose")
if (onnxruntime_WGSL_TEMPLATE STREQUAL "static")
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
list(APPEND WGSL_GEN_OPTIONS "--generator" "static-cpp-literal")
Expand Down Expand Up @@ -209,9 +207,10 @@
# Add the generated directory to include paths
target_include_directories(onnxruntime_providers_webgpu PRIVATE ${WGSL_GENERATED_ROOT})
elseif(onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic")
add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c")
target_compile_features(duktape_static PRIVATE c_std_99)
target_link_libraries(onnxruntime_providers_webgpu duktape_static)
onnxruntime_add_include_to_target(onnxruntime_providers_webgpu duktape_static)

target_include_directories(onnxruntime_providers_webgpu PRIVATE ${duktape_SOURCE_DIR}/src)
# Define the path to the generated templates.js file
target_compile_definitions(onnxruntime_providers_webgpu PRIVATE
"ORT_WGSL_TEMPLATES_JS_PATH=\"${WGSL_GENERATED_TEMPLATES_JS}\"")
Expand Down
9 changes: 1 addition & 8 deletions cmake/vcpkg.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"ms-gsl",
"nlohmann-json",
"onnx",
"optional-lite",
{
"name": "protobuf",
"version>=": "3.21.12"
Expand Down Expand Up @@ -93,10 +94,6 @@
"webgpu-ep": {
"description": "Build with WebGPU EP",
"dependencies": []
},
"webgpu-ep-wgsl-template-dynamic": {
"description": "Build with WebGPU EP with dynamic WGSL template code generator",
"dependencies": ["duktape"]
}
},
"overrides": [
Expand All @@ -107,10 +104,6 @@
{
"name": "flatbuffers",
"version": "23.5.26"
},
{
"name": "duktape",
"version": "2.7.0#2"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,10 @@ public void RunPlatformUnitTest()
var serializedResultSummary = _app.Invoke(_getResultsBackdoorMethodName)?.ToString();
Assert.IsNotEmpty(serializedResultSummary, "Test results were not returned");

// Fix security issue (overflow with too much nesting): GHSA-5crp-9r3c-p9vr
JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 };
var testSummary = JsonConvert.DeserializeObject<TestResultSummary>(serializedResultSummary);
Assert.AreEqual(testSummary.Failed, 0, $"{testSummary.Failed} tests failed");

_app.Screenshot("Post-testing");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ public TestResultSummary GetResults()
public string GetSerializedResults()
{
var resultSummary = GetResults();
JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 };
var serializedResultSummary = JsonConvert.SerializeObject(resultSummary, Formatting.Indented);
return serializedResultSummary;
}
}
}
}
10 changes: 2 additions & 8 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2545,8 +2545,6 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>left_window_size for local attention (like Mistral). Default value is -1 meaning unused.</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>qk_output</tt> : int</dt>
<dd>Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).</dd>
<dt><tt>rotary_interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>scale</tt> : float</dt>
Expand All @@ -2557,7 +2555,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Softcap value for attention weights. Default value is 0.</dd>
</dl>

#### Inputs (7 - 12)
#### Inputs (7 - 11)

<dl>
<dt><tt>query</tt> : T</dt>
Expand All @@ -2582,11 +2580,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element</dd>
<dt><tt>attention_bias</tt> (optional) : T</dt>
<dd>additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)</dd>
<dt><tt>head_sink</tt> (optional) : T</dt>
<dd>1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.</dd>
</dl>

#### Outputs (3 - 4)
#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
Expand All @@ -2595,8 +2591,6 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>present_value</tt> : T</dt>
<dd>present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>output_qk</tt> (optional) : T</dt>
<dd>Values of QK matrix multiplication, either before or after softmax normalization</dd>
</dl>

#### Type Constraints
Expand Down
6 changes: 3 additions & 3 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ Do not modify directly.*
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
Expand Down Expand Up @@ -942,7 +942,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down Expand Up @@ -1420,7 +1420,7 @@ Do not modify directly.*
|FusedMatMulActivation|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
Expand Down
5 changes: 1 addition & 4 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -952,12 +952,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
return const_cast<Graph*>(this)->GetNodeArg(name);
}

// Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding mutable NodeArg
// search this and up through any parent_graph_ instance for a NodeArg
NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name);

// Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding const NodeArg
const NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const;

/** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found.
@param name The NodeArg name.
@param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created.
Expand Down
Loading
Loading