Skip to content

Commit

Permalink
ddl: Support parsing VectorIndex defined in IndexInfo (pingcap#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
JaySon-Huang authored Aug 26, 2024
1 parent 0b0750a commit 45f99a1
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 51 deletions.
5 changes: 3 additions & 2 deletions dbms/src/Storages/DeltaMerge/Index/VectorIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <Storages/DeltaMerge/File/dtpb/dmfile.pb.h>
#include <Storages/DeltaMerge/Index/VectorIndex.h>
#include <Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h>
#include <TiDB/Schema/VectorIndex.h>
#include <tipb/executor.pb.h>

namespace DB::ErrorCodes
Expand All @@ -40,7 +41,7 @@ bool VectorIndexBuilder::isSupportedType(const IDataType & type)
VectorIndexBuilderPtr VectorIndexBuilder::create(const TiDB::VectorIndexDefinitionPtr & definition)
{
RUNTIME_CHECK(definition->dimension > 0);
RUNTIME_CHECK(definition->dimension <= std::numeric_limits<UInt32>::max());
RUNTIME_CHECK(definition->dimension <= TiDB::MAX_VECTOR_DIMENSION);

switch (definition->kind)
{
Expand All @@ -57,7 +58,7 @@ VectorIndexBuilderPtr VectorIndexBuilder::create(const TiDB::VectorIndexDefiniti
VectorIndexViewerPtr VectorIndexViewer::view(const dtpb::VectorIndexFileProps & file_props, std::string_view path)
{
RUNTIME_CHECK(file_props.dimensions() > 0);
RUNTIME_CHECK(file_props.dimensions() <= std::numeric_limits<UInt32>::max());
RUNTIME_CHECK(file_props.dimensions() <= TiDB::MAX_VECTOR_DIMENSION);

tipb::VectorIndexKind kind;
RUNTIME_CHECK(tipb::VectorIndexKind_Parse(file_props.index_kind(), &kind));
Expand Down
16 changes: 6 additions & 10 deletions dbms/src/Storages/S3/FileCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ FileSegmentPtr FileCache::getOrWait(const S3::S3FilenameView & s3_fname, const s
lock.unlock();

PerfContext::file_cache.fg_download_from_s3++;
fgDownload(lock, s3_key, file_seg);
fgDownload(s3_key, file_seg);
if (!file_seg || !file_seg->isReadyToRead())
throw Exception( //
ErrorCodes::S3_ERROR,
Expand Down Expand Up @@ -338,22 +338,18 @@ void FileCache::removeDiskFile(const String & local_fname, bool update_fsize_met
}
}

void FileCache::remove(std::unique_lock<std::mutex> &, const String & s3_key, bool force)
void FileCache::remove(const String & s3_key, bool force)
{
auto file_type = getFileType(s3_key);
auto & table = tables[static_cast<UInt64>(file_type)];

std::unique_lock lock(mtx);
auto f = table.get(s3_key, /*update_lru*/ false);
if (f == nullptr)
return;
std::ignore = removeImpl(table, s3_key, f, force);
}

void FileCache::remove(const String & s3_key, bool force)
{
std::unique_lock lock(mtx);
remove(lock, s3_key, force);
}

std::pair<Int64, std::list<String>::iterator> FileCache::removeImpl(
LRUFileTable & table,
const String & s3_key,
Expand Down Expand Up @@ -789,7 +785,7 @@ void FileCache::bgDownload(const String & s3_key, FileSegmentPtr & file_seg)
[this, s3_key = s3_key, file_seg = file_seg]() mutable { download(s3_key, file_seg); });
}

void FileCache::fgDownload(std::unique_lock<std::mutex> & cache_lock, const String & s3_key, FileSegmentPtr & file_seg)
void FileCache::fgDownload(const String & s3_key, FileSegmentPtr & file_seg)
{
SYNC_FOR("FileCache::fgDownload"); // simulate long s3 download

Expand All @@ -809,7 +805,7 @@ void FileCache::fgDownload(std::unique_lock<std::mutex> & cache_lock, const Stri
file_seg->setStatus(FileSegment::Status::Failed);
GET_METRIC(tiflash_storage_remote_cache, type_dtfile_download_failed).Increment();
file_seg.reset();
remove(cache_lock, s3_key);
remove(s3_key);
}

LOG_DEBUG(
Expand Down
3 changes: 1 addition & 2 deletions dbms/src/Storages/S3/FileCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class FileCache
const std::optional<UInt64> & filesize = std::nullopt);

void bgDownload(const String & s3_key, FileSegmentPtr & file_seg);
void fgDownload(std::unique_lock<std::mutex> & cache_lock, const String & s3_key, FileSegmentPtr & file_seg);
void fgDownload(const String & s3_key, FileSegmentPtr & file_seg);
void download(const String & s3_key, FileSegmentPtr & file_seg);
void downloadImpl(const String & s3_key, FileSegmentPtr & file_seg);

Expand All @@ -276,7 +276,6 @@ class FileCache
void restoreTable(const std::filesystem::directory_entry & table_entry);
void restoreDMFile(const std::filesystem::directory_entry & dmfile_entry);

void remove(std::unique_lock<std::mutex> & cache_lock, const String & s3_key, bool force = false);
void remove(const String & s3_key, bool force = false);
std::pair<Int64, std::list<String>::iterator> removeImpl(
LRUFileTable & table,
Expand Down
102 changes: 68 additions & 34 deletions dbms/src/TiDB/Schema/TiDB.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <TiDB/Decode/Vector.h>
#include <TiDB/Schema/SchemaNameMapper.h>
#include <TiDB/Schema/TiDB.h>
#include <TiDB/Schema/VectorIndex.h>
#include <common/logger_useful.h>
#include <tipb/executor.pb.h>

Expand Down Expand Up @@ -103,6 +104,47 @@ using DB::Exception;
using DB::Field;
using DB::SchemaNameMapper;

VectorIndexDefinitionPtr parseVectorIndexFromJSON(const Poco::JSON::Object::Ptr & json)
{
assert(json); // not nullptr

tipb::VectorIndexKind kind = tipb::VectorIndexKind::INVALID_INDEX_KIND;
auto kind_field = json->getValue<String>("kind");
RUNTIME_CHECK_MSG(tipb::VectorIndexKind_Parse(kind_field, &kind), "invalid kind of vector index, {}", kind_field);
RUNTIME_CHECK(kind != tipb::VectorIndexKind::INVALID_INDEX_KIND);

auto dimension = json->getValue<UInt64>("dimension");
RUNTIME_CHECK(dimension > 0);
RUNTIME_CHECK(dimension <= TiDB::MAX_VECTOR_DIMENSION); // Just a protection

tipb::VectorDistanceMetric distance_metric = tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC;
auto distance_metric_field = json->getValue<String>("distance_metric");
RUNTIME_CHECK_MSG(
tipb::VectorDistanceMetric_Parse(distance_metric_field, &distance_metric),
"invalid distance_metric of vector index, {}",
distance_metric_field);
RUNTIME_CHECK(distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC);

return std::make_shared<const VectorIndexDefinition>(VectorIndexDefinition{
.kind = kind,
.dimension = dimension,
.distance_metric = distance_metric,
});
}

Poco::JSON::Object::Ptr vectorIndexToJSON(const VectorIndexDefinitionPtr & vector_index)
{
assert(vector_index != nullptr);
RUNTIME_CHECK(vector_index->kind != tipb::VectorIndexKind::INVALID_INDEX_KIND);
RUNTIME_CHECK(vector_index->distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC);

Poco::JSON::Object::Ptr vector_index_json = new Poco::JSON::Object();
vector_index_json->set("kind", tipb::VectorIndexKind_Name(vector_index->kind));
vector_index_json->set("dimension", vector_index->dimension);
vector_index_json->set("distance_metric", tipb::VectorDistanceMetric_Name(vector_index->distance_metric));
return vector_index_json;
}

////////////////////////
////// ColumnInfo //////
////////////////////////
Expand Down Expand Up @@ -412,15 +454,7 @@ try

if (vector_index)
{
RUNTIME_CHECK(vector_index->kind != tipb::VectorIndexKind::INVALID_INDEX_KIND);
RUNTIME_CHECK(vector_index->distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC);

Poco::JSON::Object::Ptr vector_index_json = new Poco::JSON::Object();
vector_index_json->set("kind", tipb::VectorIndexKind_Name(vector_index->kind));
vector_index_json->set("dimension", vector_index->dimension);
vector_index_json->set("distance_metric", tipb::VectorDistanceMetric_Name(vector_index->distance_metric));

json->set("vector_index", vector_index_json);
json->set("vector_index", vectorIndexToJSON(vector_index));
}

#ifndef NDEBUG
Expand Down Expand Up @@ -475,32 +509,9 @@ try
}
state = static_cast<SchemaState>(json->getValue<Int32>("state"));

auto vector_index_json = json->getObject("vector_index");
if (vector_index_json)
if (auto vector_index_json = json->getObject("vector_index"); vector_index_json)
{
tipb::VectorIndexKind kind = tipb::VectorIndexKind::INVALID_INDEX_KIND;
auto ok = tipb::VectorIndexKind_Parse( //
vector_index_json->getValue<String>("kind"),
&kind);
RUNTIME_CHECK(ok);
RUNTIME_CHECK(kind != tipb::VectorIndexKind::INVALID_INDEX_KIND);

auto dimension = vector_index_json->getValue<UInt64>("dimension");
RUNTIME_CHECK(dimension > 0);
RUNTIME_CHECK(dimension <= 16383); // Just a protection

tipb::VectorDistanceMetric distance_metric = tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC;
ok = tipb::VectorDistanceMetric_Parse( //
vector_index_json->getValue<String>("distance_metric"),
&distance_metric);
RUNTIME_CHECK(ok);
RUNTIME_CHECK(distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC);

vector_index = std::make_shared<const VectorIndexDefinition>(VectorIndexDefinition{
.kind = kind,
.dimension = dimension,
.distance_metric = distance_metric,
});
vector_index = parseVectorIndexFromJSON(vector_index_json);
}
}
catch (const Poco::Exception & e)
Expand Down Expand Up @@ -843,6 +854,11 @@ try
json->set("is_invisible", is_invisible);
json->set("is_global", is_global);

if (vector_index)
{
json->set("vector_index", vectorIndexToJSON(vector_index));
}

#ifndef NDEBUG
std::stringstream str;
json->stringify(str);
Expand Down Expand Up @@ -883,6 +899,11 @@ try
is_invisible = json->getValue<bool>("is_invisible");
if (json->has("is_global"))
is_global = json->getValue<bool>("is_global");

if (auto vector_index_json = json->getObject("vector_index"); vector_index_json)
{
vector_index = parseVectorIndexFromJSON(vector_index_json);
}
}
catch (const Poco::Exception & e)
{
Expand Down Expand Up @@ -1021,6 +1042,10 @@ try
// always put the primary_index at the front of all index_info
index_infos.insert(index_infos.begin(), std::move(index_info));
}
else if (index_info.vector_index != nullptr)
{
index_infos.emplace_back(std::move(index_info));
}
}
}

Expand Down Expand Up @@ -1169,6 +1194,15 @@ KeyspaceID TableInfo::getKeyspaceID() const
return keyspace_id;
}

const IndexInfo & TableInfo::getPrimaryIndexInfo() const
{
assert(is_common_handle);
#ifndef NDEBUG
RUNTIME_CHECK(index_infos[0].is_primary);
#endif
return index_infos[0];
}

String TableInfo::getColumnName(const ColumnID id) const
{
for (const auto & col : columns)
Expand Down
6 changes: 3 additions & 3 deletions dbms/src/TiDB/Schema/TiDB.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ struct IndexInfo
bool is_primary = false;
bool is_invisible = false;
bool is_global = false;

VectorIndexDefinitionPtr vector_index = nullptr;
};

struct TableInfo
Expand Down Expand Up @@ -404,9 +406,7 @@ struct TableInfo
}

