Skip to content

NDArray backend does not converge, even though WGPU backend does #3565

@jsosulski

Description

@jsosulski

Describe the bug

I followed the burn book to reproduce the MNIST training example. Out of interest, I wanted to try out the NDArray backend, but the network does not converge -> it has much worse validation accuracy with ndarray backend.

To Reproduce

This repo contains the mnist code from the book:

https://github.com/jsosulski/mnist_mwe

To run on cpu: cargo run -r and to run on gpu use cargo run -r --features wgpu.

Expected behavior

Regardless of backend choice, the network reaches at least somewhat similar validation accuracies.

Screenshots

This is the output for the training on WGPU backend:

$ cargo run --release --features wgpu
...
| Split | Metric   | Min.     | Epoch    | Max.     | Epoch    |
|-------|----------|----------|----------|----------|----------|
| Train | Accuracy | 82.317   | 1        | 82.317   | 1        |
| Train | Loss     | 0.615    | 1        | 0.615    | 1        |
| Valid | Accuracy | 92.300   | 1        | 92.300   | 1        |
| Valid | Loss     | 0.252    | 1        | 0.252    | 1        |

This is the output for the training on NDArray backend:

$ cargo run --release
...
| Split | Metric   | Min.     | Epoch    | Max.     | Epoch    |
|-------|----------|----------|----------|----------|----------|
| Train | Accuracy | 73.732   | 1        | 73.732   | 1        |
| Train | Loss     | 1.006    | 1        | 1.006    | 1        |
| Valid | Accuracy | 8.470    | 1        | 8.470    | 1        |
| Valid | Loss     | 3.637    | 1        | 3.637    | 1        |

Desktop (please complete the following information):

  • OS: Manjaro
  • Version: stable

Additional context

If it is relevant / helps, I am on a Laptop that does not have a dedicated GPU. The CPU is: AMD Ryzen 7 PRO 5850U with Radeon Graphics (16) @ 4.507GHz.
I also tried on latest main, but have the same results.

Probably offtopic, but on main I had to remove the .devices(...) call from the learner and as it turns out, this is also not needed on 0.18.0.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions