Skip to content

Commit 72adc25

Browse files
committed
[FFI] Introduce small string/bytes
This PR introduces small string support to the FFI system. When the string length fit into the space in TVMFFIAny (aka len(str) <= 7). We directly store the string content into the TVMFFIAny content instead of creating a Object. This change will likely speedup small string access. Some implications: - Always check for kTVMFFISmallStr code as well as kTVMFFIStr - Always check for kTVMFFISmallBytes when checking kTVMFFIBytes - Avoid using details::StringObj, instead, always use any.as<String>() - Always set any.padding to 0 for other values (in compiler and runtime) to enable fast cmp
1 parent 3c189f0 commit 72adc25

File tree

72 files changed

+1138
-362
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+1138
-362
lines changed

ffi/include/tvm/ffi/any.h

Lines changed: 84 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class AnyView {
6060
void reset() {
6161
data_.type_index = TypeIndex::kTVMFFINone;
6262
// invariance: always set the union padding part to 0
63+
data_.zero_padding = 0;
6364
data_.v_int64 = 0;
6465
}
6566
/*!
@@ -72,6 +73,7 @@ class AnyView {
7273
// default constructors
7374
AnyView() {
7475
data_.type_index = TypeIndex::kTVMFFINone;
76+
data_.zero_padding = 0;
7577
data_.v_int64 = 0;
7678
}
7779
~AnyView() = default;
@@ -80,6 +82,7 @@ class AnyView {
8082
AnyView& operator=(const AnyView&) = default;
8183
AnyView(AnyView&& other) : data_(other.data_) {
8284
other.data_.type_index = TypeIndex::kTVMFFINone;
85+
other.data_.zero_padding = 0;
8386
other.data_.v_int64 = 0;
8487
}
8588
TVM_FFI_INLINE AnyView& operator=(AnyView&& other) {
@@ -198,22 +201,19 @@ TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data,
198201
if (data->type_index == TypeIndex::kTVMFFIRawStr) {
199202
// convert raw string to owned string object
200203
String temp(data->v_c_str);
201-
data->type_index = TypeIndex::kTVMFFIStr;
202-
data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp));
204+
TypeTraits<String>::MoveToAny(std::move(temp), data);
203205
} else if (data->type_index == TypeIndex::kTVMFFIByteArrayPtr) {
204206
// convert byte array to owned bytes object
205207
Bytes temp(*static_cast<TVMFFIByteArray*>(data->v_ptr));
206-
data->type_index = TypeIndex::kTVMFFIBytes;
207-
data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp));
208+
TypeTraits<Bytes>::MoveToAny(std::move(temp), data);
208209
} else if (data->type_index == TypeIndex::kTVMFFIObjectRValueRef) {
209210
// convert rvalue ref to owned object
210211
Object** obj_addr = static_cast<Object**>(data->v_ptr);
211212
TVM_FFI_ICHECK(obj_addr[0] != nullptr) << "RValueRef already moved";
212213
ObjectRef temp(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(obj_addr[0]));
213214
// set the rvalue ref to nullptr to avoid double move
214215
obj_addr[0] = nullptr;
215-
data->type_index = temp->type_index();
216-
data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp));
216+
TypeTraits<ObjectRef>::MoveToAny(std::move(temp), data);
217217
}
218218
}
219219
}
@@ -239,6 +239,7 @@ class Any {
239239
details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj);
240240
}
241241
data_.type_index = TVMFFITypeIndex::kTVMFFINone;
242+
data_.zero_padding = 0;
242243
data_.v_int64 = 0;
243244
}
244245
/*!
@@ -251,6 +252,7 @@ class Any {
251252
// default constructors
252253
Any() {
253254
data_.type_index = TypeIndex::kTVMFFINone;
255+
data_.zero_padding = 0;
254256
data_.v_int64 = 0;
255257
}
256258
~Any() { this->reset(); }
@@ -262,6 +264,7 @@ class Any {
262264
}
263265
Any(Any&& other) : data_(other.data_) {
264266
other.data_.type_index = TypeIndex::kTVMFFINone;
267+
other.data_.zero_padding = 0;
265268
other.data_.v_int64 = 0;
266269
}
267270
TVM_FFI_INLINE Any& operator=(const Any& other) {
@@ -408,7 +411,8 @@ class Any {
408411
* \return True if the two Any are same type and value, false otherwise.
409412
*/
410413
TVM_FFI_INLINE bool same_as(const Any& other) const noexcept {
411-
return data_.type_index == other.data_.type_index && data_.v_int64 == other.data_.v_int64;
414+
return data_.type_index == other.data_.type_index &&
415+
data_.zero_padding == other.data_.zero_padding && data_.v_int64 == other.data_.v_int64;
412416
}
413417

414418
/*
@@ -485,6 +489,7 @@ struct AnyUnsafe : public ObjectUnsafe {
485489
TVM_FFI_INLINE static TVMFFIAny MoveAnyToTVMFFIAny(Any&& any) {
486490
TVMFFIAny result = any.data_;
487491
any.data_.type_index = TypeIndex::kTVMFFINone;
492+
any.data_.zero_padding = 0;
488493
any.data_.v_int64 = 0;
489494
return result;
490495
}
@@ -493,6 +498,7 @@ struct AnyUnsafe : public ObjectUnsafe {
493498
Any any;
494499
any.data_ = data;
495500
data.type_index = TypeIndex::kTVMFFINone;
501+
data.zero_padding = 0;
496502
data.v_int64 = 0;
497503
return any;
498504
}
@@ -543,17 +549,24 @@ struct AnyHash {
543549
* \return Hash code of a, string hash for strings and pointer address otherwise.
544550
*/
545551
uint64_t operator()(const Any& src) const {
546-
uint64_t val_hash = [&]() -> uint64_t {
547-
if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
548-
src.data_.type_index == TypeIndex::kTVMFFIBytes) {
549-
const details::BytesObjBase* src_str =
550-
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
551-
return details::StableHashBytes(src_str->data, src_str->size);
552-
} else {
553-
return src.data_.v_uint64;
554-
}
555-
}();
556-
return details::StableHashCombine(src.data_.type_index, val_hash);
552+
if (src.data_.type_index == TypeIndex::kTVMFFISmallStr) {
553+
// for small string, we use the same type key hash as normal string
554+
// so heap allocated string and on stack string will have the same hash
555+
return details::StableHashCombine(TypeIndex::kTVMFFIStr,
556+
details::StableHashSmallStrBytes(&src.data_));
557+
} else if (src.data_.type_index == TypeIndex::kTVMFFISmallBytes) {
558+
// use byte the same type key as bytes
559+
return details::StableHashCombine(TypeIndex::kTVMFFIBytes,
560+
details::StableHashSmallStrBytes(&src.data_));
561+
} else if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
562+
src.data_.type_index == TypeIndex::kTVMFFIBytes) {
563+
const details::BytesObjBase* src_str =
564+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
565+
return details::StableHashCombine(src.data_.type_index,
566+
details::StableHashBytes(src_str->data, src_str->size));
567+
} else {
568+
return details::StableHashCombine(src.data_.type_index, src.data_.v_uint64);
569+
}
557570
}
558571
};
559572

