diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.cpp b/internal/core/src/exec/expression/BinaryRangeExpr.cpp index 49d514cc8dbf4..0eccb3faf8bbc 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryRangeExpr.cpp @@ -68,24 +68,60 @@ PhyBinaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { } case DataType::JSON: { auto value_type = expr_->lower_val_.val_case(); - switch (value_type) { - case proto::plan::GenericValue::ValCase::kInt64Val: { - result = ExecRangeVisitorImplForJson(context); - break; - } - case proto::plan::GenericValue::ValCase::kFloatVal: { - result = ExecRangeVisitorImplForJson(context); - break; - } - case proto::plan::GenericValue::ValCase::kStringVal: { - result = ExecRangeVisitorImplForJson(context); - break; + if (is_index_mode_ && !has_offset_input_) { + switch (value_type) { + case proto::plan::GenericValue::ValCase::kInt64Val: { + proto::plan::GenericValue double_lower_val; + double_lower_val.set_float_val( + static_cast(expr_->lower_val_.int64_val())); + proto::plan::GenericValue double_upper_val; + double_upper_val.set_float_val( + static_cast(expr_->upper_val_.int64_val())); + + lower_arg_.SetValue(double_lower_val); + upper_arg_.SetValue(double_upper_val); + arg_inited_ = true; + + result = ExecRangeVisitorImplForIndex(); + break; + } + case proto::plan::GenericValue::ValCase::kFloatVal: { + result = ExecRangeVisitorImplForIndex(); + break; + } + case proto::plan::GenericValue::ValCase::kStringVal: { + result = + ExecRangeVisitorImplForJson(context); + break; + } + default: { + PanicInfo(DataTypeInvalid, + fmt::format( + "unsupported value type {} in expression", + value_type)); + } } - default: { - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported value type {} in expression", - value_type)); + } else { + switch (value_type) { + case proto::plan::GenericValue::ValCase::kInt64Val: { + result = ExecRangeVisitorImplForJson(context); + break; + } + case proto::plan::GenericValue::ValCase::kFloatVal: { + result = ExecRangeVisitorImplForJson(context); + break; + } + case proto::plan::GenericValue::ValCase::kStringVal: { + result = + ExecRangeVisitorImplForJson(context); + break; + } + default: { + PanicInfo(DataTypeInvalid, + fmt::format( + "unsupported value type {} in expression", + value_type)); + } } } break; diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.h b/internal/core/src/exec/expression/BinaryRangeExpr.h index 007e745bb0510..71c8b089a7f3d 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryRangeExpr.h @@ -252,7 +252,7 @@ class PhyBinaryRangeFilterExpr : public SegmentExpr { segment, expr->column_.field_id_, expr->column_.nested_path_, - DataType::NONE, + FromValCase(expr->lower_val_.val_case()), active_count, batch_size, consistency_level), diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 31e5f5483bc93..7e4af02321d6b 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include +#include #include #include #include @@ -27,8 +28,10 @@ #include "common/FieldDataInterface.h" #include "common/Json.h" +#include "common/JsonCastType.h" #include "common/LoadInfo.h" #include "common/Types.h" +#include "gtest/gtest.h" #include "index/Meta.h" #include "index/JsonInvertedIndex.h" #include "knowhere/comp/index_param.h" @@ -16880,3 +16883,193 @@ TEST_P(JsonIndexExistsTest, TestExistsExpr) { EXPECT_TRUE(result == expect_res); } } + +class JsonIndexBinaryExprTest : public testing::TestWithParam {}; + +INSTANTIATE_TEST_SUITE_P(JsonIndexBinaryExprTestParams, + JsonIndexBinaryExprTest, + testing::Values(JsonCastType::DOUBLE, + JsonCastType::VARCHAR)); + +TEST_P(JsonIndexBinaryExprTest, TestBinaryRangeExpr) { + auto json_strs = std::vector{ + R"({"a": 1})", + R"({"a": 2})", + R"({"a": 3})", + R"({"a": 4})", + + R"({"a": 1.0})", + R"({"a": 2.0})", + R"({"a": 3.0})", + R"({"a": 4.0})", + + R"({"a": "1"})", + R"({"a": "2"})", + R"({"a": "3"})", + R"({"a": "4"})", + + R"({"a": null})", + R"({"a": true})", + R"({"a": false})", + }; + + auto test_cases = std::vector>{ + // Exact match for integer 1 (matches both int 1 and float 1.0) + {std::make_any(1), + std::make_any(1), + true, + true, + 0b1000'1000'0000'000}, + + // Range [1, 3] inclusive (matches int 1,2,3 and float 1.0,2.0,3.0) + {std::make_any(1), + std::make_any(3), + true, + true, + 0b1110'1110'0000'000}, + + // Range (1, 3) exclusive (matches only int 2 and float 2.0) + {std::make_any(1), + std::make_any(3), + false, + false, + 0b0100'0100'0000'000}, + + // Range [1, 3) left inclusive, right exclusive (matches int 1,2 and float 1.0,2.0) + {std::make_any(1), + std::make_any(3), + true, + false, + 0b1100'1100'0000'000}, + + // Range (1, 3] left exclusive, right inclusive (matches int 2,3 and float 2.0,3.0) + {std::make_any(1), + std::make_any(3), + false, + true, + 0b0110'0110'0000'000}, + + // Float range test [1.0, 3.0] (matches int 1,2,3 and float 1.0,2.0,3.0) + {std::make_any(1.0), + std::make_any(3.0), + true, + true, + 0b1110'1110'0000'000}, + + // String range test ["1", "3"] (matches string "1","2","3") + {std::make_any("1"), + std::make_any("3"), + true, + true, + 0b0000'0000'1110'000}, + + // Range that should match nothing + {std::make_any(10), + std::make_any(20), + true, + true, + 0b0000'0000'0000'000}, + + // Range [2, 4] inclusive (matches int 2,3,4 and float 2.0,3.0,4.0) + {std::make_any(2), + std::make_any(4), + true, + true, + 0b0111'0111'0000'000}, + + // Mixed type range test - int to float [1, 3.0] + // {std::make_any(1), + // std::make_any(3.0), + // true, + // true, + // 0b1110'1110'0000'000}, + }; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateSealedSegment(schema); + segcore::LoadIndexInfo load_index_info; + + auto file_manager_ctx = storage::FileManagerContext(); + file_manager_ctx.fieldDataMeta.field_schema.set_data_type( + milvus::proto::schema::JSON); + file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get()); + + auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex( + index::INVERTED_INDEX_TYPE, GetParam(), "/a", file_manager_ctx); + + using json_index_type = index::JsonInvertedIndex; + auto json_index = std::unique_ptr( + static_cast(inv_index.release())); + auto json_field = + std::make_shared>(DataType::JSON, false); + std::vector jsons; + + for (auto& json : json_strs) { + jsons.push_back(milvus::Json(simdjson::padded_string(json))); + } + json_field->add_json_data(jsons); + + json_index->BuildWithFieldData({json_field}); + json_index->finish(); + json_index->create_reader(); + + load_index_info.field_id = json_fid.get(); + load_index_info.field_type = DataType::JSON; + load_index_info.index = std::move(json_index); + load_index_info.index_params = {{JSON_PATH, "/a"}}; + seg->LoadIndex(load_index_info); + + auto json_field_data_info = + FieldDataInfo(json_fid.get(), json_strs.size(), {json_field}); + seg->LoadFieldData(json_fid, json_field_data_info); + + for (auto& [lower, upper, lower_inclusive, upper_inclusive, result] : + test_cases) { + proto::plan::GenericValue lower_val; + proto::plan::GenericValue upper_val; + if (lower.type() == typeid(int64_t)) { + lower_val.set_int64_val(std::any_cast(lower)); + } else if (lower.type() == typeid(double)) { + lower_val.set_float_val(std::any_cast(lower)); + } else if (lower.type() == typeid(std::string)) { + lower_val.set_string_val(std::any_cast(lower)); + } + + if (upper.type() == typeid(int64_t)) { + upper_val.set_int64_val(std::any_cast(upper)); + } else if (upper.type() == typeid(double)) { + upper_val.set_float_val(std::any_cast(upper)); + } else if (upper.type() == typeid(std::string)) { + upper_val.set_string_val(std::any_cast(upper)); + } + + BitsetType expect_result; + expect_result.resize(json_strs.size()); + for (int i = json_strs.size() - 1; result > 0; i--) { + expect_result.set(i, (result & 0x1) != 0); + result >>= 1; + } + + auto binary_expr = std::make_shared( + expr::ColumnInfo(json_fid, DataType::JSON, {"a"}), + lower_val, + upper_val, + lower_inclusive, + upper_inclusive); + auto plan = std::make_shared(DEFAULT_PLANNODE_ID, + binary_expr); + auto res = + ExecuteQueryExpr(plan, seg.get(), json_strs.size(), MAX_TIMESTAMP); + EXPECT_TRUE(res == expect_result); + } +}