Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions examples/dualheaded_datautils/Example_.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Testing split functions and dataloading\n",
"\n",
"In this section, we test split functions (utils), custom datasets classes and dataloading (with standard pytorch dataloader). "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import print_function\n",
"import syft as sy\n",
"import torch\n",
"from torch.utils.data import Dataset\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.data._utils.collate import default_collate\n",
"from typing import List, Tuple\n",
"from uuid import UUID\n",
"from uuid import uuid4\n",
"from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler\n",
"\n",
"from abc import ABC, abstractmethod\n",
"from torchvision import datasets, transforms\n",
"\n",
"import utils\n",
"import dataloaders\n",
"\n",
"hook = sy.TorchHook(torch)\n",
"\n",
"transform = transforms.Compose([transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,)),\n",
" ])\n",
"trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)\n",
"#trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)\n",
"\n",
"# create some workers\n",
"client_1 = sy.VirtualWorker(hook, id=\"client_1\")\n",
"client_2 = sy.VirtualWorker(hook, id=\"client_2\")\n",
"\n",
"server = sy.VirtualWorker(hook, id= \"server\") \n",
"\n",
"data_owners = [client_1, client_2]"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#get a verticalFederatedDatase\n",
"vfd = utils.split_data_create_vertical_dataset(trainset, data_owners)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"loader = DataLoader(vfd, batch_size=4, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'client_1': [tensor([[[[-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" ...,\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.]]],\n",
"\n",
"\n",
" [[[-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" ...,\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.]]],\n",
"\n",
"\n",
" [[[-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" ...,\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.]]],\n",
"\n",
"\n",
" [[[-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" ...,\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.]]]]), tensor([9, 8, 2, 1]), tensor([51574., 39668., 24844., 32204.], dtype=torch.float64)], 'client_2': [tensor([[[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" ...,\n",
" [ 0.1686, 0.9922, 0.5137, ..., -1.0000, -1.0000, -1.0000],\n",
" [ 0.1686, 0.9922, 0.3412, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]],\n",
"\n",
"\n",
" [[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" ...,\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]],\n",
"\n",
"\n",
" [[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" ...,\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]],\n",
"\n",
"\n",
" [[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" ...,\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n",
" [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]]]), tensor([9, 8, 2, 1]), tensor([51574., 39668., 24844., 32204.], dtype=torch.float64)]}\n"
]
}
],
"source": [
"for el in loader: \n",
" print(el)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"#as in https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/Configuration_1.ipynb\n",
"from torch import nn, optim\n",
"\n",
"model_locations = [client_1, client_2, server]\n",
"\n",
"input_size= [28*14, 28*14]\n",
"hidden_sizes= {\"client_1\": [32, 64], \"client_2\":[32, 64], \"server\":[128, 64]}\n",
"\n",
"#create model segment for each worker\n",
"models = {\n",
" \"client_1\": nn.Sequential(\n",
" nn.Linear(input_size[0], hidden_sizes[\"client_1\"][0]),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_sizes[\"client_1\"][0], hidden_sizes[\"client_1\"][1]),\n",
" nn.ReLU(),\n",
" ),\n",
" \"client_2\": nn.Sequential(\n",
" nn.Linear(input_size[1], hidden_sizes[\"client_2\"][0]),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_sizes[\"client_2\"][0], hidden_sizes[\"client_2\"][1]),\n",
" nn.ReLU(),\n",
" ),\n",
" \"server\": nn.Sequential(\n",
" nn.Linear(hidden_sizes[\"server\"][0], hidden_sizes[\"server\"][1]),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_sizes[\"server\"][1], 10),\n",
" nn.LogSoftmax(dim=1)\n",
" )\n",
"}\n",
"\n",
"\n",
"\n",
"# Create optimisers for each segment and link to their segment\n",
"optimizers = [\n",
" optim.SGD(models[location.id].parameters(), lr=0.05,)\n",
" for location in model_locations\n",
"]\n",
"\n",
"\n",
"#send model segement to each client and server\n",
"for location in model_locations:\n",
" models[location.id].send(location)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
70 changes: 70 additions & 0 deletions examples/dualheaded_datautils/dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import print_function
import syft as sy
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate
from typing import List, Tuple
from uuid import UUID
from uuid import uuid4
from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler

import datasets


"""I think this is not needed anymore"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean we don't need the partitioned dataloader?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I mean that the default pytorch dataloader in PyTorch works, so we do not need a custom one (for how it is done now). See the notebook for an example.



class SinglePartitionDataLoader(DataLoader):
"""DataLoader for a single vertically-partitioned dataset"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

#self.collate_fn = id_collate_fn



class VerticalFederatedDataLoader(DataLoader):
"""Dataloader which batches data from a complete
set of vertically-partitioned datasets


DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
"""

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False):

self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers

self.workers = dataset.workers

self.batch_samplers = {}
for worker in self.workers:
data_range = range(len(self.dataset))
if shuffle:
sampler = RandomSampler(data_range)
else:
sampler = SequentialSampler(data_range)
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last)
self.batch_samplers[worker] = batch_sampler

single_loaders = []
for k in self.dataset.datasets.keys():
single_loaders.append(SinglePartitionDataLoader(self.dataset.datasets[k], batch_sampler=self.batch_samplers[k]))

self.single_loaders = single_loaders

def __len__(self):
return sum(len(x) for x in self.dataset.datasets.values()) // len(self.workers)
Loading