@@ -566,19 +579,60 @@ struct AnyEqual {
566579
* \return String equality if both are strings, pointer address equality otherwise.
567580
*/
568581
bool operator()(const Any& lhs, const Any& rhs) const {
569-
if (lhs.data_.type_index != rhs.data_.type_index) return false;
570-
// byte equivalence
571-
if (lhs.data_.v_int64 == rhs.data_.v_int64) return true;
572-
// specialy handle string hash
573-
if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
574-
lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
575-
const details::BytesObjBase* lhs_str =
576-
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
577-
const details::BytesObjBase* rhs_str =
578-
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
579-
return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size);
582+
// header with type index
583+
const int64_t* lhs_as_int64 = reinterpret_cast<const int64_t*>(&lhs.data_);
584+
const int64_t* rhs_as_int64 = reinterpret_cast<const int64_t*>(&rhs.data_);
585+
static_assert(sizeof(TVMFFIAny) == 16);
586+
// fast path, check byte equality
587+
if (lhs_as_int64[0] == rhs_as_int64[0] && lhs_as_int64[1] == rhs_as_int64[1]) {
588+
return true;
589+
}
590+
// common false case type index match, in this case we only need to pay attention to string
591+
// equality
592+
if (lhs.data_.type_index == rhs.data_.type_index) {
593+
// specialy handle string hash
594+
if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
595+
lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
596+
const details::BytesObjBase* lhs_str =
597+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
598+
const details::BytesObjBase* rhs_str =
599+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
600+
return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size);
601+
}
602+
return false;
603+
} else {
604+
// type_index mismatch, if index is not string, return false
605+
if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index != kTVMFFISmallStr &&
606+
lhs.data_.type_index != kTVMFFISmallBytes && lhs.data_.type_index != kTVMFFIBytes) {
607+
return false;
608+
}
609+
// small string and normal string comparison
610+
if (lhs.data_.type_index == kTVMFFIStr && rhs.data_.type_index == kTVMFFISmallStr) {
611+
const details::BytesObjBase* lhs_str =
612+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
613+
return Bytes::memequal(lhs_str->data, rhs.data_.v_bytes, lhs_str->size,
614+
rhs.data_.small_str_len);
615+
}
616+
if (lhs.data_.type_index == kTVMFFISmallStr && rhs.data_.type_index == kTVMFFIStr) {
617+
const details::BytesObjBase* rhs_str =
618+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
619+
return Bytes::memequal(lhs.data_.v_bytes, rhs_str->data, lhs.data_.small_str_len,
620+
rhs_str->size);
621+
}
622+
if (lhs.data_.type_index == kTVMFFIBytes && rhs.data_.type_index == kTVMFFISmallBytes) {
623+
const details::BytesObjBase* lhs_bytes =
624+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
625+
return Bytes::memequal(lhs_bytes->data, rhs.data_.v_bytes, lhs_bytes->size,
626+
rhs.data_.small_str_len);
627+
}
628+
if (lhs.data_.type_index == kTVMFFISmallBytes && rhs.data_.type_index == kTVMFFIBytes) {
629+
const details::BytesObjBase* rhs_bytes =
630+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
631+
return Bytes::memequal(lhs.data_.v_bytes, rhs_bytes->data, lhs.data_.small_str_len,
632+
rhs_bytes->size);
633+
}
634+
return false;
580635
}
581-
return false;
582636
}
583637
};
584638

