-
-
Notifications
You must be signed in to change notification settings - Fork 51
Dataloading utils #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
daler3
wants to merge
22
commits into
OpenMined:master
Choose a base branch
from
daler3:dataloading-utils
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
327e663
Started exploration notebook with synthea data
daler3 d92667a
Feature loading
daler3 fd333ca
Added model and training loop
daler3 30a2a38
Added confusion matrix
daler3 5154a38
started experimenting with dualhead
daler3 c782ddc
Added verticalfederateddataset class
daler3 465a91c
Added dataset parameter to split_data
daler3 d4615e3
Reformatted and added create vertical method
daler3 ca1ff75
added comments and TODOs to the dataset file
daler3 693a334
first version of verticalFederatedDataLoader
daler3 f566e37
Added comments for TODOs
daler3 e382c6d
corrected index of last tensor in split_data
daler3 74bafec
compacted sum of the len in verticalfeddataloader
daler3 bb8cae7
Uploaded skeleton for dualheaded loaders
daler3 c6eba1f
Initial commit
daler3 593a1ff
dataloaders and datasets for dualheaded
daler3 a3ad412
Updated utility functions
daler3 901dff8
updated datasets
daler3 8bdc9f3
updated to use custom dataloaders, updated example notebook
daler3 54a4f16
added enhanced worker class (wip)
daler3 f21d8af
changed workers with worker's ids in dictionaries; added models' segm…
daler3 96ed8e9
Addressed Tom's comments
daler3 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Testing split functions and dataloading\n", | ||
"\n", | ||
"In this section, we test split functions (utils), custom datasets classes and dataloading (with standard pytorch dataloader). " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from __future__ import print_function\n", | ||
"import syft as sy\n", | ||
"import torch\n", | ||
"from torch.utils.data import Dataset\n", | ||
"from torch.utils.data import DataLoader\n", | ||
"from torch.utils.data._utils.collate import default_collate\n", | ||
"from typing import List, Tuple\n", | ||
"from uuid import UUID\n", | ||
"from uuid import uuid4\n", | ||
"from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler\n", | ||
"\n", | ||
"from abc import ABC, abstractmethod\n", | ||
"from torchvision import datasets, transforms\n", | ||
"\n", | ||
"import utils\n", | ||
"import dataloaders\n", | ||
"\n", | ||
"hook = sy.TorchHook(torch)\n", | ||
"\n", | ||
"transform = transforms.Compose([transforms.ToTensor(),\n", | ||
" transforms.Normalize((0.5,), (0.5,)),\n", | ||
" ])\n", | ||
"trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)\n", | ||
"#trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)\n", | ||
"\n", | ||
"# create some workers\n", | ||
"client_1 = sy.VirtualWorker(hook, id=\"client_1\")\n", | ||
"client_2 = sy.VirtualWorker(hook, id=\"client_2\")\n", | ||
"\n", | ||
"server = sy.VirtualWorker(hook, id= \"server\") \n", | ||
"\n", | ||
"data_owners = [client_1, client_2]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#get a verticalFederatedDatase\n", | ||
"vfd = utils.split_data_create_vertical_dataset(trainset, data_owners)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"loader = DataLoader(vfd, batch_size=4, shuffle=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"{'client_1': [tensor([[[[-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" ...,\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.]]],\n", | ||
"\n", | ||
"\n", | ||
" [[[-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" ...,\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.]]],\n", | ||
"\n", | ||
"\n", | ||
" [[[-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" ...,\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.]]],\n", | ||
"\n", | ||
"\n", | ||
" [[[-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" ...,\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.],\n", | ||
" [-1., -1., -1., ..., -1., -1., -1.]]]]), tensor([9, 8, 2, 1]), tensor([51574., 39668., 24844., 32204.], dtype=torch.float64)], 'client_2': [tensor([[[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" ...,\n", | ||
" [ 0.1686, 0.9922, 0.5137, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [ 0.1686, 0.9922, 0.3412, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]],\n", | ||
"\n", | ||
"\n", | ||
" [[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" ...,\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]],\n", | ||
"\n", | ||
"\n", | ||
" [[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" ...,\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]],\n", | ||
"\n", | ||
"\n", | ||
" [[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" ...,\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", | ||
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]]]), tensor([9, 8, 2, 1]), tensor([51574., 39668., 24844., 32204.], dtype=torch.float64)]}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for el in loader: \n", | ||
" print(el)\n", | ||
" break" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#as in https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/Configuration_1.ipynb\n", | ||
"from torch import nn, optim\n", | ||
"\n", | ||
"model_locations = [client_1, client_2, server]\n", | ||
"\n", | ||
"input_size= [28*14, 28*14]\n", | ||
"hidden_sizes= {\"client_1\": [32, 64], \"client_2\":[32, 64], \"server\":[128, 64]}\n", | ||
"\n", | ||
"#create model segment for each worker\n", | ||
"models = {\n", | ||
" \"client_1\": nn.Sequential(\n", | ||
" nn.Linear(input_size[0], hidden_sizes[\"client_1\"][0]),\n", | ||
" nn.ReLU(),\n", | ||
" nn.Linear(hidden_sizes[\"client_1\"][0], hidden_sizes[\"client_1\"][1]),\n", | ||
" nn.ReLU(),\n", | ||
" ),\n", | ||
" \"client_2\": nn.Sequential(\n", | ||
" nn.Linear(input_size[1], hidden_sizes[\"client_2\"][0]),\n", | ||
" nn.ReLU(),\n", | ||
" nn.Linear(hidden_sizes[\"client_2\"][0], hidden_sizes[\"client_2\"][1]),\n", | ||
" nn.ReLU(),\n", | ||
" ),\n", | ||
" \"server\": nn.Sequential(\n", | ||
" nn.Linear(hidden_sizes[\"server\"][0], hidden_sizes[\"server\"][1]),\n", | ||
" nn.ReLU(),\n", | ||
" nn.Linear(hidden_sizes[\"server\"][1], 10),\n", | ||
" nn.LogSoftmax(dim=1)\n", | ||
" )\n", | ||
"}\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"# Create optimisers for each segment and link to their segment\n", | ||
"optimizers = [\n", | ||
" optim.SGD(models[location.id].parameters(), lr=0.05,)\n", | ||
" for location in model_locations\n", | ||
"]\n", | ||
"\n", | ||
"\n", | ||
"#send model segement to each client and server\n", | ||
"for location in model_locations:\n", | ||
" models[location.id].send(location)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from __future__ import print_function | ||
import syft as sy | ||
import torch | ||
from torch.utils.data import Dataset | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data._utils.collate import default_collate | ||
from typing import List, Tuple | ||
from uuid import UUID | ||
from uuid import uuid4 | ||
from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler | ||
|
||
import datasets | ||
|
||
|
||
"""I think this is not needed anymore""" | ||
|
||
|
||
class SinglePartitionDataLoader(DataLoader): | ||
"""DataLoader for a single vertically-partitioned dataset""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
#self.collate_fn = id_collate_fn | ||
|
||
|
||
|
||
class VerticalFederatedDataLoader(DataLoader): | ||
"""Dataloader which batches data from a complete | ||
set of vertically-partitioned datasets | ||
|
||
|
||
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, | ||
batch_sampler=None, num_workers=0, collate_fn=None, | ||
pin_memory=False, drop_last=False, timeout=0, | ||
worker_init_fn=None, *, prefetch_factor=2, | ||
persistent_workers=False) | ||
""" | ||
|
||
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, | ||
batch_sampler=None, num_workers=0, collate_fn=None, | ||
pin_memory=False, drop_last=False, timeout=0, | ||
worker_init_fn=None, *, prefetch_factor=2, | ||
persistent_workers=False): | ||
|
||
self.dataset = dataset | ||
self.batch_size = batch_size | ||
self.shuffle = shuffle | ||
self.num_workers = num_workers | ||
|
||
self.workers = dataset.workers | ||
|
||
self.batch_samplers = {} | ||
for worker in self.workers: | ||
data_range = range(len(self.dataset)) | ||
if shuffle: | ||
sampler = RandomSampler(data_range) | ||
else: | ||
sampler = SequentialSampler(data_range) | ||
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last) | ||
self.batch_samplers[worker] = batch_sampler | ||
|
||
single_loaders = [] | ||
for k in self.dataset.datasets.keys(): | ||
single_loaders.append(SinglePartitionDataLoader(self.dataset.datasets[k], batch_sampler=self.batch_samplers[k])) | ||
|
||
self.single_loaders = single_loaders | ||
|
||
def __len__(self): | ||
return sum(len(x) for x in self.dataset.datasets.values()) // len(self.workers) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean we don't need the partitioned dataloader?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I mean that the default pytorch dataloader in PyTorch works, so we do not need a custom one (for how it is done now). See the notebook for an example.