Skip to content

Commit b4c783b

Browse files
authored
Merge pull request #18 from FETS-AI/train_loop_stops
Train loop stops
2 parents 6e8694d + da26b46 commit b4c783b

File tree

3 files changed

+105
-25
lines changed

3 files changed

+105
-25
lines changed

Task_1/FeTS_Challenge.ipynb

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
},
1818
{
1919
"cell_type": "code",
20-
"execution_count": 1,
20+
"execution_count": null,
2121
"metadata": {},
2222
"outputs": [],
2323
"source": [
@@ -27,7 +27,7 @@
2727
},
2828
{
2929
"cell_type": "code",
30-
"execution_count": 2,
30+
"execution_count": null,
3131
"metadata": {},
3232
"outputs": [],
3333
"source": [
@@ -59,7 +59,7 @@
5959
},
6060
{
6161
"cell_type": "code",
62-
"execution_count": 3,
62+
"execution_count": null,
6363
"metadata": {},
6464
"outputs": [],
6565
"source": [
@@ -106,7 +106,7 @@
106106
},
107107
{
108108
"cell_type": "code",
109-
"execution_count": 4,
109+
"execution_count": null,
110110
"metadata": {},
111111
"outputs": [],
112112
"source": [
@@ -256,7 +256,7 @@
256256
},
257257
{
258258
"cell_type": "code",
259-
"execution_count": 5,
259+
"execution_count": null,
260260
"metadata": {},
261261
"outputs": [],
262262
"source": [
@@ -345,7 +345,7 @@
345345
},
346346
{
347347
"cell_type": "code",
348-
"execution_count": 6,
348+
"execution_count": null,
349349
"metadata": {},
350350
"outputs": [],
351351
"source": [
@@ -489,7 +489,7 @@
489489
},
490490
{
491491
"cell_type": "code",
492-
"execution_count": 7,
492+
"execution_count": null,
493493
"metadata": {},
494494
"outputs": [],
495495
"source": [
@@ -606,46 +606,81 @@
606606
"- ```device``` : Which device to use for training and validation"
607607
]
608608
},
609+
{
610+
"cell_type": "markdown",
611+
"metadata": {},
612+
"source": [
613+
"## Setting up the experiment\n",
614+
"Now that we've defined our custom functions, the last thing to do is to configure the experiment. The following cell shows the various settings you can change in your experiment.\n",
615+
"\n",
616+
"Note that ```rounds_to_train``` can be set as high as you want. However, the experiment will exit once the simulated time value exceeds 1 week of simulated time, or if the specified number of rounds has completed."
617+
]
618+
},
609619
{
610620
"cell_type": "code",
611621
"execution_count": null,
612622
"metadata": {},
613623
"outputs": [],
614624
"source": [
625+
"# change any of these you wish to your custom functions. You may leave defaults if you wish.\n",
615626
"aggregation_function = weighted_average_aggregation\n",
616627
"choose_training_collaborators = all_collaborators_train\n",
617628
"training_hyper_parameters_for_round = constant_hyper_parameters\n",
618629
"validation_functions = [('sensitivity', sensitivity), ('specificity', specificity)]\n",
630+
"\n",
631+
"# Final scoring will be on partitioning_1, partitioning_2, and a hidden partitioning\n",
632+
"# We encourage you to experiment with other partitionings\n",
619633
"institution_split_csv_filename = 'partitioning_1.csv'\n",
634+
"\n",
635+
"# change this to point to the parent directory of the data\n",
620636
"brats_training_data_parent_dir = '/raid/datasets/FeTS21/MICCAI_FeTS2021_TrainingData'\n",
637+
"\n",
638+
"# increase this if you need a longer history for your algorithms\n",
639+
"# decrease this if you need to reduce system RAM consumption\n",
621640
"db_store_rounds = 5\n",
622-
"rounds_to_train = 5\n",
641+
"\n",
642+
"# this is passed to PyTorch, so set it accordingly for your system\n",
623643
"device = 'cuda'\n",
624644
"\n",
625-
"run_challenge_experiment(aggregation_function=aggregation_function,\n",
626-
" choose_training_collaborators=choose_training_collaborators,\n",
627-
" training_hyper_parameters_for_round=training_hyper_parameters_for_round,\n",
628-
" validation_functions=validation_functions,\n",
629-
" institution_split_csv_filename=institution_split_csv_filename,\n",
630-
" brats_training_data_parent_dir=brats_training_data_parent_dir,\n",
631-
" db_store_rounds=db_store_rounds,\n",
632-
" rounds_to_train=rounds_to_train,\n",
633-
" device=device)"
645+
"# you'll want to increase this most likely. You can set it as high as you like, \n",
646+
"# however, the experiment will exit once the simulated time exceeds one week. \n",
647+
"rounds_to_train = 5"
648+
]
649+
},
650+
{
651+
"cell_type": "code",
652+
"execution_count": null,
653+
"metadata": {},
654+
"outputs": [],
655+
"source": [
656+
"# the scores are returned in a Pandas dataframe\n",
657+
"scores_dataframe = run_challenge_experiment(\n",
658+
" aggregation_function=aggregation_function,\n",
659+
" choose_training_collaborators=choose_training_collaborators,\n",
660+
" training_hyper_parameters_for_round=training_hyper_parameters_for_round,\n",
661+
" validation_functions=validation_functions,\n",
662+
" institution_split_csv_filename=institution_split_csv_filename,\n",
663+
" brats_training_data_parent_dir=brats_training_data_parent_dir,\n",
664+
" db_store_rounds=db_store_rounds,\n",
665+
" rounds_to_train=rounds_to_train,\n",
666+
" device=device)"
634667
]
635668
},
636669
{
637670
"cell_type": "code",
638671
"execution_count": null,
639672
"metadata": {},
640673
"outputs": [],
641-
"source": []
674+
"source": [
675+
"scores_dataframe"
676+
]
642677
}
643678
],
644679
"metadata": {
645680
"kernelspec": {
646-
"display_name": "fets_challenge_test_2",
681+
"display_name": "openfl",
647682
"language": "python",
648-
"name": "fets_challenge_test_2"
683+
"name": "openfl"
649684
},
650685
"language_info": {
651686
"codemirror_mode": {
@@ -657,7 +692,7 @@
657692
"name": "python",
658693
"nbconvert_exporter": "python",
659694
"pygments_lexer": "ipython3",
660-
"version": "3.6.12"
695+
"version": "3.6.13"
661696
}
662697
},
663698
"nbformat": 4,

Task_1/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ Along with the typical DICE and Hausdorff metrics, we include a "time to converg
3232

3333
The time to convergence metric will be computed as the area under the validation learning curve over 1 week of simulated time where the horizontal axis measures simulated runtime and the vertical axis measures the current best score, computed as the average of enhancing tumor, tumor core, and whole tumor DICE scores over the validation split of the training data.
3434

35+
You can find the code for the "time to convergence metric" in the experiment.py file by searching for ## CONVERGENCE METRIC COMPUTATION.
36+
37+
### How Simulated Time is computed
3538
The simulated time is stochastic, and computed per collaborator, per round, with the round time equaling the greatest round time of all collaborators in the round.
3639

3740
A given collaborator's round time is computed as the sum of:
@@ -57,8 +60,6 @@ We assign these network and compute distributions by drawing uniform-randomly fr
5760

5861
For a given collaborator, these normal distributions are constant throughout the experiment. Again, each possible timing distribution is based on actual timing information from a subset of the hospitals in the FeTS intitiative. You can find these distributions in the experiment.py file (search for ## COLLABORATOR TIMING DISTRIBUTIONS), as well as the random seed used to ensure reproducibility.
5962

60-
You can find the code for the "time to convergence metric" in the experiment.py file by searching for ## CONVERGENCE METRIC COMPUTATION.
61-
6263
## Data Partitioning and Sharding
6364
The FeTS 2021 data release consists of a training set and two CSV files - each providing information for how to partition the training data into non-IID institutional subsets. The release will contain subfolders for single patient records whose names have the format `FeTS21_Training_###`, and two CSV files:
6465
- **partitioning_1.csv**

Task_1/fets_challenge/experiment.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pathlib import Path
1212

1313
import numpy as np
14+
import pandas as pd
1415
from openfl.utilities import split_tensor_dict_for_holdouts, TensorKey
1516
from openfl.protocols import utils
1617
import openfl.native as fx
@@ -19,6 +20,10 @@
1920
from .custom_aggregation_wrapper import CustomAggregationWrapper
2021

2122
# one week
23+
# MINUTE = 60
24+
# HOUR = 60 * MINUTE
25+
# DAY = 24 * HOUR
26+
# WEEK = 7 * DAY
2227
MAX_SIMULATION_TIME = 7 * 24 * 60 * 60
2328

2429
## COLLABORATOR TIMING DISTRIBUTIONS
@@ -194,6 +199,13 @@ def compute_times_per_collaborator(collaborator_names,
194199
data_size *= epochs_per_round
195200
time += data_size * training_time_per
196201

202+
# if training, we also validate the locally updated model
203+
data_size = data.get_valid_data_size()
204+
validation_time_per = np.random.normal(loc=stats.validation_mean,
205+
scale=stats.validation_std)
206+
validation_time_per = max(1, validation_time_per)
207+
time += data_size * validation_time_per
208+
197209
# upload time
198210
upload_time = np.random.normal(loc=stats.upload_speed_mean,
199211
scale=stats.upload_speed_std)
@@ -295,6 +307,19 @@ def run_challenge_experiment(aggregation_function,
295307
best_dice = -1.0
296308
best_dice_over_time_auc = 0
297309

310+
# results dataframe data
311+
experiment_results = {
312+
'round':[],
313+
'time': [],
314+
'convergence_score': [],
315+
'binary_dice_wt': [],
316+
'binary_dice_et': [],
317+
'binary_dice_tc': [],
318+
'hausdorff95_wt': [],
319+
'hausdorff95_et': [],
320+
'hausdorff95_tc': [],
321+
}
322+
298323
for round_num in range(rounds_to_train):
299324
# pick collaborators to train for the round
300325
training_collaborators = choose_training_collaborators(collaborator_names,
@@ -416,13 +441,32 @@ def run_challenge_experiment(aggregation_function,
416441
# End of round summary
417442
summary = '"**** END OF ROUND {} SUMMARY *****"'.format(round_num)
418443
summary += "\n\tSimulation Time: {} minutes".format(round(total_simulated_time / 60, 2))
419-
summary += "\n\tProjected Convergence Score: {}".format(projected_auc)
444+
summary += "\n\t(Projected) Convergence Score: {}".format(projected_auc)
420445
summary += "\n\tBinary DICE WT: {}".format(binary_dice_wt)
421446
summary += "\n\tBinary DICE ET: {}".format(binary_dice_et)
422447
summary += "\n\tBinary DICE TC: {}".format(binary_dice_tc)
423448
summary += "\n\tHausdorff95 WT: {}".format(hausdorff95_wt)
424449
summary += "\n\tHausdorff95 ET: {}".format(hausdorff95_et)
425450
summary += "\n\tHausdorff95 TC: {}".format(hausdorff95_tc)
426451

452+
experiment_results['round'].append(round_num)
453+
experiment_results['time'].append(total_simulated_time)
454+
experiment_results['convergence_score'].append(projected_auc)
455+
experiment_results['binary_dice_wt'].append(binary_dice_wt)
456+
experiment_results['binary_dice_et'].append(binary_dice_et)
457+
experiment_results['binary_dice_tc'].append(binary_dice_tc)
458+
experiment_results['hausdorff95_wt'].append(hausdorff95_wt)
459+
experiment_results['hausdorff95_et'].append(hausdorff95_et)
460+
experiment_results['hausdorff95_tc'].append(hausdorff95_tc)
461+
427462
logger.info(summary)
428-
463+
464+
# if the total_simulated_time has exceeded the maximum time, we break
465+
# in practice, this means that the previous round's model is the last model scored,
466+
# so a long final round should not actually benefit the competitor, since that final
467+
# model is never globally validated
468+
if total_simulated_time > MAX_SIMULATION_TIME:
469+
logger.info("Simulation time exceeded. Ending Experiment")
470+
break
471+
472+
return pd.DataFrame.from_dict(experiment_results)

0 commit comments

Comments
 (0)