Skip to content

Commit

Permalink
Merge pull request #33072 from vespa-engine/toregge/use-issue-reporti…
Browse files Browse the repository at this point in the history
…ng-for-streaming-search

Use issue reporting for streaming search.
  • Loading branch information
toregge authored Jan 3, 2025
2 parents c547d5c + 7ddc040 commit c70d7cf
Show file tree
Hide file tree
Showing 14 changed files with 206 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ private Result buildResultFromCompletedVisitor(Query query, Visitor visitor) {
result.hits().addError(ErrorMessage.createTimeout("Missing hit summary data for " + skippedHits + " hits"));
}

var errors = visitor.getErrors();
for (var error : errors) {
result.hits().addError(ErrorMessage.createSearchReplyError(error));
}

return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.Collection;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Logger;
Expand Down Expand Up @@ -63,6 +64,7 @@ class StreamingVisitor extends VisitorDataHandler implements Visitor {
private static final Logger log = Logger.getLogger(StreamingVisitor.class.getName());
private final VisitorParameters params = new VisitorParameters("");
private List<SearchResult.Hit> hits = new ArrayList<>();
private Set<String> errors = new TreeSet<>();
private int totalHitCount = 0;

private final Map<String, DocumentSummary.Summary> summaryMap = new HashMap<>();
Expand Down Expand Up @@ -323,6 +325,10 @@ private void handleSearchResult(SearchResult result) {
synchronized (this) {
totalHitCount += result.getTotalHitCount();
hits = ListMerger.mergeIntoArrayList(hits, newHits, query.getOffset() + query.getHits());
var newErrors = result.getErrors();
for (var error : newErrors) {
errors.add(error);
}
}

Map<Integer, byte[]> newGroupingMap = result.getGroupingList();
Expand Down Expand Up @@ -387,4 +393,6 @@ final public List<Grouping> getGroupings() {
return new ArrayList<>(groupings);
}

@Override
public Set<String> getErrors() { return Set.copyOf(errors); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Visitor for performing searches and accessing results.
Expand All @@ -31,6 +32,8 @@ interface Visitor {

List<Grouping> getGroupings();

Set<String> getErrors();

Trace getTrace();

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;

import static org.junit.jupiter.api.Assertions.*;
Expand Down Expand Up @@ -142,6 +143,9 @@ public List<Grouping> getGroupings() {
return groupings;
}

@Override
public Set<String> getErrors() { return Set.of(); }

@Override
public Trace getTrace() {
return new Trace();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Map;
import java.util.function.Consumer;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
Expand Down Expand Up @@ -458,6 +459,7 @@ class QueryResultMessageTest implements RunnableTest {
@Override
public void run() throws Exception {
test_result_with_match_features();
test_result_with_errors();

Routable routable = deserialize("QueryResultMessage-1", DocumentProtocol.MESSAGE_QUERYRESULT, Language.CPP);
assertTrue(routable instanceof QueryResultMessage);
Expand Down Expand Up @@ -552,6 +554,16 @@ void test_result_with_match_features() {
assertEquals(1.0, mf.field("foo").asDouble(), 1E-6);
assertEqualsData(new byte[] { 'H', 'i' }, mf.field("bar").asData());
}

void test_result_with_errors() {
Routable routable = deserialize("QueryResultMessage-7", DocumentProtocol.MESSAGE_QUERYRESULT, Language.CPP);
assertTrue(routable instanceof QueryResultMessage);

var msg = (QueryResultMessage) routable;
assertEquals(0, msg.getResult().getHitCount());
var errors = msg.getResult().getErrors();
assertArrayEquals(new String[]{"hello", "world!"}, errors);
}
}

class QueryResultReplyTest implements RunnableTest {
Expand Down
14 changes: 14 additions & 0 deletions documentapi/src/tests/messages/messages80test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,20 @@ TEST_F(Messages80Test, query_result_message) {
EXPECT_EQ(mf_names[0], "foo");
EXPECT_EQ(mf_names[1], "bar");
}
QueryResultMessage qrm4;
auto& sr4 = qrm4.getSearchResult();
sr4.set_errors(std::vector<std::string>{"hello", "world!"});
sr4.sort();
serialize("QueryResultMessage-7", qrm4);
{
auto routable = deserialize("QueryResultMessage-7", DocumentProtocol::MESSAGE_QUERYRESULT, LANG_CPP);
ASSERT_TRUE(routable);
auto& dm = dynamic_cast<QueryResultMessage&>(*routable);
auto& dr = dm.getSearchResult();
EXPECT_EQ(dr.getHitCount(), size_t(0));
EXPECT_EQ((std::vector<std::string>{"hello", "world!"}), dr.get_errors());
}

}

TEST_F(Messages80Test, query_result_reply) {
Expand Down
Binary file not shown.
26 changes: 26 additions & 0 deletions streamingvisitors/src/tests/searchvisitor/searchvisitor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class RequestBuilder {
RequestBuilder& rank_profile(const std::string& value) { return set_param("rankprofile", value); }
RequestBuilder& summary_class(const std::string& value) { return set_param("summaryclass", value); }
RequestBuilder& summary_count(uint32_t value) { return set_param("summarycount", std::to_string(value)); }
RequestBuilder& sort(const std::string& value) { return set_param("sort", value); }
RequestBuilder& query_stack_count(uint32_t value) { return set_param("querystackcount", std::to_string(value)); }
RequestBuilder& string_term(const std::string& term, const std::string& field) {
_builder.addStringTerm(term, field, _term_id++, Weight(100));
return *this;
Expand Down Expand Up @@ -272,6 +274,30 @@ TEST_F(SearchVisitorTest, visitor_only_require_weak_read_consistency)
EXPECT_TRUE(session->visitor.getRequiredReadConsistency() == spi::ReadConsistency::WEAK);
}

namespace {

void
check_sorting(SearchVisitorTest& test, const std::string& sort_spec, const HitVector& exp_hits,
const std::vector<std::string>& exp_errors) {
SCOPED_TRACE(sort_spec);
auto res = test.execute_query(RequestBuilder().rank_profile("default").
number_term("[4;10]", "id").sort(sort_spec).
query_stack_count(1).build(),
{{5}, {4}, {3}, {7}});
expect_hits(exp_hits, *res);
EXPECT_EQ(exp_errors, res->getSearchResult().get_errors());
}

}

TEST_F(SearchVisitorTest, sorting_works)
{
check_sorting(*this, "-id", {{7,17.0}, {5,15.0}, {4, 14.0}}, {});
check_sorting(*this, "+id", {{4,14.0}, {5,15.0}, {7, 17.0}}, {});
check_sorting(*this, "-badid", {{7,17.0}, {5,15.0}, {4, 14.0}},
{"Cannot locate field 'badid' in field name registry"});
}

}

GTEST_MAIN_RUN_ALL_TESTS()
29 changes: 24 additions & 5 deletions streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <vespa/vespalib/util/size_literals.h>
#include <vespa/vespalib/data/slime/slime.h>
#include <vespa/vespalib/text/stringtokenizer.h>
#include <vespa/vespalib/util/issue.h>
#include <vespa/fnet/databuffer.h>
#include <vespa/fastlib/text/normwordfolder.h>
#include <optional>
Expand All @@ -48,6 +49,7 @@ using search::Normalizing;
using search::streaming::QueryTermList;
using storage::StorageComponent;
using storage::VisitorEnvironment;
using vespalib::Issue;
using vdslib::Parameters;
using vsm::DocsumFilter;
using vsm::FieldPath;
Expand Down Expand Up @@ -333,7 +335,8 @@ SearchVisitor::SearchVisitor(StorageComponent& component,
_rankAttribute(dynamic_cast<search::SingleFloatExtAttribute &>(*_rankAttributeBacking)),
_shouldFillRankAttribute(false),
_syntheticFieldsController(),
_rankController()
_rankController(),
_unique_issues()
{
LOG(debug, "Created SearchVisitor");
}
Expand Down Expand Up @@ -524,7 +527,7 @@ SearchVisitor::init(const Parameters & params)
setupAttributeVectorsForSorting(_sortSpec);

_rankController.setRankManagerSnapshot(_env->get_rank_manager_snapshot());
_rankController.setupRankProcessors(_query, location, wantedSummaryCount, ! _sortSpec.empty(), _attrMan, _attributeFields);
_rankController.setupRankProcessors(_query, location, wantedSummaryCount, ! _sortList.empty(), _attrMan, _attributeFields);

// This depends on _fieldPathMap (from setupScratchDocument),
// and IQueryEnvironment (from setupRankProcessors).
Expand Down Expand Up @@ -1016,13 +1019,13 @@ SearchVisitor::setupAttributeVectorsForSorting(const search::common::SortSpec &
}
_sortList.push_back(index);
} else {
LOG(warning, "Attribute '%s' is not sortable", sInfo._field.c_str());
Issue::report("Attribute '%s' is not sortable", sInfo._field.c_str());
}
} else {
LOG(warning, "Attribute '%s' is not valid", sInfo._field.c_str());
Issue::report("Attribute '%s' is not valid", sInfo._field.c_str());
}
} else {
LOG(warning, "Cannot locate field '%s' in field name registry", sInfo._field.c_str());
Issue::report("Cannot locate field '%s' in field name registry", sInfo._field.c_str());
}
}
} else {
Expand Down Expand Up @@ -1088,6 +1091,7 @@ SearchVisitor::compatibleDocumentTypes(const document::DocumentType& typeA,
void
SearchVisitor::handleDocuments(const document::BucketId&, DocEntryList & entries, HitCounter& )
{
auto capture_issues = Issue::listen(_unique_issues);
if (!_init_called) {
init(_params);
}
Expand Down Expand Up @@ -1262,6 +1266,7 @@ SearchVisitor::generate_query_result(HitCounter& counter)
void
SearchVisitor::completedVisitingInternal(HitCounter& hitCounter)
{
auto capture_issues = std::make_unique<Issue::Binding>(_unique_issues);
if (!_init_called) {
init(_params);
}
Expand All @@ -1287,6 +1292,8 @@ SearchVisitor::completedVisitingInternal(HitCounter& hitCounter)
generateGroupingResults();
generateDocumentSummaries();
documentSummary.sort();
capture_issues.reset();
generate_errors();
LOG(debug, "Docsum count: %lu", documentSummary.getSummaryCount());
}

Expand Down Expand Up @@ -1339,5 +1346,17 @@ SearchVisitor::generateDocumentSummaries()
}
}

void
SearchVisitor::generate_errors()
{
auto num_issues = _unique_issues.size();
if (num_issues == 0) {
return;
}
std::vector<std::string> errors;
errors.reserve(num_issues);
_unique_issues.for_each_message([&](const std::string &issue) { errors.emplace_back(issue); });
_queryResult->getSearchResult().set_errors(std::move(errors));
}

}
4 changes: 4 additions & 0 deletions streamingvisitors/src/vespa/searchvisitor/searchvisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <vespa/searchlib/attribute/attributevector.h>
#include <vespa/searchlib/attribute/extendableattributes.h>
#include <vespa/searchlib/common/sortspec.h>
#include <vespa/searchlib/common/unique_issues.h>
#include <vespa/storage/visiting/visitor.h>
#include <vespa/document/fieldvalue/fieldvalues.h>
#include <vespa/documentapi/messagebus/messages/queryresultmessage.h>
Expand Down Expand Up @@ -384,6 +385,8 @@ class SearchVisitor : public storage::Visitor,
**/
void generateDocumentSummaries();

void generate_errors();

class GroupingEntry : std::shared_ptr<Grouping> {
public:
explicit GroupingEntry(Grouping * grouping);
Expand Down Expand Up @@ -489,6 +492,7 @@ class SearchVisitor : public storage::Visitor,
SyntheticFieldsController _syntheticFieldsController;
RankController _rankController;
vsm::StringFieldIdTMapT _fieldsUnion;
search::UniqueIssues _unique_issues;

void setupAttributeVector(const vsm::FieldPath &fieldPath);
bool is_text_matching(std::string_view index) const noexcept override;
Expand Down
21 changes: 21 additions & 0 deletions vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package com.yahoo.vdslib;

import com.yahoo.data.access.helpers.MatchFeatureData;
import com.yahoo.text.Utf8;
import com.yahoo.vespa.objects.BufferSerializer;
import com.yahoo.vespa.objects.Deserializer;

Expand Down Expand Up @@ -64,8 +65,10 @@ public int compareTo(Hit h) {
private final Hit[] hits;
private final TreeMap<Integer, byte []> aggregatorList;
private final TreeMap<Integer, byte []> groupingList;
private String[] errors;
private static final int EXTENSION_FLAGS_PRESENT = -1;
private static final int MATCH_FEATURES_PRESENT_MASK = 1;
private static final int ERRORS_PRESENT_MASK = 2;

public SearchResult(Deserializer buf) {
BufferSerializer bser = (BufferSerializer) buf; // TODO: dirty cast. must do this differently
Expand Down Expand Up @@ -123,6 +126,11 @@ public SearchResult(Deserializer buf) {
if (hasMatchFeatures(extensionFlags)) {
deserializeMatchFeatures(buf, numHits);
}
if (hasErrors(extensionFlags)) {
deserializeErrors(buf);
} else {
this.errors = new String[0];
}
}

private void deserializeMatchFeatures(Deserializer buf, int numHits) {
Expand All @@ -147,6 +155,14 @@ private void deserializeMatchFeatures(Deserializer buf, int numHits) {
}
}

private void deserializeErrors(Deserializer buf) {
int numErrors = buf.getInt(null);
this.errors = new String[numErrors];
for (int i = 0; i < numErrors; ++i) {
errors[i] = Utf8.toString(buf.getBytes(null, buf.getInt(null)));
}
}

private static boolean hasExtensionFlags(int numHits) {
return numHits == EXTENSION_FLAGS_PRESENT;
}
Expand All @@ -155,6 +171,10 @@ private static boolean hasMatchFeatures(int extensionFlags) {
return (extensionFlags & MATCH_FEATURES_PRESENT_MASK) != 0;
}

private static boolean hasErrors(int extensionFlags) {
return (extensionFlags & ERRORS_PRESENT_MASK) != 0;
}

private static boolean isDoubleFeature(byte featureType) {
return featureType == 0;
}
Expand All @@ -163,4 +183,5 @@ private static boolean isDoubleFeature(byte featureType) {
final public int getTotalHitCount() { return (totalHits != 0) ? totalHits : getHitCount(); }
final public Hit getHit(int hitNo) { return hits[hitNo]; }
final public Map<Integer, byte []> getGroupingList() { return groupingList; }
final public String[] getErrors() { return errors; }
}
17 changes: 17 additions & 0 deletions vdslib/src/tests/container/searchresulttest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ void deserialize(SearchResult& sr, std::span<const char> buf)
EXPECT_EQ(0, dbuf.getRemaining());
}

void set_errors(SearchResult& sr, std::vector<std::string> errors) {
sr.set_errors(std::move(errors));
}

void populate(SearchResult& sr, FeatureValues& mf)
{
sr.addHit(7, "doc1", 5);
Expand Down Expand Up @@ -74,6 +78,12 @@ void check_match_features(const std::vector<char> & buf, const std::string& labe
check_match_features(sr, label, sort_remap);
}

std::vector<std::string> get_errors(const std::vector<char>& buf) {
SearchResult sr;
deserialize(sr, buf);
return sr.get_errors();
}

}

TEST(SearchResultTest, test_simple)
Expand Down Expand Up @@ -164,4 +174,11 @@ TEST(SearchResultTest, test_deserialized_match_features)
check_match_features(serialize(sr), "deserialized sorted", true);
}

TEST(SearchResultTest, test_errors)
{
SearchResult sr;
set_errors(sr, { "one two", "three four"});
EXPECT_EQ((std::vector<std::string>{"one two", "three four"}), get_errors(serialize(sr)));
}

}
Loading

0 comments on commit c70d7cf

Please sign in to comment.