-
Notifications
You must be signed in to change notification settings - Fork 5
Monitor comm cost #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
0bd522a
aad0d7d
cf88bb7
9ecb93b
23b65c1
768c56a
aa3c6df
7e7ebd6
b24f0f8
1c03885
86254d7
2955b52
05f85d5
7bf2d21
7c3e8bd
2d751b5
e0b4dd2
5670080
88ec66d
3ce50e3
dfcd478
aa94ab2
0794e40
278d3ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,21 +50,26 @@ def run_fedgraph(args: attridict) -> None: | |
| Input data for the federated learning task. Format depends on the specific task and | ||
| will be explained in more detail below inside specific functions. | ||
| """ | ||
| monitor = Monitor() | ||
| monitor.init_time_start() | ||
|
|
||
| if args.fedgraph_task != "NC" or not args.use_huggingface: | ||
| data = data_loader(args) | ||
| else: | ||
| # use hugging_face instead of use data_loader | ||
| print("Using hugging_face for local loading") | ||
| data = None | ||
| if args.fedgraph_task == "NC": | ||
| run_NC(args, data) | ||
| run_NC(args, data, monitor) | ||
| elif args.fedgraph_task == "GC": | ||
| run_GC(args, data) | ||
| run_GC(args, data, monitor) | ||
| elif args.fedgraph_task == "LP": | ||
| run_LP(args) | ||
| run_LP(args, monitor) | ||
|
|
||
|
|
||
| def run_NC(args: attridict, data: Any = None) -> None: | ||
| def run_NC( | ||
| args: attridict, data: Any = None, monitor: Optional[Monitor] = None | ||
| ) -> None: | ||
| """ | ||
| Train a Federated Graph Classification model using multiple trainers. | ||
|
|
||
|
|
@@ -80,12 +85,10 @@ def run_NC(args: attridict, data: Any = None) -> None: | |
| Configuration arguments | ||
| data: tuple | ||
| """ | ||
| if monitor is None: | ||
| monitor = Monitor() | ||
| ray.init() | ||
| start_time = time.time() | ||
| if args.use_cluster: | ||
| monitor = Monitor() | ||
| # Start tracking initialization time | ||
| monitor.init_time_start() | ||
| torch.manual_seed(42) | ||
| if args.num_hops == 0: | ||
| print("Changing method to FedAvg") | ||
|
|
@@ -219,13 +222,12 @@ def __init__(self, *args: Any, **kwds: Any): | |
| # without knowing the local trainer data | ||
|
|
||
| server = Server(features.shape[1], args_hidden, class_num, device, trainers, args) | ||
| if args.use_cluster: | ||
| # End initialization time tracking | ||
| monitor.init_time_end() | ||
|
|
||
| # End initialization time tracking | ||
| monitor.init_time_end() | ||
| server.broadcast_params(-1) | ||
| pretrain_start = time.time() | ||
| if args.use_cluster: | ||
| monitor.pretrain_time_start() | ||
| monitor.pretrain_time_start() | ||
| if args.method != "FedAvg": | ||
| ####################################################################### | ||
| # Pre-Train Communication of FedGCN | ||
|
|
@@ -301,9 +303,9 @@ def __init__(self, *args: Any, **kwds: Any): | |
| server.trainers[i].load_feature_aggregation.remote(trainer_aggregation) | ||
| print("clients received feature aggregation from server") | ||
| [trainer.relabel_adj.remote() for trainer in server.trainers] | ||
| if args.use_cluster: | ||
| monitor.pretrain_time_end(30) | ||
| monitor.train_time_start() | ||
|
|
||
| monitor.pretrain_time_end(30) | ||
| monitor.train_time_start() | ||
| ####################################################################### | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could only sleep 30s when use_cluster |
||
| # Federated Training | ||
| # ------------------ | ||
|
|
@@ -313,8 +315,7 @@ def __init__(self, *args: Any, **kwds: Any): | |
| print("global_rounds", args.global_rounds) | ||
| for i in range(args.global_rounds): | ||
| server.train(i) | ||
| if args.use_cluster: | ||
| monitor.train_time_end(30) | ||
| monitor.train_time_end(30) | ||
| training_time = time.time() - training_start | ||
| if args.use_encryption: | ||
| if hasattr(server, "aggregation_stats") and server.aggregation_stats: | ||
|
|
@@ -376,7 +377,7 @@ def __init__(self, *args: Any, **kwds: Any): | |
| ray.shutdown() | ||
|
|
||
|
|
||
| def run_GC(args: attridict, data: Any, base_model: Any = GIN) -> None: | ||
| def run_GC(args: attridict, data: Any, monitor: Optional[Monitor] = None) -> None: | ||
| """ | ||
| Entrance of the training process for graph classification. | ||
|
|
||
|
|
@@ -416,7 +417,7 @@ def run_GC(args: attridict, data: Any, base_model: Any = GIN) -> None: | |
| args.device = torch.device("cpu") | ||
| num_gpus_per_trainer = 0 | ||
|
|
||
| if args.use_cluster: | ||
| if args.use_cluster and monitor is not None: | ||
| monitor = Monitor() | ||
| # Start tracking initialization time | ||
| monitor.init_time_start() | ||
|
|
@@ -511,7 +512,7 @@ def __init__(self, idx, splited_data, dataset_trainer_name, cmodel_gc, args): # | |
| # TODO: check and modify whether deepcopy should be added. | ||
| # trainers = copy.deepcopy(init_trainers) | ||
| # server = copy.deepcopy(init_server) | ||
| if args.use_cluster: | ||
| if args.use_cluster and monitor is not None: | ||
| # End initialization time tracking after server setup is complete | ||
| monitor.init_time_end() | ||
| print("\nDone setting up devices.") | ||
|
|
@@ -521,14 +522,18 @@ def __init__(self, idx, splited_data, dataset_trainer_name, cmodel_gc, args): # | |
|
|
||
| model_parameters = { | ||
| "SelfTrain": lambda: run_GC_selftrain( | ||
| trainers=trainers, server=server, local_epoch=args.local_epoch | ||
| trainers=trainers, | ||
| server=server, | ||
| local_epoch=args.local_epoch, | ||
| monitor=monitor, | ||
| ), | ||
| "FedAvg": lambda: run_GC_Fed_algorithm( | ||
| trainers=trainers, | ||
| server=server, | ||
| communication_rounds=args.num_rounds, | ||
| local_epoch=args.local_epoch, | ||
| algorithm="FedAvg", | ||
| monitor=monitor, | ||
| ), | ||
| "FedProx": lambda: run_GC_Fed_algorithm( | ||
| trainers=trainers, | ||
|
|
@@ -546,6 +551,7 @@ def __init__(self, idx, splited_data, dataset_trainer_name, cmodel_gc, args): # | |
| EPS_1=args.epsilon1, | ||
| EPS_2=args.epsilon2, | ||
| algorithm_type="gcfl", | ||
| monitor=monitor, | ||
| ), | ||
| "GCFL+": lambda: run_GCFL_algorithm( | ||
| trainers=trainers, | ||
|
|
@@ -585,7 +591,9 @@ def __init__(self, idx, splited_data, dataset_trainer_name, cmodel_gc, args): # | |
|
|
||
|
|
||
| # The following code is the implementation of different federated graph classification methods. | ||
| def run_GC_selftrain(trainers: list, server: Any, local_epoch: int) -> dict: | ||
| def run_GC_selftrain( | ||
| trainers: list, server: Any, local_epoch: int, monitor: Optional[Monitor] = None | ||
| ) -> dict: | ||
| """ | ||
| Run the training and testing process of self-training algorithm. | ||
| It only trains the model locally, and does not perform weights aggregation. | ||
|
|
@@ -607,17 +615,16 @@ def run_GC_selftrain(trainers: list, server: Any, local_epoch: int) -> dict: | |
| """ | ||
|
|
||
| # all trainers are initialized with the same weights | ||
| if server.use_cluster: | ||
| monitor = Monitor() | ||
| if monitor is not None: | ||
| monitor.pretrain_time_start() | ||
| global_params_id = ray.put(server.W) | ||
| for trainer in trainers: | ||
| trainer.update_params.remote(global_params_id) | ||
| if server.use_cluster: | ||
| if monitor is not None: | ||
| monitor.pretrain_time_end(30) | ||
| all_accs = {} | ||
| acc_refs = [] | ||
| if server.use_cluster: | ||
| if monitor is not None: | ||
| monitor.train_time_start() | ||
| for trainer in trainers: | ||
| trainer.local_train.remote(local_epoch=local_epoch) | ||
|
|
@@ -638,7 +645,7 @@ def run_GC_selftrain(trainers: list, server: Any, local_epoch: int) -> dict: | |
| acc_refs = left | ||
| if not acc_refs: | ||
| break | ||
| if server.use_cluster: | ||
| if monitor is not None: | ||
| monitor.train_time_end(30) | ||
| frame = pd.DataFrame(all_accs).T.iloc[:, [2]] | ||
| frame.columns = ["test_acc"] | ||
|
|
@@ -656,6 +663,7 @@ def run_GC_Fed_algorithm( | |
| algorithm: str, | ||
| mu: float = 0.0, | ||
| sampling_frac: float = 1.0, | ||
| monitor: Optional[Monitor] = None, | ||
| ) -> pd.DataFrame: | ||
| """ | ||
| Run the training and testing process of FedAvg or FedProx algorithm. | ||
|
|
@@ -684,14 +692,14 @@ def run_GC_Fed_algorithm( | |
| frame: pd.DataFrame | ||
| Pandas dataframe with test accuracies | ||
| """ | ||
| if server.use_cluster: | ||
| monitor = Monitor() | ||
| if monitor is not None: | ||
| monitor.pretrain_time_start() | ||
| global_params_id = ray.put(server.W) | ||
| for trainer in trainers: | ||
| trainer.update_params.remote(global_params_id) | ||
| if server.use_cluster: | ||
| if monitor is not None: | ||
| monitor.pretrain_time_end(30) | ||
| if monitor is not None: | ||
| monitor.train_time_start() | ||
| for c_round in range(1, communication_rounds + 1): | ||
| if (c_round) % 10 == 0: | ||
|
|
@@ -740,7 +748,7 @@ def highlight_max(s: pd.Series) -> list: | |
| is_max = s == s.max() | ||
| return ["background-color: yellow" if v else "" for v in is_max] | ||
|
|
||
| if server.use_cluster: | ||
| if monitor is not None: | ||
| monitor.train_time_end(30) | ||
| fs = frame.style.apply(highlight_max).data | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. init inside the function so it wont be none, too many if |
||
| print(fs) | ||
|
|
@@ -758,6 +766,7 @@ def run_GCFL_algorithm( | |
| algorithm_type: str, | ||
| seq_length: int = 0, | ||
| standardize: bool = True, | ||
| monitor: Optional[Monitor] = None, | ||
| ) -> pd.DataFrame: | ||
| """ | ||
| Run the specified GCFL algorithm. | ||
|
|
@@ -792,8 +801,7 @@ def run_GCFL_algorithm( | |
| raise ValueError( | ||
| "Invalid algorithm_type. Must be 'gcfl', 'gcfl_plus', or 'gcfl_plus_dWs'." | ||
| ) | ||
| if server.use_cluster: | ||
| monitor = Monitor() | ||
| if monitor is not None: | ||
| monitor.pretrain_time_start() | ||
| cluster_indices = [np.arange(len(trainers)).astype("int")] | ||
| trainer_clusters = [[trainers[i] for i in idcs] for idcs in cluster_indices] | ||
|
|
@@ -806,10 +814,10 @@ def run_GCFL_algorithm( | |
|
|
||
| for trainer in trainers: | ||
| trainer.update_params.remote(global_params_id) | ||
| if server.use_cluster: | ||
| if monitor is not None: | ||
| monitor.pretrain_time_end(30) | ||
| acc_trainers: List[Any] = [] | ||
| if server.use_cluster: | ||
| if monitor is not None: | ||
| monitor.train_time_start() | ||
| for c_round in range(1, communication_rounds + 1): | ||
| if (c_round) % 10 == 0: | ||
|
|
@@ -897,7 +905,7 @@ def run_GCFL_algorithm( | |
| server.cache_model( | ||
| idc, ray.get(trainers[idc[0]].get_total_weight.remote()), acc_trainers | ||
| ) | ||
| if server.use_cluster: | ||
| if monitor is not None: | ||
| monitor.train_time_end(30) | ||
| results = np.zeros([len(trainers), len(server.model_cache)]) | ||
| for i, (idcs, W, accs) in enumerate(server.model_cache): | ||
|
|
@@ -920,7 +928,7 @@ def run_GCFL_algorithm( | |
| return frame | ||
|
|
||
|
|
||
| def run_LP(args: attridict) -> None: | ||
| def run_LP(args: attridict, monitor: Monitor) -> None: | ||
| """ | ||
| Implements various federated learning methods for link prediction tasks with support | ||
| for online learning and buffer mechanisms. Handles temporal aspects of link prediction | ||
|
|
@@ -1023,10 +1031,9 @@ def __init__(self, *args, **kwargs): # type: ignore | |
|
|
||
| current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| ray.init() | ||
| if args.use_cluster: | ||
| # Initialize monitor and start tracking initialization time | ||
| monitor = Monitor() | ||
| if args.use_cluster and monitor is not None: | ||
| monitor.init_time_start() | ||
|
|
||
| # Append paths relative to the current script's directory | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same issue, still use args.use_cluster |
||
| sys.path.append(os.path.join(current_dir, "../fedgraph")) | ||
| sys.path.append(os.path.join(current_dir, "../../")) | ||
|
|
@@ -1063,12 +1070,12 @@ def __init__(self, *args, **kwargs): # type: ignore | |
| meta_data=meta_data, | ||
| hidden_channels=hidden_channels, | ||
| ) | ||
| if args.use_cluster: | ||
| if args.use_cluster and monitor is not None: | ||
| # End initialization time tracking | ||
| monitor.init_time_end() | ||
|
|
||
| """Broadcast the global model parameter to all clients""" | ||
| if args.use_cluster: | ||
| if args.use_cluster and monitor is not None: | ||
| monitor.pretrain_time_start() | ||
| global_model_parameter = ( | ||
| server.get_model_parameter() | ||
|
|
@@ -1097,7 +1104,7 @@ def __init__(self, *args, **kwargs): # type: ignore | |
| else: | ||
| result_writer = None | ||
| time_writer = None | ||
| if args.use_cluster: | ||
| if args.use_cluster and monitor is not None: | ||
| monitor.pretrain_time_end(30) | ||
| monitor.train_time_start() | ||
| # from 2012-04-03 to 2012-04-13 | ||
|
|
@@ -1147,7 +1154,7 @@ def __init__(self, *args, **kwargs): # type: ignore | |
| start_time_float_format, | ||
| end_time_float_format, | ||
| ) = to_next_day(start_time=start_time, end_time=end_time, method=method) | ||
| if args.use_cluster: | ||
| if args.use_cluster and monitor is not None: | ||
| monitor.train_time_end(30) | ||
| if result_writer is not None and time_writer is not None: | ||
| result_writer.close() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this, always init inside of the function
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, that make sense, thanks for the advice. Already fix it.