/// should not be called if is_common_handle = false.
const IndexInfo & getPrimaryIndexInfo() const { return index_infos[0]; }

IndexInfo & getPrimaryIndexInfo() { return index_infos[0]; }
const IndexInfo & getPrimaryIndexInfo() const;
};

using DBInfoPtr = std::shared_ptr<DBInfo>;
Expand Down
3 changes: 3 additions & 0 deletions dbms/src/TiDB/Schema/VectorIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ struct VectorIndexDefinition
// ever try to modify it anyway.
using VectorIndexDefinitionPtr = std::shared_ptr<const VectorIndexDefinition>;

// Defined in TiDB pkg/types/vector.go
static constexpr Int64 MAX_VECTOR_DIMENSION = 16383;

} // namespace TiDB

template <>
Expand Down
92 changes: 92 additions & 0 deletions dbms/src/TiDB/Schema/tests/gtest_table_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,98 @@ try
}
CATCH

TEST(TiDBTableInfoTest, ParseVectorIndexJSON)
try
{
auto cases = {
ParseCase{
R"json({"cols":[{"default":null,"default_bit":null,"id":1,"name":{"L":"col1","O":"col1"},"offset":-1,"origin_default":null,"state":0,"type":{"Charset":null,"Collate":null,"Decimal":0,"Elems":null,"Flag":4097,"Flen":0,"Tp":8}},{"default":null,"default_bit":null,"id":2,"name":{"L":"vec","O":"vec"},"offset":-1,"origin_default":null,"state":0,"type":{"Charset":null,"Collate":null,"Decimal":0,"Elems":null,"Flag":4097,"Flen":0,"Tp":225}}],"id":30,"index_info":[{"id":3,"idx_cols":[{"length":-1,"name":{"L":"vec","O":"vec"},"offset":0}],"idx_name":{"L":"idx1","O":"idx1"},"index_type":-1,"is_global":false,"is_invisible":false,"is_primary":false,"is_unique":false,"state":5,"vector_index":{"dimension":3,"distance_metric":"L2","kind":"HNSW"}}],"is_common_handle":false,"name":{"L":"t1","O":"t1"},"partition":null,"pk_is_handle":false,"schema_version":-1,"state":0,"update_timestamp":1723778704444603})json",
[](const TableInfo & table_info) {
ASSERT_EQ(table_info.index_infos.size(), 1);
auto idx = table_info.index_infos[0];
ASSERT_EQ(idx.id, 3);
ASSERT_EQ(idx.idx_cols.size(), 1);
ASSERT_EQ(idx.idx_cols[0].name, "vec");
ASSERT_EQ(idx.idx_cols[0].offset, 0);
ASSERT_EQ(idx.idx_cols[0].length, -1);
ASSERT_NE(idx.vector_index, nullptr);
ASSERT_EQ(idx.vector_index->kind, tipb::VectorIndexKind::HNSW);
ASSERT_EQ(idx.vector_index->dimension, 3);
ASSERT_EQ(idx.vector_index->distance_metric, tipb::VectorDistanceMetric::L2);
ASSERT_EQ(table_info.columns.size(), 2);
auto col0 = table_info.columns[0];
ASSERT_EQ(col0.name, "col1");
ASSERT_EQ(col0.tp, TiDB::TP::TypeLongLong);
ASSERT_EQ(col0.id, 1);
auto col1 = table_info.columns[1];
ASSERT_EQ(col1.name, "vec");
ASSERT_EQ(col1.tp, TiDB::TP::TypeTiDBVectorFloat32);
ASSERT_EQ(col1.id, 2);
},
},
// VectorIndex defined in the ColumnInfo
ParseCase{
R"json({"cols":[{"comment":"hnsw(distance=l2)","default":null,"default_bit":null,"id":1,"name":{"L":"v","O":"v"},"offset":0,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":128,"Flen":5,"Tp":225},"vector_index":{"dimension":5,"distance_metric":"L2","kind":"HNSW"}}],"comment":"","id":96,"index_info":[],"is_common_handle":false,"keyspace_id":1,"name":{"L":"t","O":"t"},"partition":null,"pk_is_handle":false,"schema_version":-1,"state":5,"tiflash_replica":{"Available":true,"Count":1},"update_timestamp":451956855279976452})json",
[](const TableInfo & table_info) {
ASSERT_EQ(table_info.index_infos.size(), 0);
ASSERT_EQ(table_info.columns.size(), 1);
auto col = table_info.columns[0];
ASSERT_EQ(col.name, "v");
ASSERT_EQ(col.tp, TiDB::TP::TypeTiDBVectorFloat32);
ASSERT_EQ(col.id, 1);
auto vector_index_on_col = col.vector_index;
ASSERT_NE(vector_index_on_col, nullptr);
ASSERT_EQ(vector_index_on_col->kind, tipb::VectorIndexKind::HNSW);
ASSERT_EQ(vector_index_on_col->dimension, 5);
ASSERT_EQ(vector_index_on_col->distance_metric, tipb::VectorDistanceMetric::L2);
},
},
ParseCase{
R"json({"cols":[{"comment":"","default":null,"default_bit":null,"id":1,"name":{"L":"col","O":"col"},"offset":0,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":4099,"Flen":20,"Tp":8}},{"comment":"","default":null,"default_bit":null,"id":2,"name":{"L":"v","O":"v"},"offset":1,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":128,"Flen":5,"Tp":225}}],"comment":"","id":96,"index_info":[{"id":4,"idx_cols":[{"length":-1,"name":{"L":"v","O":"v"},"offset":1}],"idx_name":{"L":"idx_v_l2","O":"idx_v_l2"},"index_type":5,"is_global":false,"is_invisible":false,"is_primary":false,"is_unique":false,"state":3,"vector_index":{"dimension":5,"distance_metric":"L2","kind":"HNSW"}},{"id":3,"idx_cols":[{"length":-1,"name":{"L":"col","O":"col"},"offset":0}],"idx_name":{"L":"primary","O":"primary"},"index_type":1,"is_global":false,"is_invisible":false,"is_primary":true,"is_unique":true,"state":5}],"is_common_handle":false,"keyspace_id":1,"name":{"L":"ti","O":"ti"},"partition":null,"pk_is_handle":false,"schema_version":-1,"state":5,"tiflash_replica":{"Available":true,"Count":1},"update_timestamp":452024291984670725})json",
[](const TableInfo & table_info) {
// vector index && primary index
// primary index alwasy be put at the first
ASSERT_EQ(table_info.index_infos.size(), 2);
auto idx0 = table_info.index_infos[0];
ASSERT_TRUE(idx0.is_primary);
ASSERT_TRUE(idx0.is_unique);
ASSERT_EQ(idx0.id, 3);
ASSERT_EQ(idx0.idx_name, "primary");
ASSERT_EQ(idx0.idx_cols.size(), 1);
ASSERT_EQ(idx0.idx_cols[0].name, "col");
ASSERT_EQ(idx0.idx_cols[0].offset, 0);
ASSERT_EQ(idx0.vector_index, nullptr);
// vec index
auto idx1 = table_info.index_infos[1];
ASSERT_EQ(idx1.id, 4);
ASSERT_EQ(idx1.idx_name, "idx_v_l2");
ASSERT_EQ(idx1.idx_cols.size(), 1);
ASSERT_EQ(idx1.idx_cols[0].name, "v");
ASSERT_EQ(idx1.idx_cols[0].offset, 1);
ASSERT_NE(idx1.vector_index, nullptr);
ASSERT_EQ(idx1.vector_index->kind, tipb::VectorIndexKind::HNSW);
ASSERT_EQ(idx1.vector_index->dimension, 5);
ASSERT_EQ(idx1.vector_index->distance_metric, tipb::VectorDistanceMetric::L2);

ASSERT_EQ(table_info.columns.size(), 2);
auto col0 = table_info.columns[0];
ASSERT_EQ(col0.name, "col");
ASSERT_EQ(col0.tp, TiDB::TP::TypeLongLong);
ASSERT_EQ(col0.id, 1);
auto col1 = table_info.columns[1];
ASSERT_EQ(col1.name, "v");
ASSERT_EQ(col1.tp, TiDB::TP::TypeTiDBVectorFloat32);
ASSERT_EQ(col1.id, 2);
}}};

for (const auto & c : cases)
{
TableInfo table_info(c.table_info_json, NullspaceID);
c.check(table_info);
}
}
CATCH

struct StmtCase
{
TableID table_or_partition_id;
Expand Down

0 comments on commit 45f99a1

Please sign in to comment.