Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
2ac57eb
implement TensorCaster specializations for int4/uint4
amarin16 Jun 4, 2025
e8dc543
Create unit tests
amarin16 Jun 5, 2025
a4dc230
update shape inside tests
amarin16 Jun 5, 2025
6d0c033
fix bug in tests
amarin16 Jun 6, 2025
4979206
update test names
amarin16 Jun 6, 2025
6e34f5b
Add CastTo/FromString and TensorCasterNoSat specializations
amarin16 Jun 6, 2025
61d50d8
fix warning
amarin16 Jun 6, 2025
2487411
apply lintrunner
amarin16 Jun 6, 2025
e2f6f02
output low nibble first
amarin16 Jun 6, 2025
4ffef5f
Update to use Rv10 instead of Rv9 (breaks build)
amarin16 Jun 6, 2025
e0079c0
more than 1 partial specialization
amarin16 Jun 6, 2025
6157a9d
cannot convert OldType to NewType
amarin16 Jun 6, 2025
8cd5237
remove macros, expand specializations
amarin16 Jun 7, 2025
5e23ee2
merge together similar specializations
amarin16 Jun 7, 2025
9d50386
remove unused aliases
amarin16 Jun 9, 2025
8474ca7
clean up template specializations
amarin16 Jun 9, 2025
1ddcb96
reuse existing aliases, refactor for consistency
amarin16 Jun 9, 2025
1d02ca1
fix multiple partial specializations issue for MLFloat16
amarin16 Jun 9, 2025
406666f
clean up std::string specializtions
amarin16 Jun 9, 2025
13a9c17
Add specializations for Int4x2 -> UInt4x2 and UInt4x2 -> Int4x2
amarin16 Jun 9, 2025
bcb2412
more concise Int4ElementConverter
amarin16 Jun 9, 2025
cb205e5
more concise ToInt4ElementConverter
amarin16 Jun 9, 2025
1e8f680
merge a few TensorCaster specializations
amarin16 Jun 9, 2025
2d1ca00
merge a few more TensorCaster specializations
amarin16 Jun 9, 2025
18b3b11
styling suggestions, lint
amarin16 Jun 9, 2025
8ceae1b
update iterations over the input
amarin16 Jun 9, 2025
bb806e4
update iteration
amarin16 Jun 9, 2025
6825974
Update opset in unit tests to support int4
amarin16 Jun 10, 2025
b980762
Add unit tests
amarin16 Jun 10, 2025
207e5df
add more unit tests
amarin16 Jun 10, 2025
8ab0161
update string implementation and add tests
amarin16 Jun 10, 2025
c81c502
lint
amarin16 Jun 10, 2025
7e4ce01
add MLFloat16 tests
amarin16 Jun 10, 2025
6ec215c
update iteration, add test for odd elements
amarin16 Jun 12, 2025
b39a4f4
add specialization from float to MLFloat16
amarin16 Jun 12, 2025
351626e
try to fix pipeline errors
amarin16 Jun 12, 2025
3b84c14
lint
amarin16 Jun 12, 2025
b5ecbf9
Try [[noreturn]] to fix pipelines
amarin16 Jun 12, 2025
b19df98
supress 4702 warning
amarin16 Jun 12, 2025
77b1916
try pipeline fix
amarin16 Jun 13, 2025
8381956
lint
amarin16 Jun 16, 2025
3d9be2f
disable warning
amarin16 Jun 17, 2025
9ce545a
move pragma statements
amarin16 Jun 17, 2025
082ae75
update pragma
amarin16 Jun 17, 2025
234e7f0
Merge branch 'main' into dev/emarin/cast_int4
amarin16 Jun 24, 2025
c70e7f3
Update docs for Cast
amarin16 Jun 25, 2025
950a8e3
Update onnx patch
amarin16 Jun 25, 2025
41af71b
update patch
amarin16 Jun 26, 2025
5b79476
update patch
amarin16 Jun 26, 2025
b39c660
keep binskim.patch in sync with onnx.patch
amarin16 Jun 26, 2025
10b2eea
update patches
amarin16 Jun 26, 2025
c89e11f
exclude onnx tests in TestCase.cc
amarin16 Jun 26, 2025
da7e444
remove patch fixes
amarin16 Jun 30, 2025
408ae63
merge main
amarin16 Jun 30, 2025
be31bdd
Add newline at end of patch files
amarin16 Jun 30, 2025
ea2956c
explicitly mention next onnx version in skipped tests
amarin16 Jul 1, 2025
989719a
use std::is_floating_point_v
amarin16 Jul 1, 2025
e2be244
use constants for min anx max (u)int4 values
amarin16 Jul 2, 2025
c7c05fa
use constexpr if to merge specializations
amarin16 Jul 2, 2025
3932199
use constexpr if to merge specializations for Int4
amarin16 Jul 2, 2025
5e00b26
remove extra line
amarin16 Jul 2, 2025
af00c7d
remove anonymous namespace
amarin16 Jul 7, 2025
fe03e23
update IsOrtFloat8Type usage
amarin16 Jul 7, 2025
d382f2c
update cast between int4x2 and uint4x2 and tests, add Int4x2ToUInt64 …
amarin16 Jul 7, 2025
82b1fb5
Update cast down implementation and tests
amarin16 Jul 7, 2025
f1f4e2e
Fix pipeline issue
amarin16 Jul 8, 2025
2b4c325
fix pipeline issue
amarin16 Jul 8, 2025
18521cc
debug pipeline issue
amarin16 Jul 9, 2025
de3c01b
rename
amarin16 Jul 9, 2025
0dfb771
remove debugging tests
amarin16 Jul 10, 2025
338c440
lint
amarin16 Jul 10, 2025
1a22e4b
Merge branch 'main' into dev/emarin/cast_int4
amarin16 Jul 16, 2025
6d76f96
Update TensorCasterNoSat for Int4/UInt4, add tests with saturate = false
amarin16 Jul 19, 2025
6cbeea4
Add comments, extract common logic into lambda
amarin16 Jul 19, 2025
ed92b06
extract FromInt4Converter into separate struct, rename converter
amarin16 Jul 19, 2025
fad0445
add IsOrtInt4Type, use UnpackedType, merge specializations
amarin16 Jul 19, 2025
343a22c
specialize ToInt4Converter for bool, merge TensorCaster from bool spe…
amarin16 Jul 19, 2025
195780e
refactor ToInt4Converter, merge TensorCaster to Int4 specializations
amarin16 Jul 19, 2025
506a48b
enforce float 8 DstType for TensorCasterNoSat
amarin16 Jul 19, 2025
bc9edef
refactor ToInt4Converter, merge specializations
amarin16 Jul 20, 2025
4d652f0
Refactor ToInt4Converter, merge specializations
amarin16 Jul 20, 2025
27a448a
rename type, update comments
amarin16 Jul 20, 2025
fa705d6
small refactor
amarin16 Jul 20, 2025
4a65285
update comments, test values
amarin16 Jul 20, 2025
226c6dc
Add 2 tests
amarin16 Jul 20, 2025
6d97d37
merge int4 -> string specialization with int4 -> numeric
amarin16 Jul 20, 2025
be6c176
reuse ToInt4Converter inside string specializations
amarin16 Jul 20, 2025
003adae
Merge string -> int4 and string -> uint4 specializations
amarin16 Jul 20, 2025
27040c1
parse string as double, add test
amarin16 Jul 20, 2025
9a64f1c
Merge string -> int4 with numeric -> int4 specializations
amarin16 Jul 20, 2025
416635f
Add test
amarin16 Jul 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ Do not modify directly.*
|BitwiseOr|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseXor|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BlackmanWindow|*in* size:**T1**<br> *out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)<br/> **T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Cast|*in* input:**T1**<br> *out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Cast|*in* input:**T1**<br> *out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|Ceil|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|
|Celu|*in* X:**T**<br> *out* Y:**T**|12+|**T** = tensor(float)|
Expand Down
9 changes: 8 additions & 1 deletion include/onnxruntime/core/framework/data_types_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ class CallableDispatchableHelper {
public:
explicit CallableDispatchableHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0) {}

#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4702)
#endif
// Must return integer to be in a expandable context
template <class T, class Fn, class... Args>
int Invoke(Fn&& fn, Args&&... args) {
Expand All @@ -328,6 +332,9 @@ class CallableDispatchableHelper {
}
return 0;
}
#if defined(_MSC_VER)
#pragma warning(pop)
#endif

void CheckCalledOnce() const {
ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_);
Expand All @@ -338,7 +345,7 @@ class CallableDispatchableHelper {
// Other policies may set the second result argument accordingly.
template <class Ret>
struct UnsupportedTypeDefaultPolicy {
void operator()(int32_t dt_type, Ret& /*result*/) const {
[[noreturn]] void operator()(int32_t dt_type, Ret& /*result*/) const {
ORT_THROW("Unsupported data type: ", dt_type);
}
};
Expand Down
Loading
Loading