Some collected useful things from writing JAX

Environment variables for debugging

JAX_DEBUG_NANS=1
JAX_DIABLE_JIT=1

Choosing gpus and platforms for JAX

(for CUDA in general)

CUDA_VISIBLE_DEVICES=3 # for gpu 3, 0-indexed

for JAX, if you don’t want to use GPU memory for a simple job that still needs a JAX import

JAX_PLATFORMS=      # blank for default/all
JAX_PLATFORMS=cpu   # just cpu

If you want to use tensorflow-datasets without tensorflow also using gpu space, make sure to install tensorflow-cpu specifically.