Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions Task_1/fets_challenge/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def save_checkpoint(checkpoint_folder, aggregator,
"""
# Save aggregator tensor_db
aggregator.tensor_db.tensor_db.to_pickle(f'checkpoint/{checkpoint_folder}/aggregator_tensor_db.pkl')
for col in collaborator_names:
collaborators[col].tensor_db.tensor_db.to_pickle(f'checkpoint/{checkpoint_folder}/{col}_tensor_db.pkl')
with open(f'checkpoint/{checkpoint_folder}/state.pkl', 'wb') as f:
pickle.dump([collaborator_names, round_num, collaborator_time_stats, total_simulated_time,
best_dice, best_dice_over_time_auc, collaborators_chosen_each_round,
Expand All @@ -50,10 +48,4 @@ def load_checkpoint(checkpoint_folder):
with open(f'checkpoint/{checkpoint_folder}/state.pkl', 'rb') as f:
state = pickle.load(f)

# load each collaborator tensor_db
collaborator_names = state[0]
collaborator_tensor_dbs = {}
for col in collaborator_names:
collaborator_tensor_dbs[col] = pd.read_pickle(f'checkpoint/{checkpoint_folder}/{col}_tensor_db.pkl')

return state + [aggregator_tensor_db] + [collaborator_tensor_dbs]
return state + [aggregator_tensor_db]
7 changes: 2 additions & 5 deletions Task_1/fets_challenge/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,16 +356,13 @@ def run_challenge_experiment(aggregation_function,
[loaded_collaborator_names, starting_round_num, collaborator_time_stats,
total_simulated_time, best_dice, best_dice_over_time_auc,
collaborators_chosen_each_round, collaborator_times_per_round,
experiment_results, summary, agg_tensor_db, col_tensor_dbs] = state
experiment_results, summary, agg_tensor_db] = state

if loaded_collaborator_names != collaborator_names:
logger.error(f'Collaborator names found in checkpoint ({loaded_collaborator_names}) '
f'do not match provided collaborators ({collaborator_names})')
exit(1)

for col in loaded_collaborator_names:
collaborators[col].tensor_db.tensor_db = col_tensor_dbs[col]

logger.info(f'Previous summary for round {starting_round_num}')
logger.info(summary)

Expand Down Expand Up @@ -457,7 +454,7 @@ def run_challenge_experiment(aggregation_function,
# FIXME: this doesn't break up each task. We need this if we're doing straggler handling
for t, col in times_list:
# set the task_runner data loader
task_runner.data_loader = collaborator_data_loaders[col]
task_runner.data = collaborator_data_loaders[col]

# run the collaborator
collaborators[col].run_simulation()
Expand Down