Skip to content

Commit bd080e6

Browse files
committed
Add weighted clustering
1 parent 1fea43d commit bd080e6

File tree

12 files changed

+252
-13
lines changed

12 files changed

+252
-13
lines changed

src/clustering/Clustering.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
#include "Util.h"
55
#include "itoa.h"
66
#include "Timer.h"
7+
#include "SequenceWeights.h"
78

89
Clustering::Clustering(const std::string &seqDB, const std::string &seqDBIndex,
910
const std::string &alnDB, const std::string &alnDBIndex,
1011
const std::string &outDB, const std::string &outDBIndex,
12+
const std::string &sequenceWeightFile,
1113
unsigned int maxIteration, int similarityScoreType, int threads, int compressed) : maxIteration(maxIteration),
1214
similarityScoreType(similarityScoreType),
1315
threads(threads),
@@ -16,7 +18,21 @@ Clustering::Clustering(const std::string &seqDB, const std::string &seqDBIndex,
1618
outDBIndex(outDBIndex) {
1719

1820
seqDbr = new DBReader<unsigned int>(seqDB.c_str(), seqDBIndex.c_str(), threads, DBReader<unsigned int>::USE_INDEX);
19-
seqDbr->open(DBReader<unsigned int>::SORT_BY_LENGTH);
21+
22+
if (!sequenceWeightFile.empty()) {
23+
24+
seqDbr->open(DBReader<unsigned int>::SORT_BY_ID);
25+
26+
SequenceWeights *sequenceWeights = new SequenceWeights(sequenceWeightFile.c_str());
27+
float localid2weight[seqDbr->getSize()];
28+
for (size_t id = 0; id < seqDbr->getSize(); id++) {
29+
size_t key = seqDbr->getDbKey(id);
30+
localid2weight[id] = sequenceWeights->getWeightById(key);
31+
}
32+
seqDbr->sortIndex(localid2weight);
33+
34+
} else
35+
seqDbr->open(DBReader<unsigned int>::SORT_BY_LENGTH);
2036

2137
alnDbr = new DBReader<unsigned int>(alnDB.c_str(), alnDBIndex.c_str(), threads, DBReader<unsigned int>::USE_DATA|DBReader<unsigned int>::USE_INDEX);
2238
alnDbr->open(DBReader<unsigned int>::NOSORT);

src/clustering/Clustering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class Clustering {
1111
Clustering(const std::string &seqDB, const std::string &seqDBIndex,
1212
const std::string &alnResultsDB, const std::string &alnResultsDBIndex,
1313
const std::string &outDB, const std::string &outDBIndex,
14+
const std::string &weightFileName,
1415
unsigned int maxIteration, int similarityScoreType, int threads, int compressed);
1516

1617
void run(int mode);

src/clustering/ClusteringAlgorithms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ void ClusteringAlgorithms::greedyIncrementalLowMem( unsigned int *assignedcluste
282282
#pragma omp for schedule(dynamic, 1000)
283283
for(size_t i = 0; i < dbSize; i++) {
284284
unsigned int clusterKey = seqDbr->getDbKey(i);
285-
unsigned int clusterId = seqDbr->getId(clusterKey);
285+
unsigned int clusterId = i;
286286

287287
// try to set your self as cluster centriod
288288
// if some other cluster covered

src/clustering/Main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ int clust(int argc, const char **argv, const Command& command) {
66
par.parseParameters(argc, argv, command, true, 0, 0);
77

88
Clustering clu(par.db1, par.db1Index, par.db2, par.db2Index,
9-
par.db3, par.db3Index, par.maxIteration,
9+
par.db3, par.db3Index, par.weightFile, par.maxIteration,
1010
par.similarityScoreType, par.threads, par.compressed);
1111
clu.run(par.clusteringMode);
1212
return EXIT_SUCCESS;

src/commons/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ set(commons_header_files
3434
commons/PatternCompiler.h
3535
commons/ScoreMatrix.h
3636
commons/Sequence.h
37+
commons/SequenceWeights.h
3738
commons/StringBlock.h
3839
commons/SubstitutionMatrix.h
3940
commons/SubstitutionMatrixProfileStates.h
@@ -69,6 +70,7 @@ set(commons_source_files
6970
commons/ProfileStates.cpp
7071
commons/LibraryReader.cpp
7172
commons/Sequence.cpp
73+
commons/SequenceWeights.cpp
7274
commons/SubstitutionMatrix.cpp
7375
commons/tantan.cpp
7476
commons/UniprotKB.cpp

src/commons/DBReader.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ template <typename T> bool DBReader<T>::open(int accessType){
213213
template<typename T>
214214
void DBReader<T>::sortIndex(bool) {
215215
}
216+
template<typename T>
217+
void DBReader<T>::sortIndex(float*) {
218+
}
216219

217220
template<typename T>
218221
bool DBReader<T>::isSortedByOffset(){
@@ -234,8 +237,31 @@ void DBReader<std::string>::sortIndex(bool isSortedById) {
234237
}
235238
}
236239

240+
template<>
241+
void DBReader<unsigned int>::sortIndex(float *weights) {
242+
243+
this->accessType=DBReader::SORT_BY_WEIGHTS;
244+
std::pair<unsigned int, float> *sortForMapping = new std::pair<unsigned int, float>[size];
245+
id2local = new unsigned int[size];
246+
local2id = new unsigned int[size];
247+
incrementMemory(sizeof(unsigned int) * 2 * size);
248+
for (size_t i = 0; i < size; i++) {
249+
id2local[i] = i;
250+
local2id[i] = i;
251+
sortForMapping[i] = std::make_pair(i, weights[i]);
252+
}
253+
//this sort has to be stable to assure same clustering results
254+
SORT_PARALLEL(sortForMapping, sortForMapping + size, comparePairByWeight());
255+
for (size_t i = 0; i < size; i++) {
256+
id2local[sortForMapping[i].first] = i;
257+
local2id[i] = sortForMapping[i].first;
258+
}
259+
delete[] sortForMapping;
260+
}
261+
237262
template<>
238263
void DBReader<unsigned int>::sortIndex(bool isSortedById) {
264+
239265
// First, we sort the index by IDs and we keep track of the original
240266
// ordering in mappingToOriginalIndex array
241267
size_t* mappingToOriginalIndex=NULL;
@@ -294,7 +320,6 @@ void DBReader<unsigned int>::sortIndex(bool isSortedById) {
294320
mappingToOriginalIndex[i] = i;
295321
}
296322
}
297-
298323
if (accessType == SORT_BY_LENGTH) {
299324
// sort the entries by the length of the sequences
300325
std::pair<unsigned int, unsigned int> *sortForMapping = new std::pair<unsigned int, unsigned int>[size];

src/commons/DBReader.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ class DBReader : public MemoryTracker {
257257
static const int HARDNOSORT = 6; // do not even sort by ids.
258258
static const int SORT_BY_ID_OFFSET = 7;
259259
static const int SORT_BY_OFFSET = 8; // only offset sorting saves memory and does not support random access
260+
static const int SORT_BY_WEIGHTS= 9;
260261

261262

262263
static const unsigned int USE_INDEX = 0;
@@ -317,6 +318,7 @@ class DBReader : public MemoryTracker {
317318
void mlock();
318319

319320
void sortIndex(bool isSortedById);
321+
void sortIndex(float *weights);
320322
bool isSortedByOffset();
321323

322324
void unmapData();
@@ -392,6 +394,20 @@ class DBReader : public MemoryTracker {
392394
}
393395
};
394396

397+
struct comparePairByWeight {
398+
bool operator() (const std::pair<unsigned int, float>& lhs, const std::pair<unsigned int, float>& rhs) const{
399+
if(lhs.second > rhs.second)
400+
return true;
401+
if(rhs.second > lhs.second)
402+
return false;
403+
if(lhs.first < rhs.first )
404+
return true;
405+
if(rhs.first < lhs.first )
406+
return false;
407+
return false;
408+
}
409+
};
410+
395411
struct comparePairByIdAndOffset {
396412
bool operator() (const std::pair<unsigned int, Index>& lhs, const std::pair<unsigned int, Index>& rhs) const{
397413
if(lhs.second.id < rhs.second.id)

src/commons/Parameters.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ Parameters::Parameters():
151151
PARAM_PICK_N_SIMILAR(PARAM_PICK_N_SIMILAR_ID, "--pick-n-sim-kmer", "Add N similar to search", "Add N similar k-mers to search", typeid(int), (void *) &pickNbest, "^[1-9]{1}[0-9]*$", MMseqsParameter::COMMAND_CLUSTLINEAR | MMseqsParameter::COMMAND_EXPERT),
152152
PARAM_ADJUST_KMER_LEN(PARAM_ADJUST_KMER_LEN_ID, "--adjust-kmer-len", "Adjust k-mer length", "Adjust k-mer length based on specificity (only for nucleotides)", typeid(bool), (void *) &adjustKmerLength, "", MMseqsParameter::COMMAND_CLUSTLINEAR | MMseqsParameter::COMMAND_EXPERT),
153153
PARAM_RESULT_DIRECTION(PARAM_RESULT_DIRECTION_ID, "--result-direction", "Result direction", "result is 0: query, 1: target centric", typeid(int), (void *) &resultDirection, "^[0-1]{1}$", MMseqsParameter::COMMAND_CLUSTLINEAR | MMseqsParameter::COMMAND_EXPERT),
154-
154+
PARAM_WEIGHT_FILE(PARAM_WEIGHT_FILE_ID, "--weights", "Weight file name", "Weights used for cluster priorization", typeid(std::string), (void*) &weightFile, "", MMseqsParameter::COMMAND_CLUSTLINEAR | MMseqsParameter::COMMAND_EXPERT ),
155+
PARAM_WEIGHT_THR(PARAM_WEIGHT_THR_ID, "--weightThr", "Weight threshold", "Weight threshold used for cluster priorization", typeid(float), (void*) &weightThr, "^[0-9]*(\\.[0-9]+)?$", MMseqsParameter::COMMAND_CLUSTLINEAR | MMseqsParameter::COMMAND_EXPERT ),
155156
// workflow
156157
PARAM_RUNNER(PARAM_RUNNER_ID, "--mpi-runner", "MPI runner", "Use MPI on compute cluster with this MPI command (e.g. \"mpirun -np 42\")", typeid(std::string), (void *) &runner, "", MMseqsParameter::COMMAND_COMMON | MMseqsParameter::COMMAND_EXPERT),
157158
PARAM_REUSELATEST(PARAM_REUSELATEST_ID, "--force-reuse", "Force restart with latest tmp", "Reuse tmp filse in tmp/latest folder ignoring parameters and version changes", typeid(bool), (void *) &reuseLatest, "", MMseqsParameter::COMMAND_COMMON | MMseqsParameter::COMMAND_EXPERT),
@@ -435,6 +436,8 @@ Parameters::Parameters():
435436
clust.push_back(&PARAM_THREADS);
436437
clust.push_back(&PARAM_COMPRESSED);
437438
clust.push_back(&PARAM_V);
439+
clust.push_back(&PARAM_WEIGHT_FILE);
440+
clust.push_back(&PARAM_WEIGHT_THR);
438441
439442
// rescorediagonal
440443
rescorediagonal.push_back(&PARAM_SUB_MAT);
@@ -914,6 +917,8 @@ Parameters::Parameters():
914917
kmermatcher.push_back(&PARAM_THREADS);
915918
kmermatcher.push_back(&PARAM_COMPRESSED);
916919
kmermatcher.push_back(&PARAM_V);
920+
kmermatcher.push_back(&PARAM_WEIGHT_FILE);
921+
kmermatcher.push_back(&PARAM_WEIGHT_THR);
917922
918923
// kmermatcher
919924
kmersearch.push_back(&PARAM_SEED_SUB_MAT);
@@ -2431,6 +2436,9 @@ void Parameters::setDefaults() {
24312436
pickNbest = 1;
24322437
adjustKmerLength = false;
24332438
resultDirection = Parameters::PARAM_RESULT_DIRECTION_TARGET;
2439+
weightThr = 0.9;
2440+
weightFile = "";
2441+
24342442
// result2stats
24352443
stat = "";
24362444

src/commons/Parameters.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ class Parameters {
400400
std::string spacedKmerPattern; // User-specified kmer pattern
401401
std::string localTmp; // Local temporary path
402402

403+
403404
// ALIGNMENT
404405
int alignmentMode; // alignment mode 0=fastest on parameters,
405406
// 1=score only, 2=score, cov, start/end pos, 3=score, cov, start/end pos, seq.id,
@@ -526,6 +527,8 @@ class Parameters {
526527
int pickNbest;
527528
int adjustKmerLength;
528529
int resultDirection;
530+
float weightThr;
531+
std::string weightFile;
529532

530533
// indexdb
531534
int checkCompatible;
@@ -828,6 +831,9 @@ class Parameters {
828831
PARAMETER(PARAM_PICK_N_SIMILAR)
829832
PARAMETER(PARAM_ADJUST_KMER_LEN)
830833
PARAMETER(PARAM_RESULT_DIRECTION)
834+
PARAMETER(PARAM_WEIGHT_FILE)
835+
PARAMETER(PARAM_WEIGHT_THR)
836+
831837
// workflow
832838
PARAMETER(PARAM_RUNNER)
833839
PARAMETER(PARAM_REUSELATEST)

src/commons/SequenceWeights.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//
2+
// Created by annika on 09.11.22.
3+
//
4+
#include "SequenceWeights.h"
5+
#include "Util.h"
6+
#include "Debug.h"
7+
8+
#include <algorithm>
9+
#include <fstream>
10+
11+
SequenceWeights::SequenceWeights(const char* dataFileName) {
12+
13+
//parse file and fill weightIndex
14+
std::ifstream tsv(dataFileName);
15+
if (tsv.fail()) {
16+
Debug(Debug::ERROR) << "File " << dataFileName << " not found!\n";
17+
EXIT(EXIT_FAILURE);
18+
}
19+
20+
char keyData[255];
21+
std::string line;
22+
this->indexSize = 0;
23+
unsigned int pos = 0;
24+
while(std::getline(tsv, line)) {
25+
this->indexSize++;
26+
}
27+
28+
this->weightIndex = new WeightIndexEntry[this->indexSize];
29+
30+
tsv.clear();
31+
tsv.seekg(0);
32+
33+
while(std::getline(tsv, line)) {
34+
char *current = (char *) line.c_str();
35+
Util::parseKey(current, keyData);
36+
const std::string key(keyData);
37+
unsigned int keyId = strtoull(key.c_str(), NULL, 10);
38+
39+
char *restStart = current + key.length();
40+
restStart = restStart + Util::skipWhitespace(restStart);
41+
float weight = static_cast<float>(strtod(restStart, NULL));
42+
this->weightIndex[pos].id = keyId;
43+
this->weightIndex[pos].weight = weight;
44+
pos++;
45+
}
46+
}
47+
48+
float SequenceWeights::getWeightById(unsigned int id) {
49+
50+
WeightIndexEntry val;
51+
val.id = id;
52+
size_t pos = std::upper_bound(weightIndex, weightIndex + indexSize, val, WeightIndexEntry::compareByIdOnly) - weightIndex;
53+
return (pos < indexSize && weightIndex[pos].id == id ) ? weightIndex[pos].weight : std::numeric_limits<float>::min();
54+
}
55+

0 commit comments

Comments
 (0)