Skip to content

Commit b611046

Browse files
sidmlZigRazor
andauthored
Kruskal's Algorithm for MST (#127)
* function for checking if a subset or edgelist contains Cycle * cycle check for graph & more test cases for union-find * implement kruskal algorithm * add test for kruskal * update cmake and fix formatting * use priority queue instead of vector Co-authored-by: ZigRazor <[email protected]>
1 parent 0304eec commit b611046

File tree

3 files changed

+325
-28
lines changed

3 files changed

+325
-28
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ add_executable(test_exe test/main.cpp
5454
test/FWTest.cpp
5555
test/PrimTest.cpp
5656
test/BoruvkaTest.cpp
57+
test/KruskalTest.cpp
5758
test/BFSTest.cpp
5859
test/DFSTest.cpp
5960
test/CycleCheckTest.cpp

include/Graph/Graph.hpp

Lines changed: 110 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,10 @@ namespace CXXGRAPH
164164
/**
165165
* @brief This function modifies the original subset array
166166
* such that it the union of two sets a and b
167-
* @param subset query subset, we want to find target in this subset
167+
* @param subset original subset is modified to obtain union of a & b
168168
* @param a parent id of set1
169169
* @param b parent id of set2
170-
*
170+
* NOTE: Original subset is no longer available after union.
171171
* Note: No Thread Safe
172172
*/
173173
virtual void setUnion(std::vector<Subset>*, const unsigned long set1, const unsigned long elem2) const;
@@ -217,11 +217,25 @@ namespace CXXGRAPH
217217
* @brief Function runs the boruvka algorithm and returns the minimum spanning tree & cost
218218
* if the graph is undirected.
219219
* Note: No Thread Safe
220-
* @return a vector containing id of nodes in minimum spanning tree & cost of MST
221-
* returns errors if graph is undirected
220+
* @return struct of type MstResult with following fields
221+
* success: true if algorithm completed successfully ELSE false
222+
* mst: vector containing id of nodes in minimum spanning tree & cost of MST
223+
* mstCost: Cost of MST
224+
* errorMessage: "" if no error ELSE report the encountered error
222225
*/
223226
virtual const MstResult boruvka() const;
224227
/**
228+
* @brief Function runs the kruskal algorithm and returns the minimum spanning tree
229+
* if the graph is undirected.
230+
* Note: No Thread Safe
231+
* @return struct of type MstResult with following fields
232+
* success: true if algorithm completed successfully ELSE false
233+
* mst: vector containing id of nodes in minimum spanning tree & cost of MST
234+
* mstCost: Cost of MST
235+
* errorMessage: "" if no error ELSE report the encountered error
236+
*/
237+
virtual const MstResult kruskal() const;
238+
/**
225239
* \brief
226240
* Function performs the breadth first search algorithm over the graph
227241
* Note: No Thread Safe
@@ -909,7 +923,7 @@ namespace CXXGRAPH
909923
}
910924

911925
template <typename T>
912-
const unsigned long Graph<T>::setFind(std::vector<Subset>* subsets, const unsigned long nodeId) const
926+
const unsigned long Graph<T>::setFind(std::vector<Subset> *subsets, const unsigned long nodeId) const
913927
{
914928
// find root and make root as parent of i
915929
// (path compression)
@@ -927,21 +941,21 @@ namespace CXXGRAPH
927941
// return;
928942
// if both sets have same parent
929943
// then there's nothing to be done
930-
if ((*subsets)[elem1].parent==(*subsets)[elem2].parent)
944+
if ((*subsets)[elem1].parent == (*subsets)[elem2].parent)
931945
return;
932946
auto elem1Parent = Graph<T>::setFind(subsets, elem1);
933947
auto elem2Parent = Graph<T>::setFind(subsets, elem2);
934-
if((*subsets)[elem1Parent].rank < (*subsets)[elem2Parent].rank)
948+
if ((*subsets)[elem1Parent].rank < (*subsets)[elem2Parent].rank)
935949
(*subsets)[elem1].parent = elem2Parent;
936-
else if((*subsets)[elem1Parent].rank > (*subsets)[elem2Parent].rank)
950+
else if ((*subsets)[elem1Parent].rank > (*subsets)[elem2Parent].rank)
937951
(*subsets)[elem2].parent = elem1Parent;
938952
else
939953
{
940954
(*subsets)[elem2].parent = elem1Parent;
941955
(*subsets)[elem1Parent].rank++;
942-
}
943-
}
944-
956+
}
957+
}
958+
945959
template <typename T>
946960
const AdjacencyMatrix<T> Graph<T>::getAdjMatrix() const
947961
{
@@ -1303,14 +1317,14 @@ namespace CXXGRAPH
13031317
// mark source node as done
13041318
// otherwise we get (0, 0) also in mst
13051319
doneNode.push_back(source->getId());
1306-
// stores the parent and corresponding child node
1320+
// stores the parent and corresponding child node
13071321
// of the edges that are part of MST
13081322
std::map<unsigned long, unsigned long> parentNode;
13091323
while (!pq.empty())
13101324
{
13111325
// second element of pair denotes the node / vertex
13121326
const Node<T> *currentNode = pq.top().second;
1313-
auto nodeId = currentNode->getId();
1327+
auto nodeId = currentNode->getId();
13141328
if (std::find(doneNode.begin(), doneNode.end(), nodeId) == doneNode.end())
13151329
{
13161330
auto pair = std::make_pair(parentNode[nodeId], nodeId);
@@ -1330,10 +1344,9 @@ namespace CXXGRAPH
13301344
if (elem.second->isWeighted().has_value() && elem.second->isWeighted().value())
13311345
{
13321346
const UndirectedWeightedEdge<T> *udw_edge = dynamic_cast<const UndirectedWeightedEdge<T> *>(elem.second);
1333-
if (
1347+
if (
13341348
(udw_edge->getWeight() < dist[elem.first]) &&
1335-
(std::find(doneNode.begin(), doneNode.end(), elem.first->getId()) == doneNode.end())
1336-
)
1349+
(std::find(doneNode.begin(), doneNode.end(), elem.first->getId()) == doneNode.end()))
13371350
{
13381351
dist[elem.first] = udw_edge->getWeight();
13391352
parentNode[elem.first->getId()] = currentNode->getId();
@@ -1353,7 +1366,6 @@ namespace CXXGRAPH
13531366
return result;
13541367
}
13551368

1356-
13571369
template <typename T>
13581370
const MstResult Graph<T>::boruvka() const
13591371
{
@@ -1368,22 +1380,22 @@ namespace CXXGRAPH
13681380
}
13691381
auto nodeSet = Graph<T>::getNodeSet();
13701382
auto n = nodeSet.size();
1371-
1383+
13721384
// Use std vector for storing n subsets.
13731385
std::vector<Subset> subsets;
13741386

13751387
// Initially there are n different trees.
13761388
// Finally there will be one tree that will be MST
13771389
int numTrees = n;
1378-
1390+
13791391
// check if all edges are weighted and store the weights
13801392
// in a map whose keys are the edge ids and values are the edge weights
13811393
auto edgeSet = Graph<T>::getEdgeSet();
13821394
std::map<unsigned long, double> edgeWeight;
13831395
for (auto edge : edgeSet)
13841396
{
13851397
if (edge->isWeighted().has_value() && edge->isWeighted().value())
1386-
edgeWeight[edge->getId()] = (dynamic_cast<const Weighted *>(edge))->getWeight();
1398+
edgeWeight[edge->getId()] = (dynamic_cast<const Weighted *>(edge))->getWeight();
13871399
else
13881400
{
13891401
// No Weighted Edge
@@ -1404,7 +1416,7 @@ namespace CXXGRAPH
14041416
for (auto node : nodeSet)
14051417
{
14061418
userNodeMap[node->getId()] = i;
1407-
Subset set{i, 0};
1419+
Subset set{i, 0};
14081420
subsets.push_back(set);
14091421
i++;
14101422
}
@@ -1415,7 +1427,7 @@ namespace CXXGRAPH
14151427
{
14161428
// Everytime initialize cheapest array
14171429
std::fill(cheapest.begin(), cheapest.end(), -1);
1418-
1430+
14191431
// Traverse through all edges and update
14201432
// cheapest of every component
14211433
for (auto edge : edgeSet)
@@ -1426,26 +1438,26 @@ namespace CXXGRAPH
14261438
// of current edge
14271439
auto set1 = Graph<T>::setFind(&subsets, userNodeMap[elem.first->getId()]);
14281440
auto set2 = Graph<T>::setFind(&subsets, userNodeMap[elem.second->getId()]);
1429-
1441+
14301442
// If two corners of current edge belong to
14311443
// same set, ignore current edge
14321444
if (set1 == set2)
14331445
continue;
1434-
1446+
14351447
// Else check if current edge is closer to previous
14361448
// cheapest edges of set1 and set2
14371449
if (cheapest[set1] == -1 ||
14381450
edgeWeight[cheapest[set1]] > edgeWeight[edgeId])
14391451
cheapest[set1] = edgeId;
1440-
1452+
14411453
if (cheapest[set2] == -1 ||
14421454
edgeWeight[cheapest[set2]] > edgeWeight[edgeId])
14431455
cheapest[set2] = edgeId;
14441456
}
1445-
1446-
// iterate over all the vertices and add picked
1457+
1458+
// iterate over all the vertices and add picked
14471459
// cheapest edges to MST
1448-
for(int i=0; i<n;i++)
1460+
for (int i = 0; i < n; i++)
14491461
{
14501462
// Check if cheapest for current set exists
14511463
if (cheapest[i] != -1)
@@ -1468,6 +1480,76 @@ namespace CXXGRAPH
14681480
return result;
14691481
}
14701482

1483+
template <typename T>
1484+
const MstResult Graph<T>::kruskal() const
1485+
{
1486+
MstResult result;
1487+
result.success = false;
1488+
result.errorMessage = "";
1489+
result.mstCost = INF_DOUBLE;
1490+
if (!isUndirectedGraph())
1491+
{
1492+
result.errorMessage = ERR_DIR_GRAPH;
1493+
return result;
1494+
}
1495+
auto nodeSet = Graph<T>::getNodeSet();
1496+
auto n = nodeSet.size();
1497+
1498+
// check if all edges are weighted and store the weights
1499+
// in a map whose keys are the edge ids and values are the edge weights
1500+
auto edgeSet = Graph<T>::getEdgeSet();
1501+
std::priority_queue< std::pair<double, const Edge<T> *>, std::vector<std::pair<double, const Edge<T> *>>,
1502+
std::greater<std::pair<double, const Edge<T> *>>>
1503+
sortedEdges;
1504+
for (auto edge : edgeSet)
1505+
{
1506+
if (edge->isWeighted().has_value() && edge->isWeighted().value())
1507+
{
1508+
auto weight = (dynamic_cast<const Weighted *>(edge))->getWeight();
1509+
sortedEdges.push(std::make_pair(weight, edge));
1510+
}
1511+
else
1512+
{
1513+
// No Weighted Edge
1514+
result.errorMessage = ERR_NO_WEIGHTED_EDGE;
1515+
return result;
1516+
}
1517+
}
1518+
1519+
std::vector<Subset> subset;
1520+
1521+
// user can give arbitrary ids to nodes
1522+
// we map these ids from 0 to 1 for consistency
1523+
// NOTE: WE CAN REMOVE THIS WHEN WE TAKE CARE OF THIS GLOBALLY
1524+
// WHILE CONSTRUCTING THE GRAPH
1525+
std::map<unsigned long, unsigned long> userNodeMap;
1526+
unsigned long i = 0;
1527+
for (auto node : nodeSet)
1528+
{
1529+
userNodeMap[node->getId()] = i;
1530+
Subset set{i, 0};
1531+
subset.push_back(set);
1532+
i++;
1533+
}
1534+
result.mstCost = 0;
1535+
while ((!sortedEdges.empty()) && (result.mst.size() < n))
1536+
{
1537+
auto [edgeWeight, cheapestEdge] = sortedEdges.top();
1538+
sortedEdges.pop();
1539+
auto &[first, second] = cheapestEdge->getNodePair();
1540+
auto set1 = Graph<T>::setFind(&subset, userNodeMap[first->getId()]);
1541+
auto set2 = Graph<T>::setFind(&subset, userNodeMap[second->getId()]);
1542+
if (set1 != set2)
1543+
{
1544+
result.mst.push_back(std::make_pair(userNodeMap[first->getId()], userNodeMap[second->getId()]));
1545+
result.mstCost += edgeWeight;
1546+
}
1547+
Graph<T>::setUnion(&subset, set1, set2);
1548+
}
1549+
result.success = true;
1550+
return result;
1551+
}
1552+
14711553
template <typename T>
14721554
const std::vector<Node<T>> Graph<T>::breadth_first_search(const Node<T> &start) const
14731555
{

0 commit comments

Comments
 (0)