Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
7360f91
Add pytorch style dataloader
dayo05 Nov 28, 2021
3de29c3
Remove GetDataEnumerable from interface
dayo05 Nov 30, 2021
857f9e6
Resolve review except get random value
dayo05 Nov 30, 2021
954fa8c
Rename method and create reset method
dayo05 Nov 30, 2021
63568f9
Add copyright string
dayo05 Nov 30, 2021
d8d249f
Use new shuffle algorithm
dayo05 Nov 30, 2021
ab2eb09
Add summery
dayo05 Nov 30, 2021
ca49340
Make able to create non-shuffle dataloader
dayo05 Nov 30, 2021
2cc2300
Make able to create non-shuffle dataset
dayo05 Nov 30, 2021
4cff460
Change tensor tuple to dictionary
dayo05 Nov 30, 2021
124790d
Merge branch 'main' into main
dayo05 Nov 30, 2021
17a9022
Replace files and make dataset abstract class
dayo05 Nov 30, 2021
d9dddb8
Merge remote-tracking branch 'origin/main'
dayo05 Nov 30, 2021
07d2689
Merge branch 'dotnet:main' into main
dayo05 Nov 30, 2021
2315e66
Merge branch 'main' into main
dayo05 Nov 30, 2021
ab6bd3e
Make dataloader disposable
dayo05 Dec 1, 2021
9631ed6
Make count priority abstract
dayo05 Dec 1, 2021
43dfcc1
Make dataloader to stack data as end of tensor
dayo05 Dec 1, 2021
225b9f8
Create simple test for dataset and dataloader
dayo05 Dec 1, 2021
a44ba7d
Merge remote-tracking branch 'origin/main'
dayo05 Dec 1, 2021
6e336ca
Make dispose enumerator
dayo05 Dec 1, 2021
5b5d9d5
Rename methods and add copyright notice
dayo05 Dec 3, 2021
49d7afd
Rename reset to Reset
dayo05 Dec 3, 2021
b250fe0
Make Count of dataset to long type
dayo05 Dec 3, 2021
703cb01
Make type of Count to long
dayo05 Dec 3, 2021
65f06cf
Make Count to long
dayo05 Dec 3, 2021
5f1707f
Rename methods
dayo05 Dec 3, 2021
3bda582
Make move tensor automatically to device
dayo05 Dec 3, 2021
22d0556
Make able to use custom seed
dayo05 Dec 3, 2021
eb53d61
Edit test for long
dayo05 Dec 3, 2021
921bb3d
Create test for custom seed
dayo05 Dec 3, 2021
8733f10
Merge branch 'main' into main
dayo05 Dec 3, 2021
168f87c
Make dataloader tensor dispose on MoveNext or Reset
dayo05 Dec 3, 2021
4db9b64
Change GCD algorithm
dayo05 Dec 3, 2021
5315175
Merge branch 'dotnet:main' into main
dayo05 Dec 25, 2021
eab28eb
Added document comments
dayo05 Dec 25, 2021
0deaddb
Add document comment for classes
dayo05 Dec 25, 2021
e9c20a4
Make catenate every tensor once
dayo05 Dec 25, 2021
cc5dfe0
Update doc comment
dayo05 Dec 25, 2021
dfff08a
Make able to set custom shuffler
dayo05 Jan 6, 2022
9dffab6
Fix mistake on creating custom shuffler
dayo05 Jan 8, 2022
991c377
Add fisher yates shuffler and make that as default
dayo05 Jan 8, 2022
2efce39
Fix mistake on shuffler
dayo05 Jan 7, 2022
00e16ad
Make dispose dataset once
dayo05 Jan 10, 2022
9291dc4
Undo changes on global.json
dayo05 Jan 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions src/TorchSharp/DataLoader.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.

using System;
using System.Collections;
using System.Collections.Generic;

using TorchSharp.Utils;

