Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
88a2c14
add GIN to gnn_models
Lijunhaohz Feb 23, 2024
4a54fa2
Merge branch 'main' into gc
Lijunhaohz Feb 23, 2024
baf400c
add utils_gc.py
Lijunhaohz Feb 23, 2024
a01d078
add Server_GC in server_class.py
Lijunhaohz Feb 23, 2024
cc048b6
add GC Trainer
Lijunhaohz Feb 23, 2024
bb4ab36
add GC train functions
Lijunhaohz Feb 23, 2024
30147e1
add entrance functions and data process functions for GC
Lijunhaohz Feb 23, 2024
86d3e42
add short example for GC
Lijunhaohz Feb 23, 2024
0d63f4e
revise format
Lijunhaohz Feb 23, 2024
3a621eb
add example for GC
Lijunhaohz Feb 23, 2024
01f0634
pre-commit formatted
Lijunhaohz Feb 23, 2024
ad1e588
solve mypy errors
Lijunhaohz Feb 23, 2024
ed0813a
resolve attribute errors
Lijunhaohz Feb 23, 2024
5fb35e9
Finish GC code examples
Lijunhaohz Feb 24, 2024
5f244b5
reformat gc files
Lijunhaohz Feb 28, 2024
2ba41db
add .rst files for gc
Lijunhaohz Feb 28, 2024
41b84fb
remove requirement.txt for gc
Lijunhaohz Feb 28, 2024
705adef
add requirements_gc.txt
Lijunhaohz Feb 28, 2024
c593d9c
Merge branch 'main' into gc
Lijunhaohz Feb 29, 2024
3f4f4a3
remove requirements_gc.txt
Lijunhaohz Feb 29, 2024
9c539fe
move gc methods into federated_methods.py
Lijunhaohz Feb 29, 2024
a885061
add dataloader interface for gc
Lijunhaohz Feb 29, 2024
e1dfce3
refactor codes
Lijunhaohz Feb 29, 2024
ea0134d
add interface for the model
Lijunhaohz Mar 8, 2024
19ca54e
revise comments for gc
Lijunhaohz Mar 8, 2024
6e6fbed
revise comments for gc
Lijunhaohz Mar 8, 2024
d311453
update models dir name
Lijunhaohz Mar 15, 2024
d36e58e
reformat gc files
Lijunhaohz Mar 15, 2024
56a73bb
standardize name "trainer"
Lijunhaohz Mar 21, 2024
260b521
revise comments for gc
Lijunhaohz Mar 22, 2024
be21c70
merge into one base model for gc
Lijunhaohz Mar 22, 2024
930735a
support multiple datasets
Lijunhaohz Mar 26, 2024
835eb47
reformat files
Lijunhaohz Mar 26, 2024
8c11013
support repetitive experiment
Lijunhaohz Mar 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/examples/configs/config_gc_FedAvg.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
model: 'FedAvg'
device: 'cpu'
num_repeat: 1
repeat: None
num_rounds: 200
local_epoch: 1
lr: 0.001
Expand All @@ -14,5 +12,5 @@ seed: 10
datapath: './data'
outbase: './outputs'
convert_x: False
num_clients: 10
num_trainers: 10
overlap: False
4 changes: 1 addition & 3 deletions docs/examples/configs/config_gc_FedProx.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
model: 'FedProx'
device: 'cpu'
num_repeat: 1
repeat: None
num_rounds: 200
local_epoch: 1
lr: 0.001
Expand All @@ -14,6 +12,6 @@ seed: 10
datapath: './data'
outbase: './outputs'
convert_x: False
num_clients: 10
num_trainers: 10
overlap: False
mu: 0.01
4 changes: 1 addition & 3 deletions docs/examples/configs/config_gc_GCFL+.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
model: 'GCFL+'
device: 'cpu'
num_repeat: 1
repeat: None
num_rounds: 200
local_epoch: 1
lr: 0.001
Expand All @@ -14,7 +12,7 @@ seed: 10
datapath: './data'
outbase: './outputs'
convert_x: False
num_clients: 10
num_trainers: 10
overlap: False
standardize: False
seq_length: 5
Expand Down
4 changes: 1 addition & 3 deletions docs/examples/configs/config_gc_GCFL+dWs.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
model: 'GCFL+dWs'
device: 'cpu'
num_repeat: 1
repeat: None
num_rounds: 200
local_epoch: 1
lr: 0.001
Expand All @@ -14,7 +12,7 @@ seed: 10
datapath: './data'
outbase: './outputs'
convert_x: False
num_clients: 10
num_trainers: 10
overlap: False
standardize: False
seq_length: 5
Expand Down
4 changes: 1 addition & 3 deletions docs/examples/configs/config_gc_GCFL.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
model: 'GCFL'
device: 'cpu'
num_repeat: 1
repeat: None
num_rounds: 200
local_epoch: 1
lr: 0.001
Expand All @@ -14,7 +12,7 @@ seed: 10
datapath: './data'
outbase: './outputs'
convert_x: False
num_clients: 10
num_trainers: 10
overlap: False
standardize: False
seq_length: 5
Expand Down
4 changes: 1 addition & 3 deletions docs/examples/configs/config_gc_SelfTrain.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
model: 'SelfTrain'
device: 'cpu'
num_repeat: 1
repeat: None
local_epoch: 1
lr: 0.001
weight_decay: 0.0005
Expand All @@ -13,5 +11,5 @@ seed: 10
datapath: './data'
outbase: './outputs'
convert_x: False
num_clients: 10
num_trainers: 10
overlap: False
129 changes: 74 additions & 55 deletions docs/examples/example_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,39 @@
import yaml

sys.path.append("../fedgraph")
from fedgraph.data_process_gc import *
from fedgraph.train_func import *
from fedgraph.utils_gc import *
from src.data_process_gc import *
from src.federated_methods import (
run_GC_fedavg,
run_GC_fedprox,
run_GC_gcfl,
run_GC_gcfl_plus,
run_GC_gcfl_plus_dWs,
run_GC_selftrain,
)
from src.gnn_models import GIN
from src.utils_gc import *

#######################################################################
# Load configuration
# ------------
# Here we load the configuration file for the experiment.
# The configuration file contains the parameters for the experiment.
# The model and dataset are specified by the user here. And the configuration
# The algorithm and dataset are specified by the user here. And the configuration
# file is stored in the `fedgraph/configs` directory.
# Once specified the model, the corresponding configuration file will be loaded.
# Once specified the algorithm, the corresponding configuration file will be loaded.
# Feel free to modify the configuration file to suit your needs.

model = "GCFL"
algorithm = "GCFL"
dataset = "PROTEINS"
dataset_group = "biochem"
multiple_datasets = True # if True, load multiple datasets (dataset_group) instead of a single dataset (dataset)
save_files = False # if True, save the statistics and prediction results into files

config_file = f"docs/examples/configs/config_gc_{model}.yaml"
config_file = f"docs/examples/configs/config_gc_{algorithm}.yaml"
with open(config_file, "r") as file:
config = yaml.safe_load(file)

config["data_group"] = dataset
config["data_group"] = dataset_group if multiple_datasets else dataset

parser = argparse.ArgumentParser()
args = parser.parse_args()
Expand Down Expand Up @@ -72,28 +82,28 @@
# Set output directory
# ------------
# Here we set the output directory for the results.
# The output consists of the statistics of the data on clients and the
# The output consists of the statistics of the data on trainers and the
# accuracy of the model on the test set.

# outdir_base = os.path.join(args.outbase, f'seqLen{args.seq_length}')

if save_files:
outdir_base = args.outbase + "/" + f"{args.model}"
outdir = os.path.join(outdir_base, f"oneDS-nonOverlap")
if args.model in ["SelfTrain"]:
if algorithm in ["SelfTrain"]:
outdir = os.path.join(outdir, f"{args.data_group}")
elif args.model in ["FedAvg", "FedProx"]:
outdir = os.path.join(outdir, f"{args.data_group}-{args.num_clients}clients")
elif args.model in ["GCFL"]:
elif algorithm in ["FedAvg", "FedProx"]:
outdir = os.path.join(outdir, f"{args.data_group}-{args.num_trainers}trainers")
elif algorithm in ["GCFL"]:
outdir = os.path.join(
outdir,
f"{args.data_group}-{args.num_clients}clients",
f"{args.data_group}-{args.num_trainers}trainers",
f"eps_{args.epsilon1}_{args.epsilon2}",
)
elif args.model in ["GCFL+", "GCFL+dWs"]:
elif algorithm in ["GCFL+", "GCFL+dWs"]:
outdir = os.path.join(
outdir,
f"{args.data_group}-{args.num_clients}clients",
f"{args.data_group}-{args.num_trainers}trainers",
f"eps_{args.epsilon1}_{args.epsilon2}",
f"seqLen{args.seq_length}",
)
Expand All @@ -108,20 +118,29 @@
# Here we prepare the data for the experiment.
# The data is split into training and test sets, and then the training set
# is further split into training and validation sets.
# The statistics of the data on clients are also computed and saved.
# The statistics of the data on trainers are also computed and saved.

""" using original features """
print("Preparing data (original features) ...")

splited_data, df_stats = load_single_dataset(
args.datapath,
args.data_group,
num_client=args.num_clients,
batch_size=args.batch_size,
convert_x=args.convert_x,
seed=seed_split_data,
overlap=args.overlap,
)
if multiple_datasets:
splited_data, df_stats = load_multiple_datasets(
datapath=args.datapath,
dataset_group=args.data_group,
batch_size=args.batch_size,
convert_x=args.convert_x,
seed=seed_split_data,
)
else:
splited_data, df_stats = load_single_dataset(
args.datapath,
args.data_group,
num_trainer=args.num_trainers,
batch_size=args.batch_size,
convert_x=args.convert_x,
seed=seed_split_data,
overlap=args.overlap,
)
print("Data prepared.")

