diff --git a/common/dbinterface.cpp b/common/dbinterface.cpp index 53e6ccd31..a088960e9 100644 --- a/common/dbinterface.cpp +++ b/common/dbinterface.cpp @@ -55,7 +55,7 @@ bool DBInterface::exists(const string& dbName, const std::string& key) return m_redisClient.at(dbName).exists(key); } -std::string DBInterface::get(const std::string& dbName, const std::string& hash, const std::string& key, bool blocking) +std::shared_ptr DBInterface::get(const std::string& dbName, const std::string& hash, const std::string& key, bool blocking) { auto innerfunc = [&] { @@ -67,9 +67,9 @@ std::string DBInterface::get(const std::string& dbName, const std::string& hash, throw UnavailableDataError(message, hash); } const std::string& value = *pvalue; - return value == "None" ? "" : value; + return value == "None" ? std::shared_ptr() : std::make_shared(value); }; - return blockable(innerfunc, dbName, blocking); + return blockable>(innerfunc, dbName, blocking); } bool DBInterface::hexists(const std::string& dbName, const std::string& hash, const std::string& key) diff --git a/common/dbinterface.h b/common/dbinterface.h index ccf114a07..a1fcf2a2b 100644 --- a/common/dbinterface.h +++ b/common/dbinterface.h @@ -3,6 +3,7 @@ #include #include #include +#include #include "dbconnector.h" #include "logger.h" @@ -37,7 +38,7 @@ class DBInterface // Delete all keys which match %pattern from DB void delete_all_by_pattern(const std::string& dbName, const std::string& pattern); bool exists(const std::string& dbName, const std::string& key); - std::string get(const std::string& dbName, const std::string& hash, const std::string& key, bool blocking = false); + std::shared_ptr get(const std::string& dbName, const std::string& hash, const std::string& key, bool blocking = false); bool hexists(const std::string& dbName, const std::string& hash, const std::string& key); std::map get_all(const std::string& dbName, const std::string& hash, bool blocking = false); std::vector keys(const std::string& dbName, const char *pattern = "*", bool blocking = false); diff --git a/common/sonicv2connector.cpp b/common/sonicv2connector.cpp index fcce5c2d4..59552cde2 100644 --- a/common/sonicv2connector.cpp +++ b/common/sonicv2connector.cpp @@ -74,7 +74,7 @@ std::pair> SonicV2Connector_Native::scan(const std return m_dbintf.scan(db_name, cursor, match, count); } -std::string SonicV2Connector_Native::get(const std::string& db_name, const std::string& _hash, const std::string& key, bool blocking) +std::shared_ptr SonicV2Connector_Native::get(const std::string& db_name, const std::string& _hash, const std::string& key, bool blocking) { return m_dbintf.get(db_name, _hash, key, blocking); } diff --git a/common/sonicv2connector.h b/common/sonicv2connector.h index d34593185..541b5835f 100644 --- a/common/sonicv2connector.h +++ b/common/sonicv2connector.h @@ -35,7 +35,7 @@ class SonicV2Connector_Native std::pair> scan(const std::string& db_name, int cursor = 0, const char *match = "", uint32_t count = 10); - std::string get(const std::string& db_name, const std::string& _hash, const std::string& key, bool blocking=false); + std::shared_ptr get(const std::string& db_name, const std::string& _hash, const std::string& key, bool blocking=false); bool hexists(const std::string& db_name, const std::string& _hash, const std::string& key); diff --git a/pyext/swsscommon.i b/pyext/swsscommon.i index 99fc6eb05..5380250ea 100644 --- a/pyext/swsscommon.i +++ b/pyext/swsscommon.i @@ -41,6 +41,7 @@ %include %include %include +%include %include %include %include @@ -88,6 +89,21 @@ %template(GetTableResult) std::map>; %template(GetConfigResult) std::map>>; +%typemap(out) std::shared_ptr %{ + { + auto& p = static_cast&>($1); + if(p) + { + $result = PyUnicode_FromStringAndSize(p->c_str(), p->size()); + } + else + { + $result = Py_None; + Py_INCREF(Py_None); + } + } +%} + %pythoncode %{ def _FieldValueMap__get(self, key, default=None): if key in self: diff --git a/tests/test_redis_ut.py b/tests/test_redis_ut.py index 8a60ab7ae..1bebd2e04 100644 --- a/tests/test_redis_ut.py +++ b/tests/test_redis_ut.py @@ -203,7 +203,26 @@ def test_DBInterface(): db.connect("TEST_DB") redisclient = db.get_redis_client("TEST_DB") redisclient.flushdb() + + # Case: hset and hget normally db.set("TEST_DB", "key0", "field1", "value2") + val = db.get("TEST_DB", "key0", "field1") + assert val == "value2" + # Case: hset an empty value + db.set("TEST_DB", "kkk3", "field3", "") + val = db.get("TEST_DB", "kkk3", "field3") + assert val == "" + # Case: hset an "None" string value, hget will intepret it as true None (feature) + db.set("TEST_DB", "kkk3", "field3", "None") + val = db.get("TEST_DB", "kkk3", "field3") + assert val == None + # hget on an existing key but non-existing field + val = db.get("TEST_DB", "kkk3", "missing") + assert val == None + # hget on an non-existing key and non-existing field + val = db.get("TEST_DB", "kkk_missing", "missing") + assert val == None + fvs = db.get_all("TEST_DB", "key0") assert "field1" in fvs assert fvs["field1"] == "value2"