Skip to content

Commit 31f9fc0

Browse files
committed
Make cython compatible with python3 (apache#12)
1 parent bed950d commit 31f9fc0

File tree

3 files changed

+51
-14
lines changed

3 files changed

+51
-14
lines changed

nnvm/Makefile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loop
33
-Iinclude -Idmlc-core/include -fPIC
44

55
# specify tensor path
6-
.PHONY: clean all test lint doc cython cython3
6+
.PHONY: clean all test lint doc cython cython3 cyclean
77

88
all: lib/libnnvm.so lib/libnnvm.a cli_test
99

@@ -37,6 +37,8 @@ cython:
3737
cython3:
3838
cd python; python3 setup.py build_ext --inplace
3939

40+
cyclean:
41+
rm -rf python/nnvm/*/*.so python/nnvm/*/*.cpp
4042

4143
lint:
4244
python2 dmlc-core/scripts/lint.py nnvm cpp include src

nnvm/python/nnvm/cython/base.pyi

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,21 @@ cdef py_str(const char* x):
99
return x.decode("utf-8")
1010

1111

12+
cdef c_str(pystr):
13+
"""Create ctypes char * from a python string
14+
Parameters
15+
----------
16+
string : string type
17+
python string
18+
19+
Returns
20+
-------
21+
str : c_char_p
22+
A char pointer that can be passed to C API
23+
"""
24+
return pystr.encode("utf-8")
25+
26+
1227
cdef CALL(int ret):
1328
if ret != 0:
1429
raise NNVMError(NNGetLastError())
@@ -20,6 +35,13 @@ cdef const char** CBeginPtr(vector[const char*]& vec):
2035
else:
2136
return NULL
2237

38+
cdef vector[const char*] SVec2Ptr(vector[string]& vec):
39+
cdef vector[const char*] svec
40+
svec.resize(vec.size())
41+
for i in range(vec.size()):
42+
svec[i] = vec[i].c_str()
43+
return svec
44+
2345

2446
cdef BuildDoc(nn_uint num_args,
2547
const char** arg_names,

nnvm/python/nnvm/cython/symbol.pyx

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from .._base import NNVMError
66
from ..name import NameManager
77
from ..attribute import AttrScope
88
from libcpp.vector cimport vector
9+
from libcpp.string cimport string
910
from cpython.version cimport PY_MAJOR_VERSION
1011

1112
include "./base.pyi"
@@ -110,7 +111,7 @@ cdef class Symbol:
110111
CALL(NNSymbolGetOutput(self.handle, c_index, &handle))
111112
return NewSymbol(handle)
112113

113-
def attr(self, const char* key):
114+
def attr(self, key):
114115
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
115116
116117
Parameters
@@ -125,6 +126,8 @@ cdef class Symbol:
125126
"""
126127
cdef const char* ret
127128
cdef int success
129+
key = c_str(key)
130+
128131
CALL(NNSymbolGetAttr(
129132
self.handle, key, &ret, &success))
130133
if success != 0:
@@ -203,16 +206,19 @@ cdef class Symbol:
203206
def debug_str(self):
204207
cdef const char* out_str
205208
CALL(NNSymbolPrint(self.handle, &out_str))
206-
return str(out_str)
209+
return py_str(out_str)
207210

208211

209212
cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
210-
cdef vector[const char*] param_keys
211-
cdef vector[const char*] param_vals
213+
cdef vector[string] sparam_keys
214+
cdef vector[string] sparam_vals
212215
cdef nn_uint num_args
213216
for k, v in kwargs.items():
214-
param_keys.push_back(k)
215-
param_vals.push_back(str(v))
217+
sparam_keys.push_back(c_str(k))
218+
sparam_vals.push_back(c_str(str(v)))
219+
# keep strings in vector
220+
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
221+
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
216222
num_args = param_keys.size()
217223
CALL(NNSymbolSetAttrs(
218224
handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals)))
@@ -225,7 +231,7 @@ cdef NewSymbol(SymbolHandle handle):
225231
return sym
226232

227233

228-
def Variable(const char* name, **kwargs):
234+
def Variable(name, **kwargs):
229235
"""Create a symbolic variable with specified name.
230236
231237
Parameters
@@ -241,6 +247,7 @@ def Variable(const char* name, **kwargs):
241247
The created variable symbol.
242248
"""
243249
cdef SymbolHandle handle
250+
name = c_str(name)
244251
CALL(NNSymbolCreateVariable(name, &handle))
245252
return NewSymbol(handle)
246253

@@ -274,10 +281,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
274281
func_hint = func_name.lower()
275282

276283
def creator(*args, **kwargs):
277-
cdef vector[const char*] param_keys
278-
cdef vector[const char*] param_vals
284+
cdef vector[string] sparam_keys
285+
cdef vector[string] sparam_vals
279286
cdef vector[SymbolHandle] symbol_args
280-
cdef vector[const char*] symbol_keys
287+
cdef vector[string] ssymbol_keys
281288
cdef SymbolHandle ret_handle
282289

283290
name = kwargs.pop("name", None)
@@ -286,11 +293,11 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
286293
if len(kwargs) != 0:
287294
for k, v in kwargs.items():
288295
if isinstance(v, Symbol):
289-
symbol_keys.push_back(k)
296+
ssymbol_keys.push_back(c_str(k))
290297
symbol_args.push_back((<Symbol>v).handle)
291298
else:
292-
param_keys.push_back(k)
293-
param_vals.push_back(str(v))
299+
sparam_keys.push_back(c_str(k))
300+
sparam_vals.push_back(c_str(str(v)))
294301

295302
if len(args) != 0:
296303
if symbol_args.size() != 0:
@@ -301,6 +308,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
301308
raise TypeError('Compose expect `Symbol` as arguments')
302309
symbol_args.push_back((<Symbol>v).handle)
303310

311+
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
312+
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
313+
cdef vector[const char*] symbol_keys = SVec2Ptr(ssymbol_keys)
314+
304315
CALL(NNSymbolCreateAtomicSymbol(
305316
handle,
306317
<nn_uint>param_keys.size(),
@@ -315,7 +326,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
315326
name = NameManager.current.get(name, func_hint)
316327

317328
cdef const char* c_name = NULL
329+
318330
if name:
331+
name = c_str(name)
319332
c_name = name
320333

321334
CALL(NNSymbolCompose(

0 commit comments

Comments
 (0)