Skip to content

Commit

Permalink
Fix ByteArray typemap for Python. (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlphardWang authored and MaggieQi committed May 27, 2019
1 parent 4d15203 commit a00b58b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
50 changes: 48 additions & 2 deletions Wrappers/inc/PythonCommon.i
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,55 @@
}
%}

%typemap(in) ByteArray

%{
struct PyBufferHolder
{
PyBufferHolder() : shouldRelease(false) { }

~PyBufferHolder()
{
if (shouldRelease)
{
PyBuffer_Release(&buff);
}
}

Py_buffer buff;

bool shouldRelease;
};
%}

%typemap(in) ByteArray (PyBufferHolder bufferHolder)
%{
$1 = SPTAG::ByteArray((std::uint8_t*)PyBytes_AsString($input), PyBytes_Size($input), false);
if (PyBytes_Check($input))
{
$1 = SPTAG::ByteArray((std::uint8_t*)PyBytes_AsString($input), PyBytes_Size($input), false);
}
else if (PyObject_CheckBuffer($input))
{
if (PyObject_GetBuffer($input, &bufferHolder.buff, PyBUF_SIMPLE | PyBUF_C_CONTIGUOUS) == -1)
{
PyErr_SetString(PyExc_ValueError, "Failed get buffer.");
return NULL;
}

bufferHolder.shouldRelease = true;
$1 = SPTAG::ByteArray((std::uint8_t*)bufferHolder.buff.buf, bufferHolder.buff.len, false);
}
#if (PY_VERSION_HEX >= 0x03030000)
else if (PyUnicode_Check($input))
{
$1 = SPTAG::ByteArray((std::uint8_t*)PyUnicode_DATA($input), PyUnicode_GET_LENGTH($input), false);
}
#endif

if (nullptr == $1.Data())
{
PyErr_SetString(PyExc_ValueError, "Expected Bytes, Data Structure with Buffer Protocol, or Unicode String after Python 3.3 .");
return NULL;
}
%}

#endif
18 changes: 10 additions & 8 deletions docs/GettingStart.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,28 +124,28 @@ def testBuild(algo, distmethod, x, out):
i = SPTAG.AnnIndex(algo, 'Float', x.shape[1])
i.SetBuildParam("NumberOfThreads", '4')
i.SetBuildParam("DistCalcMethod", distmethod)
ret = i.Build(x.tobytes(), x.shape[0])
ret = i.Build(x, x.shape[0])
i.Save(out)

def testBuildWithMetaData(algo, distmethod, x, s, out):
i = SPTAG.AnnIndex(algo, 'Float', x.shape[1])
i.SetBuildParam("NumberOfThreads", '4')
i.SetBuildParam("DistCalcMethod", distmethod)
if i.BuildWithMetaData(x.tobytes(), s, x.shape[0]):
if i.BuildWithMetaData(x, s, x.shape[0]):
i.Save(out)

def testSearch(index, q, k):
j = SPTAG.AnnIndex.Load(index)
for t in range(q.shape[0]):
result = j.Search(q[t].tobytes(), k)
result = j.Search(q[t], k)
print (result[0]) # ids
print (result[1]) # distances

def testSearchWithMetaData(index, q, k):
j = SPTAG.AnnIndex.Load(index)
j.SetSearchParam("MaxCheck", '1024')
for t in range(q.shape[0]):
result = j.SearchWithMetaData(q[t].tobytes(), k)
result = j.SearchWithMetaData(q[t], k)
print (result[0]) # ids
print (result[1]) # distances
print (result[2]) # metadata
Expand All @@ -157,7 +157,7 @@ def testAdd(index, x, out, algo, distmethod):
i = SPTAG.AnnIndex(algo, 'Float', x.shape[1])
i.SetBuildParam("NumberOfThreads", '4')
i.SetBuildParam("DistCalcMethod", distmethod)
if i.Add(x.tobytes(), x.shape[0]):
if i.Add(x, x.shape[0]):
i.Save(out)

def testAddWithMetaData(index, x, s, out, algo, distmethod):
Expand All @@ -168,12 +168,12 @@ def testAddWithMetaData(index, x, s, out, algo, distmethod):
i = SPTAG.AnnIndex(algo, 'Float', x.shape[1])
i.SetBuildParam("NumberOfThreads", '4')
i.SetBuildParam("DistCalcMethod", distmethod)
if i.AddWithMetaData(x.tobytes(), s, x.shape[0]):
if i.AddWithMetaData(x, s, x.shape[0]):
i.Save(out)

def testDelete(index, x, out):
i = SPTAG.AnnIndex.Load(index)
ret = i.Delete(x.tobytes(), x.shape[0])
ret = i.Delete(x, x.shape[0])
print (ret)
i.Save(out)

Expand All @@ -184,6 +184,8 @@ def Test(algo, distmethod):
for i in range(n):
m += str(i) + '\n'

m = m.encode()

print ("Build.............................")
testBuild(algo, distmethod, x, 'testindices')
testSearch('testindices', q, k)
Expand Down Expand Up @@ -221,7 +223,7 @@ def testSPTAGClient():

q = np.ones((10, 10), dtype=np.float32)
for t in range(q.shape[0]):
result = index.Search(q[t].tobytes(), 6, 'Float', False)
result = index.Search(q[t], 6, 'Float', False)
print (result[0])
print (result[1])

Expand Down

0 comments on commit a00b58b

Please sign in to comment.