Skip to content

Commit e95bdc8

Browse files
authored
Prim's algorithm for finding minimum spanning tree (#89)
* add graph slicing algorithm * add graph slicing test cases * update readme * improve test coverage * add prim's algorithm * prim test & result struct * cmake & new error message
1 parent 88f6f71 commit e95bdc8

File tree

5 files changed

+306
-4
lines changed

5 files changed

+306
-4
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,14 @@ add_executable(test_exe test/main.cpp
5151
test/GraphTest.cpp
5252
test/DijkstraTest.cpp
5353
test/BellmanFordTest.cpp
54+
test/FWTest.cpp
55+
test/PrimTest.cpp
5456
test/BFSTest.cpp
5557
test/DFSTest.cpp
5658
test/CycleCheckTest.cpp
5759
test/RWOutputTest.cpp
5860
test/PartitionTest.cpp
5961
test/DialTest.cpp
60-
test/FWTest.cpp
6162
test/GraphSlicingTest.cpp
6263
)
6364
target_include_directories(test_exe PUBLIC

include/Graph/Graph.hpp

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,13 @@ namespace CXXGRAPH
188188
*/
189189
virtual const FWResult floydWarshall() const;
190190
/**
191+
* @brief Function runs the prim algorithm and returns the minimum spanning tree
192+
* if the graph is undirected.
193+
* Note: No Thread Safe
194+
* @return a vector containing id of nodes in minimum spanning tree.
195+
*/
196+
virtual const PrimResult prim() const;
197+
/**
191198
* \brief
192199
* Function performs the breadth first search algorithm over the graph
193200
* Note: No Thread Safe
@@ -241,6 +248,15 @@ namespace CXXGRAPH
241248

242249
/**
243250
* \brief
251+
* This function checks if a graph is undirected
252+
* Note: No Thread Safe
253+
*
254+
* @return true if the graph is undirected, else false.
255+
*/
256+
virtual bool isUndirectedGraph() const;
257+
258+
/**
259+
* \brief
244260
* This function performs Graph Slicing based on connectivity
245261
*
246262
* Mathematical definition of the problem:
@@ -1099,9 +1115,6 @@ namespace CXXGRAPH
10991115
result.errorMessage = "";
11001116
std::map<std::pair<unsigned long, unsigned long>, double> pairwise_dist;
11011117
auto nodeSet = getNodeSet();
1102-
const AdjacencyMatrix<T> adj = getAdjMatrix();
1103-
// n denotes the number of vertices in graph
1104-
auto n = nodeSet.size();
11051118
// create a pairwise distance matrix with distance node distances
11061119
// set to inf. Distance of node to itself is set as 0.
11071120
for (auto elem1 : nodeSet)
@@ -1174,6 +1187,87 @@ namespace CXXGRAPH
11741187
return result;
11751188
}
11761189

1190+
template <typename T>
1191+
const PrimResult Graph<T>::prim() const
1192+
{
1193+
PrimResult result;
1194+
result.success = false;
1195+
result.errorMessage = "";
1196+
result.mstCost = INF_DOUBLE;
1197+
if (!isUndirectedGraph())
1198+
{
1199+
result.errorMessage = ERR_DIR_GRAPH;
1200+
return result;
1201+
}
1202+
auto nodeSet = getNodeSet();
1203+
auto n = nodeSet.size();
1204+
const AdjacencyMatrix<T> adj = getAdjMatrix();
1205+
1206+
// setting all the distances initially to INF_DOUBLE
1207+
std::map<const Node<T> *, double> dist;
1208+
for (auto elem : adj)
1209+
{
1210+
dist[elem.first] = INF_DOUBLE;
1211+
}
1212+
1213+
// creating a min heap using priority queue
1214+
// first element of pair contains the distance
1215+
// second element of pair contains the vertex
1216+
std::priority_queue<std::pair<double, const Node<T> *>, std::vector<std::pair<double, const Node<T> *>>,
1217+
std::greater<std::pair<double, const Node<T> *>>>
1218+
pq;
1219+
1220+
// pushing the source vertex 's' with 0 distance in min heap
1221+
auto source = nodeSet.front();
1222+
pq.push(std::make_pair(0.0, source));
1223+
// initialize cost and start node of mst
1224+
result.result.push_back(source->getId());
1225+
result.mstCost = 0;
1226+
while (!pq.empty())
1227+
{
1228+
// second element of pair denotes the node / vertex
1229+
const Node<T> *currentNode = pq.top().second;
1230+
auto nodeId = currentNode->getId();
1231+
if (std::find(result.result.begin(), result.result.end(), nodeId) == result.result.end())
1232+
{
1233+
result.result.push_back(nodeId);
1234+
result.mstCost += pq.top().first;
1235+
}
1236+
1237+
pq.pop();
1238+
// for all the reachable vertex from the currently exploring vertex
1239+
// we will try to minimize the distance
1240+
if (adj.find(currentNode) != adj.end())
1241+
{
1242+
for (std::pair<const Node<T> *, const Edge<T> *> elem : adj.at(currentNode))
1243+
{
1244+
// minimizing distances
1245+
if (elem.second->isWeighted().has_value() && elem.second->isWeighted().value())
1246+
{
1247+
const UndirectedWeightedEdge<T> *udw_edge = dynamic_cast<const UndirectedWeightedEdge<T> *>(elem.second);
1248+
if (
1249+
(udw_edge->getWeight() < dist[elem.first]) &&
1250+
(std::find(result.result.begin(), result.result.end(), elem.first->getId()) == result.result.end())
1251+
)
1252+
1253+
{
1254+
dist[elem.first] = udw_edge->getWeight();
1255+
pq.push(std::make_pair(dist[elem.first], elem.first));
1256+
}
1257+
}
1258+
else
1259+
{
1260+
// No Weighted Edge
1261+
result.errorMessage = ERR_NO_WEIGHTED_EDGE;
1262+
return result;
1263+
}
1264+
}
1265+
}
1266+
}
1267+
result.success = true;
1268+
return result;
1269+
}
1270+
11771271
template <typename T>
11781272
const std::vector<Node<T>> Graph<T>::breadth_first_search(const Node<T> &start) const
11791273
{
@@ -1421,6 +1515,22 @@ namespace CXXGRAPH
14211515
return true;
14221516
}
14231517

1518+
template <typename T>
1519+
bool Graph<T>::isUndirectedGraph() const
1520+
{
1521+
auto edgeSet = this->getEdgeSet();
1522+
for (auto edge : edgeSet)
1523+
{
1524+
if ((edge->isDirected().has_value() && edge->isDirected().value()))
1525+
{
1526+
//Found Directed Edge
1527+
return false;
1528+
}
1529+
}
1530+
//No Directed Edge
1531+
return true;
1532+
}
1533+
14241534
template <typename T>
14251535
const DialResult Graph<T>::dial(const Node<T> &source, int maxWeight) const
14261536
{

include/Utility/ConstString.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ namespace CXXGRAPH
2626
{
2727
//STRING ERROR CONST EXPRESSION
2828
constexpr char ERR_NO_DIR_OR_UNDIR_EDGE[] = "Edge are neither Directed neither Undirected";
29+
constexpr char ERR_DIR_GRAPH[] = "Graph is directed";
2930
constexpr char ERR_NO_WEIGHTED_EDGE[] = "Edge are not Weighted";
3031
constexpr char ERR_TARGET_NODE_NOT_REACHABLE[] = "Target Node not Reachable";
3132
constexpr char ERR_TARGET_NODE_NOT_IN_GRAPH[] = "Target Node not inside Graph";

include/Utility/Typedef.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@ namespace CXXGRAPH
8686
};
8787
typedef FWResult_struct FWResult;
8888

89+
/// Struct that contains the information about Prim Algorithm results
90+
struct PrimResult_struct
91+
{
92+
bool success; // TRUE if the function does not return error, FALSE otherwise
93+
std::string errorMessage; //message of error
94+
std::vector<unsigned long> result; // MST
95+
double mstCost; // MST
96+
};
97+
typedef PrimResult_struct PrimResult;
98+
8999
/// Struct that contains the information about Dijsktra's Algorithm results
90100
struct DialResult_struct
91101
{

test/PrimTest.cpp

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#include "gtest/gtest.h"
2+
#include "CXXGraph.hpp"
3+
4+
// minimum spanning tree can differ so instead of checking
5+
// the exact order of elements, we can check some properties
6+
// like the length & cost of mst which must remain the same
7+
8+
// example taken from
9+
// https://www.geeksforgeeks.org/prims-mst-for-adjacency-list-representation-greedy-algo-6/TEST(FWTest, test_1)
10+
TEST(PrimTest, test_1)
11+
{
12+
CXXGRAPH::Node<int> node0(0, 0);
13+
CXXGRAPH::Node<int> node1(1, 1);
14+
CXXGRAPH::Node<int> node2(2, 2);
15+
CXXGRAPH::Node<int> node3(3, 3);
16+
CXXGRAPH::Node<int> node4(4, 4);
17+
CXXGRAPH::Node<int> node5(5, 5);
18+
CXXGRAPH::Node<int> node6(6, 6);
19+
CXXGRAPH::Node<int> node7(7, 7);
20+
CXXGRAPH::Node<int> node8(8, 8);
21+
22+
CXXGRAPH::UndirectedWeightedEdge<int> edge1(1, node0, node1, 4);
23+
CXXGRAPH::UndirectedWeightedEdge<int> edge2(2, node0, node7, 8);
24+
CXXGRAPH::UndirectedWeightedEdge<int> edge3(3, node1, node7, 11);
25+
CXXGRAPH::UndirectedWeightedEdge<int> edge4(3, node1, node2, 8);
26+
CXXGRAPH::UndirectedWeightedEdge<int> edge5(4, node7, node8, 7);
27+
CXXGRAPH::UndirectedWeightedEdge<int> edge6(3, node7, node6, 1);
28+
CXXGRAPH::UndirectedWeightedEdge<int> edge7(3, node8, node2, 2);
29+
CXXGRAPH::UndirectedWeightedEdge<int> edge8(3, node8, node6, 6);
30+
CXXGRAPH::UndirectedWeightedEdge<int> edge9(3, node2, node5, 4);
31+
CXXGRAPH::UndirectedWeightedEdge<int> edge10(3, node2, node3, 7);
32+
CXXGRAPH::UndirectedWeightedEdge<int> edge11(3, node6, node5, 2);
33+
CXXGRAPH::UndirectedWeightedEdge<int> edge12(3, node3, node4, 9);
34+
CXXGRAPH::UndirectedWeightedEdge<int> edge13(3, node3, node5, 14);
35+
CXXGRAPH::UndirectedWeightedEdge<int> edge14(3, node5, node4, 10);
36+
37+
std::list<const CXXGRAPH::Edge<int> *> edgeSet;
38+
edgeSet.push_back(&edge1);
39+
edgeSet.push_back(&edge2);
40+
edgeSet.push_back(&edge3);
41+
edgeSet.push_back(&edge4);
42+
edgeSet.push_back(&edge5);
43+
edgeSet.push_back(&edge6);
44+
edgeSet.push_back(&edge7);
45+
edgeSet.push_back(&edge8);
46+
edgeSet.push_back(&edge9);
47+
edgeSet.push_back(&edge10);
48+
edgeSet.push_back(&edge11);
49+
edgeSet.push_back(&edge12);
50+
edgeSet.push_back(&edge13);
51+
edgeSet.push_back(&edge14);
52+
53+
CXXGRAPH::Graph<int> graph(edgeSet);
54+
CXXGRAPH::PrimResult res = graph.prim();
55+
56+
ASSERT_TRUE(res.success);
57+
ASSERT_EQ(res.result.size(), graph.getNodeSet().size());
58+
ASSERT_EQ(res.mstCost, 37);
59+
ASSERT_EQ(res.errorMessage, "");
60+
}
61+
62+
63+
// example taken from
64+
// https://www.gatevidyalay.com/prims-algorithm-prim-algorithm-example/
65+
TEST(PrimTest, test_2)
66+
{
67+
CXXGRAPH::Node<int> node1(1, 1);
68+
CXXGRAPH::Node<int> node2(2, 2);
69+
CXXGRAPH::Node<int> node3(3, 3);
70+
CXXGRAPH::Node<int> node4(4, 4);
71+
CXXGRAPH::Node<int> node5(5, 5);
72+
CXXGRAPH::Node<int> node6(6, 6);
73+
CXXGRAPH::Node<int> node7(7, 7);
74+
75+
CXXGRAPH::UndirectedWeightedEdge<int> edge1(1, node1, node2, 28);
76+
CXXGRAPH::UndirectedWeightedEdge<int> edge2(2, node1, node6, 10);
77+
CXXGRAPH::UndirectedWeightedEdge<int> edge3(3, node2, node7, 14);
78+
CXXGRAPH::UndirectedWeightedEdge<int> edge4(4, node2, node3, 16);
79+
CXXGRAPH::UndirectedWeightedEdge<int> edge5(5, node6, node5, 25);
80+
CXXGRAPH::UndirectedWeightedEdge<int> edge6(6, node7, node5, 24);
81+
CXXGRAPH::UndirectedWeightedEdge<int> edge7(7, node7, node4, 18);
82+
CXXGRAPH::UndirectedWeightedEdge<int> edge8(8, node5, node4, 22);
83+
CXXGRAPH::UndirectedWeightedEdge<int> edge9(9, node4, node3, 12);
84+
85+
std::list<const CXXGRAPH::Edge<int> *> edgeSet;
86+
edgeSet.push_back(&edge1);
87+
edgeSet.push_back(&edge2);
88+
edgeSet.push_back(&edge3);
89+
edgeSet.push_back(&edge4);
90+
edgeSet.push_back(&edge5);
91+
edgeSet.push_back(&edge6);
92+
edgeSet.push_back(&edge7);
93+
edgeSet.push_back(&edge8);
94+
edgeSet.push_back(&edge9);
95+
96+
CXXGRAPH::Graph<int> graph(edgeSet);
97+
CXXGRAPH::PrimResult res = graph.prim();
98+
99+
// double values[4][4] = {{0, -1, -2, 0}, {4, 0, 2, 4}, {5, 1, 0, 2}, {3, -1, 1, 0}};
100+
// unsigned long mst[] = {}
101+
ASSERT_TRUE(res.success);
102+
ASSERT_EQ(res.result.size(), graph.getNodeSet().size());
103+
ASSERT_EQ(res.mstCost, 99);
104+
ASSERT_EQ(res.errorMessage, "");
105+
106+
}
107+
108+
109+
// example taken from
110+
// https://www.gatevidyalay.com/prims-algorithm-prim-algorithm-example/
111+
TEST(PrimTest, test_3)
112+
{
113+
CXXGRAPH::Node<int> node1(1, 1);
114+
CXXGRAPH::Node<int> node2(2, 2);
115+
CXXGRAPH::Node<int> node3(3, 3);
116+
CXXGRAPH::Node<int> node4(4, 4);
117+
CXXGRAPH::Node<int> node5(5, 5);
118+
CXXGRAPH::Node<int> node6(6, 6);
119+
CXXGRAPH::Node<int> node7(7, 7);
120+
121+
CXXGRAPH::UndirectedWeightedEdge<int> edge1(1, node1, node2, 1);
122+
CXXGRAPH::UndirectedWeightedEdge<int> edge2(2, node1, node3, 5);
123+
CXXGRAPH::UndirectedWeightedEdge<int> edge3(3, node2, node5, 7);
124+
CXXGRAPH::UndirectedWeightedEdge<int> edge4(4, node2, node4, 8);
125+
CXXGRAPH::UndirectedWeightedEdge<int> edge5(5, node2, node3, 4);
126+
CXXGRAPH::UndirectedWeightedEdge<int> edge6(6, node3, node4, 6);
127+
CXXGRAPH::UndirectedWeightedEdge<int> edge7(7, node3, node6, 2);
128+
CXXGRAPH::UndirectedWeightedEdge<int> edge8(8, node4, node6, 9);
129+
CXXGRAPH::UndirectedWeightedEdge<int> edge9(9, node4, node5, 11);
130+
CXXGRAPH::UndirectedWeightedEdge<int> edge10(9, node5, node7, 10);
131+
CXXGRAPH::UndirectedWeightedEdge<int> edge11(9, node5, node6, 3);
132+
CXXGRAPH::UndirectedWeightedEdge<int> edge12(9, node6, node7, 12);
133+
134+
std::list<const CXXGRAPH::Edge<int> *> edgeSet;
135+
edgeSet.push_back(&edge1);
136+
edgeSet.push_back(&edge2);
137+
edgeSet.push_back(&edge3);
138+
edgeSet.push_back(&edge4);
139+
edgeSet.push_back(&edge5);
140+
edgeSet.push_back(&edge6);
141+
edgeSet.push_back(&edge7);
142+
edgeSet.push_back(&edge8);
143+
edgeSet.push_back(&edge9);
144+
edgeSet.push_back(&edge10);
145+
edgeSet.push_back(&edge11);
146+
edgeSet.push_back(&edge12);
147+
148+
CXXGRAPH::Graph<int> graph(edgeSet);
149+
CXXGRAPH::PrimResult res = graph.prim();
150+
151+
ASSERT_TRUE(res.success);
152+
ASSERT_EQ(res.result.size(), graph.getNodeSet().size());
153+
ASSERT_EQ(res.mstCost, 26);
154+
ASSERT_EQ(res.errorMessage, "");
155+
}
156+
157+
// test for directed and no weighted edge errors
158+
TEST(PrimTest, test_4)
159+
{
160+
CXXGRAPH::Node<int> node1(1, 1);
161+
CXXGRAPH::Node<int> node2(2, 2);
162+
CXXGRAPH::Node<int> node3(3, 3);
163+
CXXGRAPH::DirectedWeightedEdge<int> edge1(1, node1, node2, 1);
164+
CXXGRAPH::DirectedWeightedEdge<int> edge2(2, node2, node3, 1);
165+
std::list<const CXXGRAPH::Edge<int> *> edgeSet;
166+
edgeSet.push_back(&edge1);
167+
edgeSet.push_back(&edge2);
168+
CXXGRAPH::Graph<int> graph(edgeSet);
169+
CXXGRAPH::PrimResult res = graph.prim();
170+
ASSERT_FALSE(res.success);
171+
ASSERT_EQ(res.errorMessage, CXXGRAPH::ERR_DIR_GRAPH);
172+
173+
CXXGRAPH::UndirectedEdge<int> edge3(3, node1, node2);
174+
std::list<const CXXGRAPH::Edge<int> *> edgeSet1;
175+
edgeSet1.push_back(&edge3);
176+
CXXGRAPH::Graph<int> graph1(edgeSet1);
177+
res = graph1.prim();
178+
ASSERT_FALSE(res.success);
179+
ASSERT_EQ(res.errorMessage, CXXGRAPH::ERR_NO_WEIGHTED_EDGE);
180+
}

0 commit comments

Comments
 (0)