diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index 7b12d8062..a1f05567e 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -206,6 +206,7 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_ if (status != NULL) goto error; int64_t total_batch_size = dims[0]; + total_batch_size = total_batch_size > 0 ? total_batch_size : 1; shape = RedisModule_Calloc(ndims, sizeof(*shape)); strides = RedisModule_Calloc(ndims, sizeof(*strides)); diff --git a/src/backends/tensorflow.c b/src/backends/tensorflow.c index 750206679..baf7be54b 100644 --- a/src/backends/tensorflow.c +++ b/src/backends/tensorflow.c @@ -89,7 +89,8 @@ RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor, size_t batch_offset, const size_t ndims = TF_NumDims(tensor); - const int64_t total_batch_size = TF_Dim(tensor, 0); + int64_t total_batch_size = TF_Dim(tensor, 0); + total_batch_size = total_batch_size > 0 ? total_batch_size : 1; int64_t* shape = RedisModule_Calloc(ndims, sizeof(*shape)); int64_t* strides = RedisModule_Calloc(ndims, sizeof(*strides)); diff --git a/src/dag.c b/src/dag.c index 94fc96b3d..53c76e259 100644 --- a/src/dag.c +++ b/src/dag.c @@ -91,6 +91,16 @@ void *RedisAI_DagRunSession(RedisAI_RunInfo *rinfo) { currentOp->result = REDISMODULE_ERR; } } + // since we've increased the reference count prior modelrun we need to decrease it + const size_t ninputs = RAI_ModelRunCtxNumInputs(currentOp->mctx); + for (size_t inputNumber = 0; inputNumber < ninputs; inputNumber++) { + RAI_Tensor *tensor = + RAI_ModelRunCtxInputTensor(currentOp->mctx, inputNumber); + if (tensor) { + RAI_TensorFree(tensor); + } + } + } else { currentOp->result = REDISMODULE_ERR; } @@ -195,7 +205,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, } RedisModule_CloseKey(key); RedisAI_ReplicateTensorSet(ctx, tensor_keyname, tensor); - // TODO: free Tensor } else { RedisModule_ReplyWithError( ctx, "ERR specified persistent key that was not used on DAG"); diff --git a/src/model.c b/src/model.c index a9b2247fb..e8e1079cb 100644 --- a/src/model.c +++ b/src/model.c @@ -389,17 +389,20 @@ RAI_Tensor* RAI_ModelRunCtxOutputTensor(RAI_ModelRunCtx* mctx, size_t index) { return mctx->outputs[index].tensor; } -void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx) { - for (size_t i=0; iinputs); ++i) { - RAI_TensorFree(mctx->inputs[i].tensor); - } - array_free(mctx->inputs); +void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx, int freeTensors) { + if (freeTensors) { + for (size_t i=0; iinputs); ++i) { + RAI_TensorFree(mctx->inputs[i].tensor); + } - for (size_t i = 0 ; i < array_len(mctx->outputs) ; ++i) { - if (mctx->outputs[i].tensor) { - RAI_TensorFree(mctx->outputs[i].tensor); + for (size_t i = 0 ; i < array_len(mctx->outputs) ; ++i) { + if (mctx->outputs[i].tensor) { + RAI_TensorFree(mctx->outputs[i].tensor); + } } } + + array_free(mctx->inputs); array_free(mctx->outputs); RAI_Error err = {0}; diff --git a/src/model.h b/src/model.h index 90b019dc4..6df0a62aa 100644 --- a/src/model.h +++ b/src/model.h @@ -79,8 +79,9 @@ RAI_ModelRunCtx* RAI_ModelRunCtxCreate(RAI_Model* model); * work * * @param mctx + * @param freeTensors free input and output tensors or leave them allocated */ -void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx); +void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx, int freeTensors); /** * Allocates a RAI_ModelCtxParam data structure, and enforces a shallow copy of diff --git a/src/run_info.c b/src/run_info.c index 42070306b..af4f5bdc6 100644 --- a/src/run_info.c +++ b/src/run_info.c @@ -16,6 +16,40 @@ #include "util/arr_rm_alloc.h" #include "util/dict.h" + +static uint64_t RAI_TensorDictKeyHashFunction(const void *key){ + return AI_dictGenHashFunction(key, strlen((char*)key)); +} + +static int RAI_TensorDictKeyStrcmp(void *privdata, const void *key1, const void *key2){ + const char* strKey1 = key1; + const char* strKey2 = key2; + return strcmp(strKey1, strKey2) == 0; +} + +static void RAI_TensorDictKeyFree(void *privdata, void *key){ + RedisModule_Free(key); +} + +static void* RAI_TensorDictKeyDup(void *privdata, const void *key){ + return RedisModule_Strdup((char*)key); +} + +static void RAI_TensorDictValFree(void *privdata, const void *obj){ + return RAI_TensorFree((RAI_Tensor*)obj); +} + + +AI_dictType AI_dictTypeTensorVals = { + .hashFunction = RAI_TensorDictKeyHashFunction, + .keyDup = RAI_TensorDictKeyDup, + .valDup = NULL, + .keyCompare = RAI_TensorDictKeyStrcmp, + .keyDestructor = RAI_TensorDictKeyFree, + .valDestructor = RAI_TensorDictValFree, +}; + + /** * Allocate the memory and initialise the RAI_DagOp. * @param result Output parameter to capture allocated RAI_DagOp. @@ -76,7 +110,7 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) { return REDISMODULE_ERR; } rinfo->use_local_context = 0; - rinfo->dagTensorsContext = AI_dictCreate(&AI_dictTypeHeapStrings, NULL); + rinfo->dagTensorsContext = AI_dictCreate(&AI_dictTypeTensorVals, NULL); if (!(rinfo->dagTensorsContext)) { return REDISMODULE_ERR; } @@ -116,6 +150,13 @@ void RAI_FreeDagOp(RedisModuleCtx *ctx, RAI_DagOp *dagOp) { } array_free(dagOp->outTensors); + if (dagOp->mctx) { + RAI_ModelRunCtxFree(dagOp->mctx, false); + } + if (dagOp->sctx) { + RAI_ScriptRunCtxFree(dagOp->sctx, false); + } + RedisModule_Free(dagOp); } } @@ -125,37 +166,48 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) { return; } if (rinfo->mctx) { - RAI_ModelRunCtxFree(rinfo->mctx); + RAI_ModelRunCtxFree(rinfo->mctx, true); } if (rinfo->sctx) { - RAI_ScriptRunCtxFree(rinfo->sctx); + RAI_ScriptRunCtxFree(rinfo->sctx, true); } RAI_FreeError(rinfo->err); if (rinfo->dagTensorsContext) { AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext); - AI_dictEntry *stats_entry = AI_dictNext(iter); + AI_dictEntry *entry = AI_dictNext(iter); RAI_Tensor *tensor = NULL; - while (stats_entry) { - tensor = AI_dictGetVal(stats_entry); - char *key = (char *)AI_dictGetKey(stats_entry); + while (entry) { + tensor = AI_dictGetVal(entry); + char *key = (char *)AI_dictGetKey(entry); - if (tensor&&key!=NULL) { + if (tensor && key != NULL) { // if the key is persistent then we should not delete it AI_dictEntry *persistent_entry = AI_dictFind(rinfo->dagTensorsPersistentContext, key); - // if the key was loaded from the keyspace then we should not delete - // it + // if the key was loaded from the keyspace then we should not delete it AI_dictEntry *loaded_entry = AI_dictFind(rinfo->dagTensorsLoadedContext, key); + if (persistent_entry == NULL && loaded_entry == NULL) { - RAI_TensorFree(tensor); + AI_dictDelete(rinfo->dagTensorsContext, key); + } + + if (persistent_entry) { + AI_dictDelete(rinfo->dagTensorsPersistentContext, key); + } + if (loaded_entry) { + AI_dictDelete(rinfo->dagTensorsLoadedContext, key); } } - stats_entry = AI_dictNext(iter); + entry = AI_dictNext(iter); } AI_dictReleaseIterator(iter); + + RedisModule_Free(rinfo->dagTensorsContext); + RedisModule_Free(rinfo->dagTensorsLoadedContext); + RedisModule_Free(rinfo->dagTensorsPersistentContext); } if (rinfo->dagOps) { diff --git a/src/script.c b/src/script.c index 0fcba7958..eada1922e 100644 --- a/src/script.c +++ b/src/script.c @@ -182,17 +182,20 @@ RAI_Tensor* RAI_ScriptRunCtxOutputTensor(RAI_ScriptRunCtx* sctx, size_t index) { return sctx->outputs[index].tensor; } -void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx* sctx) { - for (size_t i = 0; i < array_len(sctx->inputs); ++i) { - RAI_TensorFree(sctx->inputs[i].tensor); - } - array_free(sctx->inputs); +void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx* sctx, int freeTensors) { + if (freeTensors) { + for (size_t i = 0; i < array_len(sctx->inputs); ++i) { + RAI_TensorFree(sctx->inputs[i].tensor); + } - for (size_t i = 0; i < array_len(sctx->outputs); ++i) { - if (sctx->outputs[i].tensor) { - RAI_TensorFree(sctx->outputs[i].tensor); + for (size_t i = 0; i < array_len(sctx->outputs); ++i) { + if (sctx->outputs[i].tensor) { + RAI_TensorFree(sctx->outputs[i].tensor); + } } } + + array_free(sctx->inputs); array_free(sctx->outputs); RedisModule_Free(sctx->fnname); diff --git a/src/script.h b/src/script.h index 17337ca30..7936f8537 100644 --- a/src/script.h +++ b/src/script.h @@ -103,8 +103,9 @@ RAI_Tensor* RAI_ScriptRunCtxOutputTensor(RAI_ScriptRunCtx* sctx, size_t index); * work * * @param sctx + * @param freeTensors free input and output tensors or leave them allocated */ -void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx* sctx); +void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx* sctx, int freeTensors); /** * Given the input script context, run associated script diff --git a/test/includes.py b/test/includes.py index 5cf52554c..237f556a6 100755 --- a/test/includes.py +++ b/test/includes.py @@ -16,6 +16,7 @@ except: pass +MAX_ITERATIONS = 2 if os.environ.get("MAX_ITERATIONS") == None else os.environ.get("MAX_ITERATIONS") TEST_TF = os.environ.get("TEST_TF") != "0" and os.environ.get("WITH_TF") != "0" TEST_TFLITE = os.environ.get("TEST_TFLITE") != "0" and os.environ.get("WITH_TFLITE") != "0" TEST_PT = os.environ.get("TEST_PT") != "0" and os.environ.get("WITH_PT") != "0" @@ -24,7 +25,7 @@ DEVICE = os.environ.get('DEVICE', 'CPU').upper().encode('utf-8', 'ignore').decode('utf-8') VALGRIND = os.environ.get("VALGRIND") == "1" print(f"Running tests on {DEVICE}\n") - +print(f"Using a max of {MAX_ITERATIONS} iterations per test\n") # change this to make inference tests longer MAX_TRANSACTIONS=100 @@ -67,11 +68,59 @@ def info_to_dict(info): return dict(zip(info[::2], info[1::2])) -def load_mobilenet_test_data(): +def load_resnet_test_data(): + test_data_path = os.path.join(os.path.dirname(__file__), 'test_data/imagenet') + labels_filename = os.path.join(test_data_path, 'imagenet_class_index.json') + image_filename = os.path.join(test_data_path, 'dog.jpg') + model_filename = os.path.join(test_data_path, 'resnet50.pb') + script_filename = os.path.join(test_data_path, 'data_processing_script.txt') + + with open(script_filename, 'rb') as f: + script = f.read() + + with open(model_filename, 'rb') as f: + model_pb = f.read() + + with open(labels_filename, 'r') as f: + labels = json.load(f) + + img_height, img_width = 224, 224 + + img = imread(image_filename) + img = resize(img, (img_height, img_width), mode='constant', anti_aliasing=True) + img = img.astype(np.uint8) + + return model_pb, script, labels, img + +def load_mobilenet_v1_test_data(): + test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') + labels_filename = os.path.join(test_data_path, 'imagenet_class_index.json') + image_filename = os.path.join(test_data_path, 'panda.jpg') + model_filename = os.path.join(test_data_path, 'mobilenet/mobilenet_v1_100_224_cpu_NxHxWxC.pb') + input_var = 'input' + output_var = 'MobilenetV1/Predictions/Reshape_1' + + with open(model_filename, 'rb') as f: + model_pb = f.read() + + with open(labels_filename, 'r') as f: + labels = json.load(f) + + img_height, img_width = 224, 224 + + img = imread(image_filename) + img = resize(img, (img_height, img_width), mode='constant', anti_aliasing=True) + img = img.astype(np.float32) + + return model_pb, input_var, output_var, labels, img + +def load_mobilenet_v2_test_data(): test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') labels_filename = os.path.join(test_data_path, 'imagenet_class_index.json') image_filename = os.path.join(test_data_path, 'panda.jpg') - model_filename = os.path.join(test_data_path, 'mobilenet_v2_1.4_224_frozen.pb') + model_filename = os.path.join(test_data_path, 'mobilenet/mobilenet_v2_1.4_224_frozen.pb') + input_var = 'input' + output_var = 'MobilenetV2/Predictions/Reshape_1' with open(model_filename, 'rb') as f: model_pb = f.read() @@ -85,7 +134,7 @@ def load_mobilenet_test_data(): img = resize(img, (img_height, img_width), mode='constant', anti_aliasing=True) img = img.astype(np.float32) - return model_pb, labels, img + return model_pb, input_var, output_var, labels, img def load_creditcardfraud_data(env,max_tensors=10000): test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') diff --git a/test/test_data/batchdim_mismatch.onnx b/test/test_data/batchdim_mismatch.onnx new file mode 100644 index 000000000..49a4d5be8 --- /dev/null +++ b/test/test_data/batchdim_mismatch.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7be546e4ea8636bd79c26e4ff06e28a70a4f8082f8b9c28fbcae6adb0e46e24d +size 143 diff --git a/test/test_data/batchdim_mismatch.pt b/test/test_data/batchdim_mismatch.pt new file mode 100644 index 000000000..efe30d697 --- /dev/null +++ b/test/test_data/batchdim_mismatch.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d58ef7dc37681521fa8c8856f4c256b614b535efdc33e838836fee3bde9106da +size 1317 diff --git a/test/test_data/mobilenet/mobilenet_v1_100_224_cpu_NxHxWxC.pb b/test/test_data/mobilenet/mobilenet_v1_100_224_cpu_NxHxWxC.pb new file mode 100644 index 000000000..30ba12b3c --- /dev/null +++ b/test/test_data/mobilenet/mobilenet_v1_100_224_cpu_NxHxWxC.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbb2752038ff1749d2b55988bb5f6e999a799c19413a0691b82d29f7aec0bab3 +size 17198345 diff --git a/test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC.pb b/test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC.pb new file mode 100644 index 000000000..2e8871769 --- /dev/null +++ b/test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1fe206dfd3cff261cf403b5757abec886da445a80056e55310ddac0b2805a3b +size 17198345 diff --git a/test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC_fp16_trt.pb b/test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC_fp16_trt.pb new file mode 100644 index 000000000..197733e00 --- /dev/null +++ b/test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC_fp16_trt.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd925f4b59d8d5035ccb2ecdfbf9b0f47a5ba3acfa81bd5a18536f69021df74a +size 34277746 diff --git a/test/test_data/mobilenet/mobilenet_v2_1.4_224_frozen.pb b/test/test_data/mobilenet/mobilenet_v2_1.4_224_frozen.pb new file mode 100644 index 000000000..41e3481fd --- /dev/null +++ b/test/test_data/mobilenet/mobilenet_v2_1.4_224_frozen.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:111479258f3841c93d0a7a377c976c24e8281077818991931429d2277dd88590 +size 24508794 diff --git a/test/test_data/mobilenet/model_saver.py b/test/test_data/mobilenet/model_saver.py new file mode 100644 index 000000000..4ba9da4b7 --- /dev/null +++ b/test/test_data/mobilenet/model_saver.py @@ -0,0 +1,49 @@ +import tensorflow as tf +import tensorflow_hub as hub +import ml2rt +import argparse +import sys + +url = 'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/quantops/classification/3' +model_name = 'mobilenet_v1_100_224' +module = hub.Module(url) +batch_size = 1 +number_channels = 3 +height, width = hub.get_expected_image_size(module) +input_var = 'input' +output_var = 'MobilenetV1/Predictions/Reshape_1' + +parser = argparse.ArgumentParser() +parser.add_argument('--gpu', action="store_true", default=False) +parser.add_argument('--input-shape', default="NxHxWxC", type=str) +args = parser.parse_args() +device = 'gpu' if args.gpu else 'cpu' + +gpu_available = tf.test.is_gpu_available( + cuda_only=True, min_cuda_compute_capability=None +) + +if gpu_available is False and args.gpu: + print("No CUDA GPUs found. Exiting...") + sys.exit(1) + +var_converter = tf.compat.v1.graph_util.convert_variables_to_constants + +if args.input_shape == "NxHxWxC": + print("Saving N x H x W x C (1, 224, 224, 3) (with channels_last data format)") + images = tf.compat.v1.placeholder(tf.float32, shape=( + batch_size, height, width, number_channels), name=input_var) +elif args.input_shape == "NxHxWxC": + print("Saving N x C x H x W (1, 3, 224, 224)") + images = tf.placeholder(tf.float32, shape=( + batch_size, number_channels, height, width), name=input_var) +else: + print("inputs shape is either NxHxWxC or NxCxHxW. Exiting...") + sys.exit(1) + +logits = module(images) +logits = tf.identity(logits, output_var) +with tf.compat.v1.Session() as sess: + sess.run([tf.compat.v1.global_variables_initializer()]) + ml2rt.save_tensorflow(sess, '{model_name}_{device}_{input_shape}.pb'.format( + model_name=model_name, device=device, input_shape=args.input_shape), output=[output_var]) diff --git a/test/test_data/onnx_batchdim_mismatch.py b/test/test_data/onnx_batchdim_mismatch.py new file mode 100644 index 000000000..fea913578 --- /dev/null +++ b/test/test_data/onnx_batchdim_mismatch.py @@ -0,0 +1,17 @@ +import torch + +class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, a, b): + return a, torch.tensor([]) + + +my_module = MyModule() + +dummy = (torch.rand(2), torch.rand(2)) +torch.onnx.export(my_module, dummy, "batchdim_mismatch.onnx") + +my_module_traced = torch.jit.trace(my_module, dummy) +torch.jit.save(my_module_traced, "batchdim_mismatch.pt") diff --git a/test/tests_onnx.py b/test/tests_onnx.py index 291c0f696..fff7f9ebb 100644 --- a/test/tests_onnx.py +++ b/test/tests_onnx.py @@ -152,6 +152,27 @@ def test_onnx_modelrun_mnist(env): env.assertEqual(values2, values) +def test_onnx_modelrun_batchdim_mismatch(env): + con = env.getConnection() + + test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') + model_filename = os.path.join(test_data_path, 'batchdim_mismatch.onnx') + + with open(model_filename, 'rb') as f: + model_pb = f.read() + + ret = con.execute_command('AI.MODELSET', 'm', 'ONNX', DEVICE, 'BLOB', model_pb) + env.assertEqual(ret, b'OK') + + ensureSlaveSynced(con, env) + + con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 'VALUES', 1, 1) + con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 'VALUES', 1, 1) + + con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c', 'd') + + + def test_onnx_modelrun_mnist_autobatch(env): if not TEST_ONNX: return diff --git a/test/tests_pytorch.py b/test/tests_pytorch.py index 638be7bf8..44206dada 100644 --- a/test/tests_pytorch.py +++ b/test/tests_pytorch.py @@ -167,6 +167,26 @@ def test_pytorch_modelrun(env): env.assertEqual(values2, values) +def test_pytorch_modelrun_batchdim_mismatch(env): + con = env.getConnection() + + test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') + model_filename = os.path.join(test_data_path, 'batchdim_mismatch.pt') + + with open(model_filename, 'rb') as f: + model_pb = f.read() + + ret = con.execute_command('AI.MODELSET', 'm', 'TORCH', DEVICE, 'BLOB', model_pb) + env.assertEqual(ret, b'OK') + + ensureSlaveSynced(con, env) + + con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 'VALUES', 1, 1) + con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 'VALUES', 1, 1) + + con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c', 'd') + + def test_pytorch_modelrun_autobatch(env): if not TEST_PT: return diff --git a/test/tests_sanitizer.py b/test/tests_sanitizer.py new file mode 100644 index 000000000..4f0c45ca4 --- /dev/null +++ b/test/tests_sanitizer.py @@ -0,0 +1,87 @@ +import redis +from functools import wraps +import multiprocessing as mp +from includes import * + +''' +python -m RLTest --test tests_sanitizer.py --module path/to/redisai.so +''' + + +def test_sanitizer_dagrun_mobilenet_v1(env): + if (not TEST_TF or not TEST_PT): + return + con = env.getConnection() + mem_allocator = con.execute_command('info', 'memory')['mem_allocator'] + if 'jemalloc' in mem_allocator: + print("exiting sanitizer test given we're not using stdlib allocator") + return + + model_name = 'mobilenet_v1' + model_pb, input_var, output_var, labels, img = load_mobilenet_v1_test_data() + + ret = con.execute_command('AI.MODELSET', model_name, 'TF', DEVICE, + 'INPUTS', input_var, + 'OUTPUTS', output_var, + 'BLOB', model_pb) + env.assertEqual(ret, b'OK') + + for opnumber in range(1, MAX_ITERATIONS): + image_key = 'image{}'.format(opnumber) + class_key = 'output' + + ret = con.execute_command( + 'AI.DAGRUN', '|>', + 'AI.TENSORSET', image_key, 'FLOAT', 1, 224, 224, 3, 'BLOB', img.tobytes(), + '|>', + 'AI.MODELRUN', model_name, + 'INPUTS', image_key, + 'OUTPUTS', class_key, + '|>', + 'AI.TENSORGET', class_key, 'blob' + ) + env.assertEqual([b'OK', b'OK'], ret[:2]) + env.assertEqual(1001.0, len(ret[2])/4) + + +def test_sanitizer_modelrun_mobilenet_v1(env): + if (not TEST_TF or not TEST_PT): + return + con = env.getConnection() + mem_allocator = con.execute_command('info', 'memory')['mem_allocator'] + if 'jemalloc' in mem_allocator: + print("exiting sanitizer test given we're not using stdlib allocator") + return + + model_name = 'mobilenet_v1' + model_pb, input_var, output_var, labels, img = load_mobilenet_v1_test_data() + + ret = con.execute_command('AI.MODELSET', model_name, 'TF', DEVICE, + 'INPUTS', input_var, + 'OUTPUTS', output_var, + 'BLOB', model_pb) + env.assertEqual(ret, b'OK') + + for opnumber in range(1, MAX_ITERATIONS): + image_key = 'image' + temp_key1 = 'temp_key1' + temp_key2 = 'temp_key2' + class_key = 'output' + ret = con.execute_command( + 'AI.TENSORSET', image_key, 'FLOAT', 1, 224, 224, 3, 'BLOB', img.tobytes() + ) + env.assertEqual(b'OK', ret) + + ret = con.execute_command( + 'AI.MODELRUN', model_name, + 'INPUTS', image_key, + 'OUTPUTS', class_key + ) + + env.assertEqual(b'OK', ret) + + ret = con.execute_command( + 'AI.TENSORGET', class_key, 'blob' + ) + + env.assertEqual(1001.0, len(ret)/4) diff --git a/test/tests_tensorflow.py b/test/tests_tensorflow.py index fda1ab59c..b494b1f07 100644 --- a/test/tests_tensorflow.py +++ b/test/tests_tensorflow.py @@ -24,10 +24,7 @@ def wrapper(env, *args, **kwargs): def test_run_mobilenet(env): con = env.getConnection() - input_var = 'input' - output_var = 'MobilenetV2/Predictions/Reshape_1' - - model_pb, labels, img = load_mobilenet_test_data() + model_pb, input_var, output_var, labels, img = load_mobilenet_v2_test_data() con.execute_command('AI.MODELSET', 'mobilenet', 'TF', DEVICE, 'INPUTS', input_var, 'OUTPUTS', output_var, 'BLOB', model_pb) @@ -94,10 +91,7 @@ def test_run_mobilenet_multiproc(env): con = env.getConnection() - input_var = 'input' - output_var = 'MobilenetV2/Predictions/Reshape_1' - - model_pb, labels, img = load_mobilenet_test_data() + model_pb, input_var, output_var, labels, img = load_mobilenet_v2_test_data() con.execute_command('AI.MODELSET', 'mobilenet', 'TF', DEVICE, 'INPUTS', input_var, 'OUTPUTS', output_var, 'BLOB', model_pb) ensureSlaveSynced(con, env) @@ -627,15 +621,12 @@ def test_tensorflow_modelrun_with_batch_and_minbatch(env): minbatch_size = 2 model_name = 'model' another_model_name = 'another_model' - inputvar = 'input' - outputvar = 'MobilenetV2/Predictions/Reshape_1' - - model_pb, labels, img = load_mobilenet_test_data() + model_pb, input_var, output_var, labels, img = load_mobilenet_v2_test_data() con.execute_command('AI.MODELSET', model_name, 'TF', DEVICE, 'BATCHSIZE', batch_size, 'MINBATCHSIZE', minbatch_size, - 'INPUTS', inputvar, - 'OUTPUTS', outputvar, + 'INPUTS', input_var, + 'OUTPUTS', output_var, 'BLOB', model_pb) con.execute_command('AI.TENSORSET', 'input', 'FLOAT', 1, img.shape[1], img.shape[0], img.shape[2], @@ -658,8 +649,8 @@ def run(name=model_name, output_name='output'): con.execute_command('AI.MODELSET', another_model_name, 'TF', DEVICE, 'BATCHSIZE', batch_size, 'MINBATCHSIZE', minbatch_size, - 'INPUTS', inputvar, - 'OUTPUTS', outputvar, + 'INPUTS', input_var, + 'OUTPUTS', output_var, 'BLOB', model_pb) p1b = mp.Process(target=run, args=(another_model_name, 'final1'))