ffi/include/tvm/ffi/base_details.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ TVM_FFI_INLINE uint64_t StableHashCombine(uint64_t key, const T& value) {
170170
* \param size The size of the bytes.
171171
* \return the hash value.
172172
*/
173-
TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) {
173+
TVM_FFI_INLINE uint64_t StableHashBytes(const void* data_ptr, size_t size) {
174+
const char* data = reinterpret_cast<const char*>(data_ptr);
174175
const constexpr uint64_t kMultiplier = 1099511628211ULL;
175176
const constexpr uint64_t kMod = 2147483647ULL;
176177
union Union {
@@ -250,6 +251,20 @@ TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) {
250251
return result;
251252
}
252253

254+
/*!
255+
* \brief Same as StableHashBytes, but for small string data.
256+
* \param data The data pointer
257+
* \return the hash value.
258+
*/
259+
TVM_FFI_INLINE uint64_t StableHashSmallStrBytes(const TVMFFIAny* data) {
260+
if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) {
261+
// fast path, no endian swap, simply hash as uint64_t
262+
const constexpr uint64_t kMod = 2147483647ULL;
263+
return data->v_uint64 % kMod;
264+
}
265+
return StableHashBytes(reinterpret_cast<const void*>(data), sizeof(data->v_uint64));
266+
}
267+
253268
} // namespace details
254269
} // namespace ffi
255270
} // namespace tvm

