diff --git a/get_deps.sh b/get_deps.sh index 89ea86369..727f3a5c9 100755 --- a/get_deps.sh +++ b/get_deps.sh @@ -253,7 +253,7 @@ if [[ $WITH_PT != 0 ]]; then echo "Done." else - echo "librotch is in place." + echo "libtorch is in place." fi else echo "SKipping libtorch." diff --git a/src/dag.c b/src/dag.c index c818560d1..d47fc1965 100644 --- a/src/dag.c +++ b/src/dag.c @@ -158,7 +158,7 @@ void RedisAI_DagRunSession_ModelRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *curr RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(currentOp->mctx, outputNumber); const char *key_string = RedisModule_StringPtrLen( currentOp->outkeys[outputNumber], NULL); - AI_dictReplace(rinfo->dagTensorsContext, (void*)key_string, tensor); + AI_dictReplace(rinfo->dagTensorsContext, (void*)key_string, tensor ? RAI_TensorGetShallowCopy(tensor) : NULL); } currentOp->result = result; @@ -220,7 +220,7 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur RAI_ScriptRunCtxOutputTensor(currentOp->sctx, outputNumber); const char *key_string = RedisModule_StringPtrLen( currentOp->outkeys[outputNumber], NULL); - AI_dictReplace(rinfo->dagTensorsContext, (void*)key_string, tensor); + AI_dictReplace(rinfo->dagTensorsContext, (void*)key_string, tensor ? RAI_TensorGetShallowCopy(tensor) : NULL); } currentOp->result = result; @@ -434,7 +434,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, AI_dictEntry *tensor_entry = AI_dictFind(rinfo->dagTensorsContext, persist_key_name); if (tensor_entry) { - RAI_Tensor *tensor = AI_dictGetVal(tensor_entry); + RAI_Tensor *tensor = RAI_TensorGetShallowCopy(AI_dictGetVal(tensor_entry)); RedisModuleKey *key; char *demangled_key_name = RedisModule_Strdup(persist_key_name); demangled_key_name[strlen(persist_key_name) - 4] = 0; @@ -444,11 +444,13 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, ctx, tensor_keyname, &key, REDISMODULE_READ | REDISMODULE_WRITE); RedisModule_Free(demangled_key_name); if (status == REDISMODULE_ERR) { + RAI_TensorFree(tensor); RedisModule_ReplyWithError(ctx, "ERR could not save tensor"); rinfo->dagReplyLength++; } else { if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, tensor) != REDISMODULE_OK) { + RAI_TensorFree(tensor); RedisModule_ReplyWithError(ctx, "ERR could not save tensor"); rinfo->dagReplyLength++; } @@ -532,7 +534,7 @@ int RAI_parseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, RedisModule_CloseKey(key); char *dictKey = (char*) RedisModule_Alloc((strlen(arg_string) + 5)*sizeof(char)); sprintf(dictKey, "%s%04d", arg_string, 1); - AI_dictAdd(*localContextDict, (void*)dictKey, (void *)t); + AI_dictAdd(*localContextDict, (void*)dictKey, (void *)RAI_TensorGetShallowCopy(t)); AI_dictAdd(*loadedContextDict, (void*)dictKey, (void *)1); RedisModule_Free(dictKey); number_loaded_keys++; diff --git a/src/run_info.c b/src/run_info.c index b1d44d1c4..3ea893f02 100644 --- a/src/run_info.c +++ b/src/run_info.c @@ -246,36 +246,9 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) { } if (rinfo->dagTensorsContext) { - AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext); - AI_dictEntry *entry = AI_dictNext(iter); - RAI_Tensor *tensor = NULL; - - while (entry) { - tensor = AI_dictGetVal(entry); - char *key = (char *)AI_dictGetKey(entry); - - if (tensor && key != NULL) { - // if the key is persisted then we should not delete it - AI_dictEntry *persisted_entry = - AI_dictFind(rinfo->dagTensorsPersistedContext, key); - // if the key was loaded from the keyspace then we should not delete it - AI_dictEntry *loaded_entry = - AI_dictFind(rinfo->dagTensorsLoadedContext, key); - - if (persisted_entry == NULL && loaded_entry == NULL) { - AI_dictDelete(rinfo->dagTensorsContext, key); - } - - if (persisted_entry) { - AI_dictDelete(rinfo->dagTensorsPersistedContext, key); - } - if (loaded_entry) { - AI_dictDelete(rinfo->dagTensorsLoadedContext, key); - } - } - entry = AI_dictNext(iter); - } - AI_dictReleaseIterator(iter); + AI_dictRelease(rinfo->dagTensorsContext); + AI_dictRelease(rinfo->dagTensorsPersistedContext); + AI_dictRelease(rinfo->dagTensorsLoadedContext); } if (rinfo->dagOps) {