Skip to content

Commit 96ed93d

Browse files
yy462yuyang2S2023
andauthored
Monitor comm cost (#40)
* update the downloading path of cora dataset * Add ogbn dataset * Update data_process.py * Fix issues detected by pre-commit * Add setup_cluster.sh script and update README for Ray cluster setup * Update readme for setup cluster * add setup custer doc * update config paths, and adjust Ray cluster setup * adjust the structure of the folder and finish testing the code * feat: add communication monitoring during initialization in NC, GC, LP * fix: pre-commit formatting * Moved use_cluster check inside Monitor class to cleanly separate time tracking from cluster-specific metrics. * Fix some parts based on the feedback * Fix if else inside of the monitor * Update setup_cluster.sh and test all benchmark * add Theoretical Comm Cost * Update pretrain comm cost for NC * Add comm cost extraction and visualization scripts, update benchmarks and EKS configs * update configure and figure * delete redundant fils * Add PDF output for plots and update extract log scripts * Update figures * Add accuracy figures and extract script --------- Co-authored-by: Yu Yang <[email protected]>
1 parent 74292e3 commit 96ed93d

35 files changed

+2433
-1187
lines changed

benchmark/benchmark_GC.py

Lines changed: 95 additions & 289 deletions
Large diffs are not rendered by default.

benchmark/benchmark_LP.py

Lines changed: 75 additions & 242 deletions
Original file line numberDiff line numberDiff line change
@@ -1,262 +1,95 @@
11
"""
2-
Federated Link Prediction Example
3-
================
2+
Federated Link Prediction Benchmark
3+
===================================
44
5-
In this tutorial, you will learn the basic workflow of
6-
Federated Link Prediction with a runnable example. This tutorial assumes that
7-
you have basic familiarity with PyTorch and PyTorch Geometric (PyG).
5+
Run benchmarks for various federated link prediction algorithms using a simplified approach.
86
9-
(Time estimate: 20 minutes)
7+
(Time estimate: 30 minutes)
108
"""
119

12-
import argparse
13-
import copy
14-
import datetime
1510
import os
16-
import random
17-
import sys
18-
from pathlib import Path
11+
import time
1912

2013
import attridict
21-
import numpy as np
2214
import ray
2315
import torch
2416
import yaml
25-
from ray.util.metrics import Counter, Gauge, Histogram
2617

27-
from fedgraph.federated_methods import LP_train_global_round
28-
from fedgraph.monitor_class import Monitor
29-
from fedgraph.server_class import Server_LP
30-
from fedgraph.trainer_class import Trainer_LP
31-
from fedgraph.utils_lp import *
18+
from fedgraph.federated_methods import run_fedgraph
3219

20+
# Methods to benchmark
21+
methods = ["4D-FED-GNN+", "STFL", "StaticGNN", "FedLink"]
3322

34-
def run(method, country_codes):
35-
print(
36-
f"Running experiment with: Dataset={'+'.join(country_codes)}, Number of Trainers=10, Distribution Type={method}, IID Beta=1.0, Number of Hops=1, Batch Size=-1"
37-
)
38-
# Determine the directory of the current script
39-
current_dir = os.path.dirname(os.path.abspath(__file__))
40-
41-
# Append paths relative to the current script's directory
42-
sys.path.append(os.path.join(current_dir, "../fedgraph"))
43-
sys.path.append(os.path.join(current_dir, "../../"))
44-
ray.init()
45-
46-
#######################################################################
47-
# Load configuration and check arguments
48-
# ------------
49-
# Here we load the configuration file for the experiment.
50-
# The configuration file contains the parameters for the experiment.
51-
# The algorithm and dataset (represented by the country code) are specified by the user here.
52-
# We also specify some prechecks to ensure the validity of the arguments.
53-
54-
config_file = os.path.join(current_dir, "configs/config_LP.yaml")
55-
with open(config_file, "r") as file:
56-
args = attridict(yaml.safe_load(file))
57-
args.method = method
58-
args.country_codes = country_codes
59-
dataset_path = os.path.join(
60-
os.path.dirname(os.path.abspath(__file__)), args.dataset_path
61-
)
62-
print(dataset_path)
63-
global_file_path = os.path.join(dataset_path, "data_global.txt")
64-
traveled_file_path = os.path.join(dataset_path, "traveled_users.txt")
65-
print(f"traveled_file_path: {traveled_file_path}")
66-
assert args.method in [
67-
"STFL",
68-
"StaticGNN",
69-
"4D-FED-GNN+",
70-
"FedLink",
71-
], "Invalid method."
72-
assert all(
73-
code in ["US", "BR", "ID", "TR", "JP"] for code in args.country_codes
74-
), "The country codes should be in 'US', 'BR', 'ID', 'TR', 'JP'"
75-
if args.use_buffer:
76-
assert args.buffer_size > 0, "The buffer size should be greater than 0."
77-
78-
#######################################################################
79-
# Generate data
80-
# ------------
81-
# Here we generate the data for the experiment.
82-
# If the data is already generated, we load the data from the file.
83-
# Otherwise, we download the data from the website and generate the data.
84-
# We also create the mappings and meta_data for the data.
85-
86-
check_data_files_existance(args.country_codes, dataset_path)
87-
88-
(
89-
user_id_mapping,
90-
item_id_mapping,
91-
) = get_global_user_item_mapping( # get global user and item mapping
92-
global_file_path=global_file_path
93-
)
94-
95-
meta_data = (
96-
["user", "item"],
97-
[("user", "select", "item"), ("item", "rev_select", "user")],
98-
) # set meta_data
99-
100-
#######################################################################
101-
# Initialize server and trainers
102-
# ------------
103-
# Starting from this block, we formally begin the training process.
104-
# If you want to run multiple experiments, you can wrap the following code in a loop.
105-
# In this block, we initialize the server and trainers for the experiment.
23+
# Country code combinations to test
24+
country_codes_list = [["US"], ["US", "BR"], ["US", "BR", "ID", "TR", "JP"]]
10625

107-
number_of_clients = len(args.country_codes)
108-
number_of_users, number_of_items = len(user_id_mapping.keys()), len(
109-
item_id_mapping.keys()
110-
)
111-
num_cpus_per_client = 3
112-
if args.device == "gpu":
113-
device = torch.device("cuda")
114-
print("gpu detected")
115-
num_gpus_per_client = 1
116-
else:
117-
device = torch.device("cpu")
118-
num_gpus_per_client = 0
119-
print("gpu not detected")
26+
# Number of runs per configuration
27+
runs_per_config = 1
12028

121-
@ray.remote(
122-
num_gpus=num_gpus_per_client,
123-
num_cpus=num_cpus_per_client,
124-
scheduling_strategy="SPREAD",
125-
)
126-
class Trainer(Trainer_LP):
127-
def __init__(self, *args, **kwargs): # type: ignore
128-
super().__init__(*args, **kwargs)
29+
# Define additional required parameters that might be missing from YAML
30+
required_params = {
31+
"fedgraph_task": "LP",
32+
"num_cpus_per_trainer": 3,
33+
"num_gpus_per_trainer": 1 if torch.cuda.is_available() else 0,
34+
"use_cluster": True,
35+
"gpu": torch.cuda.is_available(),
36+
"ray_address": "auto",
37+
}
12938

130-
clients = [
131-
Trainer.remote( # type: ignore
132-
i,
133-
country_code=args.country_codes[i],
134-
user_id_mapping=user_id_mapping,
135-
item_id_mapping=item_id_mapping,
136-
number_of_users=number_of_users,
137-
number_of_items=number_of_items,
138-
meta_data=meta_data,
139-
hidden_channels=args.hidden_channels,
140-
dataset_path=dataset_path,
39+
# Main benchmark loop
40+
for method in methods:
41+
for country_codes in country_codes_list:
42+
# Load the base configuration file
43+
config_file = os.path.join(
44+
os.path.dirname(__file__), "configs", "config_LP.yaml"
14145
)
142-
for i in range(number_of_clients)
143-
]
144-
145-
server = Server_LP( # the concrete information of users and items is not available in the server
146-
number_of_users=number_of_users,
147-
number_of_items=number_of_items,
148-
meta_data=meta_data,
149-
trainers=clients,
150-
)
151-
pretrain_time_costs_gauge = Gauge(
152-
"pretrain_time_cost", description="Latencies of pretrain_time_costs in ms."
153-
)
154-
train_time_costs_gauge = Gauge(
155-
"train_time_cost", description="Latencies of train_time_costs in ms."
156-
)
157-
158-
#######################################################################
159-
# Training preparation
160-
# ------------
161-
# Here we prepare the training for the experiment.
162-
# (1) We brodcast the initial model parameter to all clients.
163-
# (2) We determine the start and end time of the conditional information.
164-
# (3) We open the file to record the results if the user wants to record the results.
165-
166-
"""Broadcast the global model parameter to all clients"""
167-
monitor = Monitor()
168-
monitor.pretrain_time_start()
169-
global_model_parameter = (
170-
server.get_model_parameter()
171-
) # fetch the global model parameter
172-
for i in range(number_of_clients):
173-
# broadcast the global model parameter to all clients
174-
clients[i].set_model_parameter.remote(global_model_parameter)
175-
176-
"""Determine the start and end time of the conditional information"""
177-
(
178-
start_time,
179-
end_time,
180-
prediction_days,
181-
start_time_float_format,
182-
end_time_float_format,
183-
) = get_start_end_time(online_learning=args.online_learning, method=args.method)
184-
185-
if not args.record_results:
186-
result_writer = None
187-
time_writer = None
188-
else:
189-
file_name = f"{args.method}_buffer_{args.use_buffer}_{args.buffer_size}_online_{args.online_learning}.txt"
190-
result_writer = open(file_name, "a+")
191-
time_writer = open("train_time_" + file_name, "a+")
192-
193-
monitor.pretrain_time_end(30)
194-
monitor.train_time_start()
195-
196-
#######################################################################
197-
# Train the model
198-
# ------------
199-
# Here we train the model for the experiment.
200-
# For each prediction day, we train the model for each client.
201-
# We also record the results if the user wants to record the results.
202-
for day in range(prediction_days): # make predictions for each day
203-
# get the train and test data for each client at the current time step
204-
for i in range(number_of_clients):
205-
clients[i].get_train_test_data_at_current_time_step.remote(
206-
start_time_float_format,
207-
end_time_float_format,
208-
use_buffer=args.use_buffer,
209-
buffer_size=args.buffer_size,
210-
)
211-
clients[i].calculate_traveled_user_edge_indices.remote(
212-
file_path=traveled_file_path
46+
with open(config_file, "r") as file:
47+
config = attridict(yaml.safe_load(file))
48+
49+
# Update the configuration with specific parameters for this run
50+
config.method = method
51+
config.country_codes = country_codes
52+
53+
# Add required parameters that might be missing
54+
for param, value in required_params.items():
55+
if not hasattr(config, param):
56+
setattr(config, param, value)
57+
58+
# Set dataset path
59+
if not hasattr(config, "dataset_path") or not config.dataset_path:
60+
config.dataset_path = os.path.join(
61+
os.path.dirname(os.path.abspath(__file__)), "data", "LPDataset"
21362
)
21463

215-
if args.online_learning:
216-
print(f"start training for day {day + 1}")
217-
else:
218-
print(f"start training")
219-
for iteration in range(args.global_rounds):
220-
# each client train on local graph
221-
print(iteration)
222-
223-
current_loss = LP_train_global_round(
224-
server=server,
225-
local_steps=args.local_steps,
226-
use_buffer=args.use_buffer,
227-
method=args.method,
228-
online_learning=args.online_learning,
229-
prediction_day=day,
230-
curr_iteration=iteration,
231-
global_rounds=args.global_rounds,
232-
record_results=args.record_results,
233-
result_writer=result_writer,
234-
time_writer=time_writer,
235-
)
236-
237-
if current_loss >= 0.3:
238-
print("training is not complete")
239-
240-
# go to next day
241-
(
242-
start_time,
243-
end_time,
244-
start_time_float_format,
245-
end_time_float_format,
246-
) = to_next_day(start_time=start_time, end_time=end_time, method=args.method)
247-
monitor.train_time_end(30)
248-
if result_writer is not None and time_writer is not None:
249-
result_writer.close()
250-
time_writer.close()
251-
252-
print("The whole process has ended")
253-
ray.shutdown()
254-
255-
256-
methods = ["4D-FED-GNN+", "STFL", "StaticGNN", "FedLink"]
257-
country_codes_list = [["US"], ["US", "BR"], ["US", "BR", "ID", "TR", "JP"]]
258-
259-
for method in methods:
260-
for country_codes in country_codes_list:
261-
print(f"Running method {method} with country codes {country_codes}")
262-
run(method, country_codes)
64+
# Run multiple times for statistical significance
65+
for i in range(runs_per_config):
66+
print(f"\n{'-'*80}")
67+
print(f"Running experiment {i+1}/{runs_per_config}:")
68+
print(f"Method: {method}, Countries: {', '.join(country_codes)}")
69+
print(f"{'-'*80}\n")
70+
71+
# To ensure each run uses a fresh configuration object
72+
run_config = attridict({})
73+
for key, value in config.items():
74+
run_config[key] = value
75+
76+
# Run the federated learning process with clean Ray environment
77+
try:
78+
# Make sure Ray is shut down from any previous runs
79+
if ray.is_initialized():
80+
ray.shutdown()
81+
82+
# Run the experiment
83+
run_fedgraph(run_config)
84+
except Exception as e:
85+
print(f"Error running experiment: {e}")
86+
print(f"Configuration: {run_config}")
87+
finally:
88+
# Always ensure Ray is shut down before the next experiment
89+
if ray.is_initialized():
90+
ray.shutdown()
91+
92+
# Add a short delay between runs
93+
time.sleep(5)
94+
95+
print("Benchmark completed.")

0 commit comments

Comments
 (0)