Skip to content

Commit

Permalink
Fix memory issues in QueryMatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-steinegger committed Jul 10, 2021
1 parent 17c8028 commit 442d898
Showing 1 changed file with 42 additions and 33 deletions.
75 changes: 42 additions & 33 deletions src/prefiltering/QueryMatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ QueryMatcher::QueryMatcher(IndexTable *indexTable, SequenceLookup *sequenceLooku
short kmerThr, int kmerSize, size_t dbSize,
unsigned int maxSeqLen, size_t maxHitsPerQuery, bool aaBiasCorrection,
bool diagonalScoring, unsigned int minDiagScoreThr, bool takeOnlyBestKmer, bool isNucleotide)
: idx(indexTable->getAlphabetSize(), kmerSize), isNucleotide(isNucleotide)
: idx(indexTable->getAlphabetSize(), kmerSize), isNucleotide(isNucleotide)
{
this->kmerSubMat = kmerSubMat;
this->ungappedAlignmentSubMat = ungappedAlignmentSubMat;
Expand Down Expand Up @@ -103,63 +103,72 @@ std::pair<hit_t*, size_t> QueryMatcher::matchQuery(Sequence *querySeq, unsigned
// write diagonal scores in count value
ungappedAlignment->processQuery(querySeq, compositionBias, foundDiagonals, resultSize);
memset(scoreSizes, 0, SCORE_RANGE * sizeof(unsigned int));
updateScoreBins(foundDiagonals, resultSize);
size_t elementsCntAboveMinDiagonalThr = radixSortByScoreSize(scoreSizes, foundDiagonals + resultSize,
minDiagScoreThr, foundDiagonals, resultSize);
if (isNucleotide) {
CounterResult * resultReadPos = foundDiagonals;
CounterResult * resultWritePos = foundDiagonals + resultSize;
const bool canBeSorted = (resultSize < (foundDiagonalsSize / 2));
if (isNucleotide && canBeSorted) {
updateScoreBins(resultReadPos, resultSize);
//TODO can crash
size_t elementsCntAboveMinDiagonalThr = radixSortByScoreSize(scoreSizes, resultWritePos,
minDiagScoreThr, resultReadPos, resultSize);
std::swap(resultReadPos, resultWritePos);
size_t len;
// only sort the 255 bucket
for (len = 0; len < elementsCntAboveMinDiagonalThr
&& (foundDiagonals + resultSize)[len].count >= (UCHAR_MAX - ungappedAlignment->getQueryBias()); len++) { ;
&& resultReadPos[len].count >= (UCHAR_MAX - ungappedAlignment->getQueryBias()); len++) { ;
}
SORT_SERIAL((foundDiagonals + resultSize), (foundDiagonals + resultSize) + len, CounterResult::sortById);
SORT_SERIAL(resultReadPos, resultReadPos + len, CounterResult::sortById);
size_t prevId = UINT_MAX;//(foundDiagonals + resultSize)[0].id;
size_t max = 0;
size_t firstPos = 0;
for (size_t i = 0; i < len; i++) {
if (prevId == (foundDiagonals + resultSize)[i].id) {
if (prevId == resultReadPos[i].id) {
unsigned int newScore = ungappedAlignment->scoreSingelSequenceByCounterResult(
(foundDiagonals + resultSize)[i]);
resultReadPos[i]);
if (newScore > max) {
max = newScore;
(foundDiagonals + resultSize)[firstPos].diagonal = (foundDiagonals + resultSize)[i].diagonal;
resultReadPos[firstPos].diagonal = resultReadPos[i].diagonal;
}
} else {
max = i+1<len && (foundDiagonals + resultSize)[i+1].id == (foundDiagonals + resultSize)[i].id ? \
max = i+1<len && resultReadPos[i+1].id == resultReadPos[i].id ? \
ungappedAlignment->scoreSingelSequenceByCounterResult(
(foundDiagonals + resultSize)[i]) : 0 ;
resultReadPos[i]) : 0 ;
firstPos = i;
}
prevId = (foundDiagonals + resultSize)[i].id;
prevId = resultReadPos[i].id;
}
resultSize = keepMaxScoreElementOnly(resultReadPos, elementsCntAboveMinDiagonalThr);
memset(scoreSizes, 0, SCORE_RANGE * sizeof(unsigned int));
}else{
resultSize = keepMaxScoreElementOnly(resultReadPos, resultSize);
resultWritePos = foundDiagonals + resultSize;
}

unsigned int maxScoreElementsCount = keepMaxScoreElementOnly(foundDiagonals + resultSize, elementsCntAboveMinDiagonalThr);
memset(scoreSizes, 0, SCORE_RANGE * sizeof(unsigned int));
updateScoreBins(foundDiagonals + resultSize, maxScoreElementsCount);

updateScoreBins(resultReadPos, resultSize);
unsigned int diagonalThr = computeScoreThreshold(scoreSizes, this->maxHitsPerQuery);
diagonalThr = std::max(minDiagScoreThr, diagonalThr);

// sort to not lose highest scoring hits if > 150.000 hits are searched
if(resultSize < foundDiagonalsSize / 2){
unsigned int maxDiagonalScoreThr = (UCHAR_MAX - ungappedAlignment->getQueryBias());
bool scoreIsTruncated = (diagonalThr >= maxDiagonalScoreThr) ? true : false;
size_t elementsCntAboveDiagonalThr = radixSortByScoreSize(scoreSizes, foundDiagonals, diagonalThr, foundDiagonals+resultSize, maxScoreElementsCount);
size_t elementsCntAboveDiagonalThr = radixSortByScoreSize(scoreSizes, resultWritePos, diagonalThr, resultReadPos, resultSize);
std::swap(resultReadPos, resultWritePos);
if (scoreIsTruncated == true) {
memset(scoreSizes, 0, SCORE_RANGE * sizeof(unsigned int));
std::pair<size_t, unsigned int> rescoreResult = rescoreHits(querySeq, scoreSizes, foundDiagonals, elementsCntAboveDiagonalThr, ungappedAlignment, maxDiagonalScoreThr);
std::pair<size_t, unsigned int> rescoreResult = rescoreHits(querySeq, scoreSizes, resultReadPos, elementsCntAboveDiagonalThr, ungappedAlignment, maxDiagonalScoreThr);
size_t newResultSize = rescoreResult.first;
unsigned int maxSelfScoreMinusDiag = rescoreResult.second;
elementsCntAboveDiagonalThr = radixSortByScoreSize(scoreSizes, foundDiagonals+newResultSize, 0, foundDiagonals, newResultSize);
queryResult = getResult<UNGAPPED_DIAGONAL_SCORE>(foundDiagonals+newResultSize, elementsCntAboveDiagonalThr, identityId, 0, ungappedAlignment, maxSelfScoreMinusDiag);
elementsCntAboveDiagonalThr = radixSortByScoreSize(scoreSizes, resultWritePos, 0, resultReadPos, newResultSize);
std::swap(resultReadPos, resultWritePos);
queryResult = getResult<UNGAPPED_DIAGONAL_SCORE>(resultReadPos, elementsCntAboveDiagonalThr, identityId, 0, ungappedAlignment, maxSelfScoreMinusDiag);
}else{
queryResult = getResult<UNGAPPED_DIAGONAL_SCORE>(foundDiagonals, elementsCntAboveDiagonalThr, identityId, diagonalThr, ungappedAlignment, false);
queryResult = getResult<UNGAPPED_DIAGONAL_SCORE>(resultReadPos, elementsCntAboveDiagonalThr, identityId, diagonalThr, ungappedAlignment, false);
}
stats->truncated = 0;
}else{
//Debug(Debug::WARNING) << "Sequence " << querySeq->getDbKey() << " produces too many hits. Results might be truncated\n";
queryResult = getResult<UNGAPPED_DIAGONAL_SCORE>(foundDiagonals + resultSize, maxScoreElementsCount, identityId, diagonalThr, ungappedAlignment, false);
queryResult = getResult<UNGAPPED_DIAGONAL_SCORE>(resultReadPos, resultSize, identityId, diagonalThr, ungappedAlignment, false);
stats->truncated = 1;
}
}else{
Expand Down Expand Up @@ -419,9 +428,9 @@ void QueryMatcher::deleteDiagonalMatcher(unsigned int activeCounter){
}

size_t QueryMatcher::findDuplicates(IndexEntryLocal **hitsByIndex,
CounterResult *output, size_t outputSize,
unsigned short indexFrom, unsigned short indexTo,
bool computeTotalScore) {
CounterResult *output, size_t outputSize,
unsigned short indexFrom, unsigned short indexTo,
bool computeTotalScore) {
size_t localResultSize = 0;
#define COUNT_CASE(x) case x: localResultSize += cachedOperation##x->findDuplicates(hitsByIndex, output, outputSize, indexFrom, indexTo, computeTotalScore); break;
switch (activeCounter){
Expand Down Expand Up @@ -457,10 +466,10 @@ size_t QueryMatcher::keepMaxScoreElementOnly(CounterResult *foundDiagonals, size
}

size_t QueryMatcher::radixSortByScoreSize(const unsigned int * scoreSizes,
CounterResult *writePos,
const unsigned int scoreThreshold,
const CounterResult *results,
const size_t resultSize) {
CounterResult *writePos,
const unsigned int scoreThreshold,
const CounterResult *results,
const size_t resultSize) {
CounterResult * ptr[SCORE_RANGE];
ptr[0] = writePos+resultSize;
CounterResult * ptr_prev=ptr[0];
Expand All @@ -484,7 +493,7 @@ size_t QueryMatcher::radixSortByScoreSize(const unsigned int * scoreSizes,
}

std::pair<size_t, unsigned int> QueryMatcher::rescoreHits(Sequence * querySeq, unsigned int * scoreSizes, CounterResult *results,
size_t resultSize, UngappedAlignment *align, int lowerBoundScore) {
size_t resultSize, UngappedAlignment *align, int lowerBoundScore) {
size_t elements = 0;
const unsigned char * query = querySeq->numSequence;
int maxSelfScore = align->scoreSingleSequence(std::make_pair(query, querySeq->L), 0,0);
Expand All @@ -505,8 +514,8 @@ std::pair<size_t, unsigned int> QueryMatcher::rescoreHits(Sequence * querySeq, u
}

template std::pair<hit_t *, size_t> QueryMatcher::getResult<0>(CounterResult * results, size_t resultSize,
const unsigned int id, const unsigned short thr,
UngappedAlignment * align, const int rescaleScore);
const unsigned int id, const unsigned short thr,
UngappedAlignment * align, const int rescaleScore);
template std::pair<hit_t *, size_t> QueryMatcher::getResult<1>(CounterResult * results, size_t resultSize,
const unsigned int id, const unsigned short thr,
UngappedAlignment * align, const int rescaleScore);
Expand Down

0 comments on commit 442d898

Please sign in to comment.