ffi/include/tvm/ffi/c_api.h

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,7 @@ enum TVMFFITypeIndex : int32_t {
6565
#else
6666
typedef enum {
6767
#endif
68-
// [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin)
69-
// N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array,
70-
// which is not owned by TVMFFIAny. It is required that the following
71-
// invariant holds:
72-
// - `Any::type_index` is never `kTVMFFIRawStr`
73-
// - `AnyView::type_index` can be `kTVMFFIRawStr`
74-
//
68+
7569
/*
7670
* \brief The root type of all FFI objects.
7771
*
@@ -80,6 +74,13 @@ typedef enum {
8074
* However, it may appear in field annotations during reflection.
8175
*/
8276
kTVMFFIAny = -1,
77+
// [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin)
78+
// N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array,
79+
// which is not owned by TVMFFIAny. It is required that the following
80+
// invariant holds:
81+
// - `Any::type_index` is never `kTVMFFIRawStr`
82+
// - `AnyView::type_index` can be `kTVMFFIRawStr`
83+
//
8384
/*! \brief None/nullptr value */
8485
kTVMFFINone = 0,
8586
/*! \brief POD int value */
@@ -96,12 +97,16 @@ typedef enum {
9697
kTVMFFIDevice = 6,
9798
/*! \brief DLTensor* */
9899
kTVMFFIDLTensorPtr = 7,
99-
/*! \brief const char**/
100+
/*! \brief const char* */
100101
kTVMFFIRawStr = 8,
101102
/*! \brief TVMFFIByteArray* */
102103
kTVMFFIByteArrayPtr = 9,
103104
/*! \brief R-value reference to ObjectRef */
104105
kTVMFFIObjectRValueRef = 10,
106+
/*! \brief Small string on stack */
107+
kTVMFFISmallStr = 11,
108+
/*! \brief Small bytes on stack */
109+
kTVMFFISmallBytes = 12,
105110
/*! \brief Start of statically defined objects. */
106111
kTVMFFIStaticObjectBegin = 64,
107112
/*!
@@ -183,11 +188,17 @@ typedef struct TVMFFIAny {
183188
* \note The type index of Object and Any are shared in FFI.
184189
*/
185190
int32_t type_index;
186-
/*!
187-
* \brief length for on-stack Any object, such as small-string
188-
* \note This field is reserved for future compact.
189-
*/
190-
int32_t small_len;
191+
union { // 4 bytes
192+
/*! \brief padding, must set to zero for values other than small string. */
193+
uint32_t zero_padding;
194+
/*!
195+
* \brief Length of small string, with a max value of 7.
196+
*
197+
* We keep small str to start at next 4 bytes to ensure alignment
198+
* when accessing the small str content.
199+
*/
200+
uint32_t small_str_len;
201+
};
191202
union { // 8 bytes
192203
int64_t v_int64; // integers
193204
double v_float64; // floating-point numbers
@@ -823,7 +834,7 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType*
823834
824835
* \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues.
825836
*/
826-
TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out);
837+
TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out);
827838

828839
//------------------------------------------------------------
829840
// Section: Backend noexcept functions for internal use
@@ -903,6 +914,15 @@ inline int32_t TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) {
903914
return static_cast<TVMFFIObject*>(obj)->type_index;
904915
}
905916

917+
/*!
918+
* \brief Get the content of a small string in bytearray format.
919+
* \param obj The object handle.
920+
* \return The content of the small string in bytearray format.
921+
*/
922+
inline TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) {
923+
return TVMFFIByteArray{value->v_bytes, static_cast<size_t>(value->small_str_len)};
924+
}
925+
906926
/*!
907927
* \brief Get the data pointer of a bytearray from a string or bytes object.
908928
* \param obj The object handle.

ffi/include/tvm/ffi/cast.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <tvm/ffi/dtype.h>
2828
#include <tvm/ffi/error.h>
2929
#include <tvm/ffi/object.h>
30+
#include <tvm/ffi/optional.h>
3031

3132
#include <utility>
3233

ffi/include/tvm/ffi/container/variant.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,12 @@ class VariantBase<true> : public ObjectRef {
8080
TVMFFIAny any_data;
8181
if (data_ == nullptr) {
8282
any_data.type_index = TypeIndex::kTVMFFINone;
83+
any_data.zero_padding = 0;
8384
any_data.v_int64 = 0;
8485
} else {
8586
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
8687
any_data.type_index = data_->type_index();
88+
any_data.zero_padding = 0;
8789
any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr<Object>(data_);
8890
}
8991
return AnyView::CopyFromTVMFFIAny(any_data);

ffi/include/tvm/ffi/dtype.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,15 @@ inline const char* DLDataTypeCodeAsCStr(DLDataTypeCode type_code) { // NOLINT(*
115115

116116
inline DLDataType StringToDLDataType(const String& str) {
117117
DLDataType out;
118-
TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(str.get(), &out));
118+
TVMFFIByteArray data{str.data(), str.size()};
119+
TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(&data, &out));
119120
return out;
120121
}
121122

122123
inline String DLDataTypeToString(DLDataType dtype) {
123-
TVMFFIObjectHandle out;
124+
TVMFFIAny out;
124125
TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out));
125-
return String(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(out)));
126+
return TypeTraits<String>::MoveFromAnyAfterCheck(&out);
126127
}
127128

128129
// DLDataType
@@ -134,13 +135,15 @@ struct TypeTraits<DLDataType> : public TypeTraitsBase {
134135
// clear padding part to ensure the equality check can always check the v_uint64 part
135136
result->v_uint64 = 0;
136137
result->type_index = TypeIndex::kTVMFFIDataType;
138+
result->zero_padding = 0;
137139
result->v_dtype = src;
138140
}
139141

140142
TVM_FFI_INLINE static void MoveToAny(DLDataType src, TVMFFIAny* result) {
141143
// clear padding part to ensure the equality check can always check the v_uint64 part
142144
result->v_uint64 = 0;
143145
result->type_index = TypeIndex::kTVMFFIDataType;
146+
result->zero_padding = 0;
144147
result->v_dtype = src;
145148
}
146149

ffi/include/tvm/ffi/object.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ struct StaticTypeKey {
6060
static constexpr const char* kTVMFFIFunction = "ffi.Function";
6161
static constexpr const char* kTVMFFIArray = "ffi.Array";
6262
static constexpr const char* kTVMFFIMap = "ffi.Map";
63+
static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr";
64+
static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes";
6365
};
6466

6567
/*!

0 commit comments

Comments
 (0)