-
Notifications
You must be signed in to change notification settings - Fork 684
Description
Describe the bug
I followed the burn book to reproduce the MNIST training example. Out of interest, I wanted to try out the NDArray
backend, but the network does not converge -> it has much worse validation accuracy with ndarray backend.
To Reproduce
This repo contains the mnist code from the book:
https://github.com/jsosulski/mnist_mwe
To run on cpu: cargo run -r
and to run on gpu use cargo run -r --features wgpu
.
Expected behavior
Regardless of backend choice, the network reaches at least somewhat similar validation accuracies.
Screenshots
This is the output for the training on WGPU backend:
$ cargo run --release --features wgpu
...
| Split | Metric | Min. | Epoch | Max. | Epoch |
|-------|----------|----------|----------|----------|----------|
| Train | Accuracy | 82.317 | 1 | 82.317 | 1 |
| Train | Loss | 0.615 | 1 | 0.615 | 1 |
| Valid | Accuracy | 92.300 | 1 | 92.300 | 1 |
| Valid | Loss | 0.252 | 1 | 0.252 | 1 |
This is the output for the training on NDArray backend:
$ cargo run --release
...
| Split | Metric | Min. | Epoch | Max. | Epoch |
|-------|----------|----------|----------|----------|----------|
| Train | Accuracy | 73.732 | 1 | 73.732 | 1 |
| Train | Loss | 1.006 | 1 | 1.006 | 1 |
| Valid | Accuracy | 8.470 | 1 | 8.470 | 1 |
| Valid | Loss | 3.637 | 1 | 3.637 | 1 |
Desktop (please complete the following information):
- OS: Manjaro
- Version: stable
Additional context
If it is relevant / helps, I am on a Laptop that does not have a dedicated GPU. The CPU is: AMD Ryzen 7 PRO 5850U with Radeon Graphics (16) @ 4.507GHz
.
I also tried on latest main, but have the same results.
Probably offtopic, but on main I had to remove the .devices(...)
call from the learner
and as it turns out, this is also not needed on 0.18.0.