I am very tired.
docker run --gpus all -it --name vyjax -p 8888:8888 --shm-size=20g --ulimit memlock=-1 nvcr.io/nvidia/jax:23.08-py3 bash
pip install --upgrade pip
pip install --upgrade jupyterlab
jupyter lab --allow-root
Please help me run JAX with CUDA. I have a RTX 5090 32 GB in Windows 11 this PC. this is my ultimate goal. Thank you for your time.