if save_files:
Expand All @@ -131,16 +150,20 @@


#######################################################################
# Setup server and clients (trainers)
# Setup server and trainers
# ------------
# Here we set up the server and clients (trainers) for the experiment.
# The server is responsible for federated aggregation (e.g., FedAvg) without
# knowing the local trainer data.
# The clients (trainers) are responsible for local training and testing.

init_clients, _ = setup_clients(splited_data, args)
init_server = setup_server(args)
clients = copy.deepcopy(init_clients)
# Here we set up the server and trainers for the experiment.
# The server is responsible for federated aggregation (e.g., FedAvg) without knowing the local trainer data.
# The trainers are responsible for local training and testing.
# Before setting up those, the user has to specify the base model for the federated learning that applies for both server and trainers.
# The default model is `GIN` (Graph Isomorphism Network) for graph classification.
# They user can also use other models, but the customized model should be compatible.
# That is, `base_model` must have all the required methods and attributes as the default `GIN`
# For the detailed expected format of the model, please refer to the `fedgraph/gnn_models.py`
base_model = GIN
init_trainers, _ = setup_trainers(splited_data, base_model, args)
init_server = setup_server(base_model, args)
trainers = copy.deepcopy(init_trainers)
server = copy.deepcopy(init_server)

print("\nDone setting up devices.")
Expand All @@ -150,47 +173,44 @@
# Federated Training for Graph Classification
# ------------
# Here we run the federated training for graph classification.
# The server starts training of all clients and aggregates the parameters.
# The server starts training of all trainers and aggregates the parameters.
# The output consists of the accuracy of the model on the test set.

print(f"Running {args.model} ...")
if args.model == "SelfTrain":
print(f"Running {algorithm} ...")
if algorithm == "SelfTrain":
output = run_GC_selftrain(
clients=clients, server=server, local_epoch=args.local_epoch
trainers=trainers, server=server, local_epoch=args.local_epoch
)

elif args.model == "FedAvg":
elif algorithm == "FedAvg":
output = run_GC_fedavg(
clients=clients,
trainers=trainers,
server=server,
communication_rounds=args.num_rounds,
local_epoch=args.local_epoch,
samp=None,
)

elif args.model == "FedProx":
elif algorithm == "FedProx":
output = run_GC_fedprox(
clients=clients,
trainers=trainers,
server=server,
communication_rounds=args.num_rounds,
local_epoch=args.local_epoch,
mu=args.mu,
samp=None,
)

elif args.model == "GCFL":
elif algorithm == "GCFL":
output = run_GC_gcfl(
clients=clients,
trainers=trainers,
server=server,
communication_rounds=args.num_rounds,
local_epoch=args.local_epoch,
EPS_1=args.epsilon1,
EPS_2=args.epsilon2,
)

elif args.model == "GCFL+":
elif algorithm == "GCFL+":
output = run_GC_gcfl_plus(
clients=clients,
trainers=trainers,
server=server,
communication_rounds=args.num_rounds,
local_epoch=args.local_epoch,
Expand All @@ -200,9 +220,9 @@
standardize=args.standardize,
)

elif args.model == "GCFL+dWs":
output = run_GC_gcfl_plus(
clients=clients,
elif algorithm == "GCFL+dWs":
output = run_GC_gcfl_plus_dWs(
trainers=trainers,
server=server,
communication_rounds=args.num_rounds,
local_epoch=args.local_epoch,
Expand All @@ -213,14 +233,13 @@
)

else:
raise ValueError(f"Unknown model: {args.model}")
raise ValueError(f"Unknown algorithm: {algorithm}")

#######################################################################
# Save the output
# ------------
# Here we save the results to a file, and the output directory can be specified by the user.
# If save_files == False, the output will not be saved and will only be printed in the console.

if save_files:
outdir_result = os.path.join(outdir, f"accuracy_seed{args.seed}.csv")
pd.DataFrame(output).to_csv(outdir_result)
Expand Down
8 changes: 4 additions & 4 deletions docs/examples/intro.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

ray.init()

from fedgraph.data_process import load_data
from fedgraph.server_class import Server
from fedgraph.trainer_class import Trainer_General
from fedgraph.utils import (
from src.data_process import load_data
from src.server_class import Server
from src.trainer_class import Trainer_General
from src.utils import (
get_1hop_feature_sum,
get_in_comm_indexes,
label_dirichlet_partition,
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/simple_code_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import attridict
import yaml

from fedgraph.federated_methods import FedGCN_Train
from fedgraph.utils import federated_data_loader
from src.federated_methods import FedGCN_Train
from src.utils import federated_data_loader

#######################################################################
# Load configuration and federated data
Expand Down
Loading