namespace TorchSharp
{
public static partial class torch
{
public static partial class utils
{
public static partial class data
{
public class DataLoader : IEnumerable<Dictionary<string, Tensor>>, IDisposable
{
private Dataset dataset;
private int batchSize;
private bool shuffle;
private Device device;

/// <summary>
/// Create pytorch style dataloader
/// </summary>
/// <param name="dataset"></param>
/// <param name="batchSize"></param>
/// <param name="shuffle"></param>
/// <param name="device"></param>
public DataLoader(Dataset dataset, int batchSize, bool shuffle = false, Device device = null)
{
this.dataset = dataset;
this.batchSize = batchSize;
this.shuffle = shuffle;
this.device = device ?? CPU;
}

public IEnumerator<Dictionary<string, Tensor>> GetEnumerator() =>
new DataLoaderEnumerator(dataset, batchSize, shuffle, device);

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

public long Count => (dataset.Count - 1) / batchSize + 1;

private class DataLoaderEnumerator : IEnumerator<Dictionary<string, Tensor>>
{
private Dataset dataset;
private int batchSize;
private Device device;
private bool shuffle;
private ShuffleGenerator shuffleGenerator;
private int currentVal = 0;

public DataLoaderEnumerator(Dataset dataset, int batchSize, bool shuffle, Device device)
{
this.dataset = dataset;
this.batchSize = batchSize;
this.device = device;
this.shuffle = shuffle;
reset();
}

private bool isFinished() =>
shuffle ? !shuffleGenerator.hasNext() : currentVal >= dataset.Count;

private int getNextValue() => shuffle ? shuffleGenerator.next() : currentVal++;

private void reset()
{
shuffleGenerator = new ShuffleGenerator(dataset.Count);
currentVal = 0;
}

public bool MoveNext()
{
if (isFinished()) return false;
Current = dataset.GetTensor(getNextValue());
var currentKeys = Current.Keys;
foreach (var x in currentKeys)
Current[x].unsqueeze_(0);
Dictionary<string, Tensor> dic;
for (var i = 1; i < batchSize; i++) {
if (isFinished())
break;
dic = dataset.GetTensor(getNextValue());
foreach (var x in currentKeys)
Current[x] = cat(new List<Tensor>() {Current[x], dic[x].unsqueeze(0)}, 0);
}

return true;
}

public void Reset() => reset();

public Dictionary<string, Tensor> Current { get; private set; }

object IEnumerator.Current => Current;

public void Dispose()
{
dataset.Dispose();
}
}

public void Dispose()
{
dataset.Dispose();
}
}
}
}
}
}
26 changes: 26 additions & 0 deletions src/TorchSharp/Dataset.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using System.Collections.Generic;

namespace TorchSharp
{
public static partial class torch
{
public static partial class utils
{
public static partial class data
{
public abstract class Dataset : IDisposable
{
public virtual void Dispose()
{
}

public abstract int Count { get; }

public abstract Dictionary<string, Tensor> GetTensor(int index);
}
}
}
}
}
89 changes: 89 additions & 0 deletions src/TorchSharp/Utils/ShuffleGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using System;

namespace TorchSharp.Utils
{
public class ShuffleGenerator
{
int maxrange;
int prime;
int index;
int offset;
int runningvalue;

public ShuffleGenerator(int size)
{
var min = size / 2;
maxrange = size;
prime = selectCoPrimeResev(min, size);
offset = new Random().Next(size);
index = 0;
runningvalue = offset;
}

private int getCurrentValue()
{
return (int) (((long) index * prime + offset) % (maxrange));
}

public bool hasNext()
{
return index < maxrange;
}

public int next()
{
runningvalue += prime;
if (runningvalue >= maxrange) runningvalue -= maxrange;
index++;
return runningvalue;
}

private const int MAX_COUNT = int.MaxValue;

static int selectCoPrimeResev(int min, int target)
{
var count = 0;
var selected = 0;
var rand = new Random();
for (var val = min; val < target; ++val) {
if (coprime(val, target)) {
count += 1;
if ((count == 1) || (rand.Next(count) < 1)) {
selected = val;
}
}

if (count == MAX_COUNT) return val;
}

return selected;
}

static bool coprime(int u, int v) => gcd(u, v) == 1;

static int gcd(int u, int v)
{
int shift;
if (u == 0) return v;
if (v == 0) return u;
for (shift = 0; ((u | v) & 1) == 0; ++shift) {
u >>= 1;
v >>= 1;
}

while ((u & 1) == 0)
u >>= 1;

do {
while ((v & 1) == 0)
v >>= 1;
if (u > v)
(v, u) = (u, v);

v -= u;
} while (v != 0);

return u << shift;
}
}
}
44 changes: 44 additions & 0 deletions test/TorchSharpTest/TestDataLoader.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using System;
using System.Collections.Generic;
using Xunit;

namespace TorchSharp
{
public class TestDataLoader
{
private class TestDataset : torch.utils.data.Dataset
{
public override int Count { get; } = 10;
public override Dictionary<string, torch.Tensor> GetTensor(int index)
{
return new() {{"data", torch.tensor(1)}, {"label", torch.tensor(13)}, {"index", torch.tensor(index)}};
}
}

[Fact]
public void DatasetTest()
{
using var dataset = new TestDataset();
var d = dataset.GetTensor(0);
Assert.True(d.ContainsKey("data"));
Assert.True(d.ContainsKey("index"));
Assert.True(d.ContainsKey("label"));

Assert.Equal(d["data"], torch.tensor(1));
Assert.Equal(d["label"], torch.tensor(13));
Assert.Equal(d["index"], torch.tensor(0));
}

[Fact]
public void DataLoaderTest()
{
using var dataset = new TestDataset();
using var dataloader = new torch.utils.data.DataLoader(dataset, 2, false, torch.CPU);
var iterator = dataloader.GetEnumerator();
iterator.MoveNext();
Assert.Equal(iterator.Current["data"], torch.tensor(rawArray: new[]{1, 1}, dimensions: new[]{2L}));
Assert.Equal(iterator.Current["label"], torch.tensor(rawArray: new[]{13, 13}, dimensions: new[]{2L}));
Assert.Equal(iterator.Current["index"], torch.tensor(rawArray: new[]{0, 1}, dimensions: new[]{2L}));
}
}
}