Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vector-search: fix attach #1799

Merged
merged 8 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 80 additions & 22 deletions libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c
Original file line number Diff line number Diff line change
Expand Up @@ -85652,7 +85652,7 @@ int vectorIdxParseColumnType(const char *, int *, int *, const char **);
int vectorIndexCreate(Parse*, const Index*, const char *, const IdList*);
int vectorIndexClear(sqlite3 *, const char *, const char *);
int vectorIndexDrop(sqlite3 *, const char *, const char *);
int vectorIndexSearch(sqlite3 *, const char *, int, sqlite3_value **, VectorOutRows *, int *, int *, char **);
int vectorIndexSearch(sqlite3 *, int, sqlite3_value **, VectorOutRows *, int *, int *, char **);
int vectorIndexCursorInit(sqlite3 *, const char *, const char *, VectorIdxCursor **);
void vectorIndexCursorClose(sqlite3 *, VectorIdxCursor *, int *, int *);
int vectorIndexInsert(VectorIdxCursor *, const UnpackedRecord *, char **);
Expand Down Expand Up @@ -215829,17 +215829,25 @@ int vectorIndexTryGetParametersFromBinFormat(sqlite3 *db, const char *zSql, cons

int vectorIndexGetParameters(
sqlite3 *db,
const char *zDbSName,
const char *zIdxName,
VectorIdxParams *pParams
) {
int rc = SQLITE_OK;
assert( zDbSName != NULL );

static const char* zSelectSql = "SELECT metadata FROM " VECTOR_INDEX_GLOBAL_META_TABLE " WHERE name = ?";
static const char *zSelectSqlTemplate = "SELECT metadata FROM \"%w\"." VECTOR_INDEX_GLOBAL_META_TABLE " WHERE name = ?";
char* zSelectSql;
zSelectSql = sqlite3_mprintf(zSelectSqlTemplate, zDbSName);
if( zSelectSql == NULL ){
return SQLITE_NOMEM_BKPT;
}
// zSelectSqlPekkaLegacy handles the case when user created DB before 04 July 2024 (https://discord.com/channels/933071162680958986/1225560924526477322/1258367912402489397)
// when instead of table with binary parameters rigid schema was used for index settings
// we should drop this eventually - but for now we postponed this decision
static const char* zSelectSqlPekkaLegacy = "SELECT vector_type, block_size, dims, distance_ops FROM libsql_vector_index WHERE name = ?";
rc = vectorIndexTryGetParametersFromBinFormat(db, zSelectSql, zIdxName, pParams);
sqlite3_free(zSelectSql);
if( rc == SQLITE_OK ){
return SQLITE_OK;
}
Expand Down Expand Up @@ -216022,19 +216030,46 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co
return CREATE_OK;
}

// extracts schema and index name part if full index name is composite (e.g. schema_name.index_name)
// if full index name has no schema part - function returns SQLITE_OK but leaves pzIdxDbSName and pzIdxName untouched
int getIndexNameParts(sqlite3 *db, const char *zIdxFullName, char **pzIdxDbSName, char **pzIdxName) {
int nFullName, nDbSName;
const char *pDot = zIdxFullName;
while( *pDot != '.' && *pDot != '\0' ){
pDot++;
}
if( *pDot == '\0' ){
return SQLITE_OK;
}
assert( *pDot == '.' );
nFullName = sqlite3Strlen30(zIdxFullName);
nDbSName = pDot - zIdxFullName;
*pzIdxDbSName = sqlite3DbStrNDup(db, zIdxFullName, nDbSName);
*pzIdxName = sqlite3DbStrNDup(db, pDot + 1, nFullName - nDbSName - 1);
if( pzIdxName == NULL || pzIdxDbSName == NULL ){
sqlite3DbFree(db, *pzIdxName);
sqlite3DbFree(db, *pzIdxDbSName);
return SQLITE_NOMEM_BKPT;
}
return SQLITE_OK;
}

int vectorIndexSearch(
sqlite3 *db,
const char* zDbSName,
int argc,
sqlite3_value **argv,
VectorOutRows *pRows,
int *nReads,
int *nWrites,
char **pzErrMsg
) {
int type, dims, k, rc;
int type, dims, k, rc, iDb = -1;
double kDouble;
const char *zIdxName;
const char *zIdxFullName;
char *zIdxDbSNameAlloc = NULL; // allocated managed schema name string - must be freed if not null
char *zIdxNameAlloc = NULL; // allocated managed index name string - must be freed if not null
const char *zIdxDbSName = NULL; // schema name of the index (can be static in cases where explicit schema is omitted - so must not be freed)
const char *zIdxName = NULL; // index name (can be extracted with sqlite3_value_text and managed by SQLite - so must not be freed)
const char *zErrMsg;
Vector *pVector = NULL;
DiskAnnIndex *pDiskAnn = NULL;
Expand All @@ -216043,8 +216078,6 @@ int vectorIndexSearch(
VectorIdxParams idxParams;
vectorIdxParamsInit(&idxParams, NULL, 0);

assert( zDbSName != NULL );

if( argc != 3 ){
*pzErrMsg = sqlite3_mprintf("vector index(search): got %d parameters, expected 3", argc);
rc = SQLITE_ERROR;
Expand Down Expand Up @@ -216095,19 +216128,45 @@ int vectorIndexSearch(
rc = SQLITE_ERROR;
goto out;
}
zIdxName = (const char*)sqlite3_value_text(argv[0]);
if( vectorIndexGetParameters(db, zIdxName, &idxParams) != 0 ){
zIdxFullName = (const char*)sqlite3_value_text(argv[0]);
rc = getIndexNameParts(db, zIdxFullName, &zIdxDbSNameAlloc, &zIdxNameAlloc);
if( rc != SQLITE_OK ){
*pzErrMsg = sqlite3_mprintf("vector index(search): failed to parse index name");
goto out;
}
assert( (zIdxDbSNameAlloc == NULL && zIdxNameAlloc == NULL) || (zIdxDbSNameAlloc != NULL && zIdxNameAlloc != NULL) );
if( zIdxDbSNameAlloc == NULL && zIdxNameAlloc == NULL ){
zIdxDbSName = "main";
zIdxName = zIdxFullName;
} else{
zIdxDbSName = zIdxDbSNameAlloc;
zIdxName = zIdxNameAlloc;
iDb = sqlite3FindDbName(db, zIdxDbSName);
if( iDb < 0 ){
*pzErrMsg = sqlite3_mprintf("vector index(search): unknown schema '%s'", zIdxDbSName);
rc = SQLITE_ERROR;
goto out;
}
// we need to hold mutex to protect schema against unwanted changes
// this code is necessary, otherwise sqlite3SchemaMutexHeld assert will fail
if( iDb !=1 ){
// not "main" DB which we already hold mutex for
sqlite3BtreeEnter(db->aDb[iDb].pBt);
}
}

if( vectorIndexGetParameters(db, zIdxDbSName, zIdxName, &idxParams) != 0 ){
*pzErrMsg = sqlite3_mprintf("vector index(search): failed to parse vector index parameters");
rc = SQLITE_ERROR;
goto out;
}
pIndex = sqlite3FindIndex(db, zIdxName, zDbSName);
pIndex = sqlite3FindIndex(db, zIdxName, zIdxDbSName);
if( pIndex == NULL ){
*pzErrMsg = sqlite3_mprintf("vector index(search): index not found");
rc = SQLITE_ERROR;
goto out;
}
rc = diskAnnOpenIndex(db, zDbSName, zIdxName, &idxParams, &pDiskAnn);
rc = diskAnnOpenIndex(db, zIdxDbSName, zIdxName, &idxParams, &pDiskAnn);
if( rc != SQLITE_OK ){
*pzErrMsg = sqlite3_mprintf("vector index(search): failed to open diskann index");
goto out;
Expand All @@ -216127,6 +216186,11 @@ int vectorIndexSearch(
if( pVector != NULL ){
vectorFree(pVector);
}
sqlite3DbFree(db, zIdxNameAlloc);
sqlite3DbFree(db, zIdxDbSNameAlloc);
if( iDb >= 0 && iDb != 1 ){
sqlite3BtreeLeave(db->aDb[iDb].pBt);
}
return rc;
}

Expand Down Expand Up @@ -216176,7 +216240,7 @@ int vectorIndexCursorInit(

assert( zDbSName != NULL );

if( vectorIndexGetParameters(db, zIndexName, &params) != 0 ){
if( vectorIndexGetParameters(db, zDbSName, zIndexName, &params) != 0 ){
return SQLITE_ERROR;
}
pCursor = sqlite3DbMallocZero(db, sizeof(VectorIdxCursor));
Expand Down Expand Up @@ -216240,7 +216304,6 @@ typedef struct vectorVtab vectorVtab;
struct vectorVtab {
sqlite3_vtab base; /* Base class - must be first */
sqlite3 *db; /* Database connection */
char* zDbSName; /* Database schema name */
};

typedef struct vectorVtab_cursor vectorVtab_cursor;
Expand All @@ -216266,7 +216329,6 @@ static int vectorVtabConnect(
sqlite3_vtab **ppVtab,
char **pzErr
){
char *zDbSName = NULL;
vectorVtab *pVtab = NULL;
int rc;
/*
Expand All @@ -216281,21 +216343,17 @@ static int vectorVtabConnect(
if( pVtab == NULL ){
return SQLITE_NOMEM_BKPT;
}
zDbSName = sqlite3DbStrDup(db, argv[1]); // argv[1] is the database schema name by spec (see https://www.sqlite.org/vtab.html#the_xcreate_method)
if( zDbSName == NULL ){
sqlite3_free(pVtab);
return SQLITE_NOMEM_BKPT;
}
// > Eponymous virtual tables exist in the "main" schema only, so they will not work if prefixed with a different schema name.
// so, argv[1] always equal to "main" and we can safely ignore it
// (see https://www.sqlite.org/vtab.html#epovtab)
memset(pVtab, 0, sizeof(*pVtab));
pVtab->db = db;
pVtab->zDbSName = zDbSName;
*ppVtab = (sqlite3_vtab*)pVtab;
return SQLITE_OK;
}

static int vectorVtabDisconnect(sqlite3_vtab *pVtab){
vectorVtab *pVTab = (vectorVtab*)pVtab;
sqlite3DbFree(pVTab->db, pVTab->zDbSName);
sqlite3_free(pVtab);
return SQLITE_OK;
}
Expand Down Expand Up @@ -216362,7 +216420,7 @@ static int vectorVtabFilter(
pCur->rows.aIntValues = NULL;
pCur->rows.ppValues = NULL;

if( vectorIndexSearch(pVTab->db, pVTab->zDbSName, argc, argv, &pCur->rows, &pCur->nReads, &pCur->nWrites, &pVTab->base.zErrMsg) != 0 ){
if( vectorIndexSearch(pVTab->db, argc, argv, &pCur->rows, &pCur->nReads, &pCur->nWrites, &pVTab->base.zErrMsg) != 0 ){
return SQLITE_ERROR;
}

Expand Down
Loading
Loading