摘要
记录 jax 使用过程中的坑。
gpu 上安装 jax 需要
1 | pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html |
如果没有 IB 网卡,多机训练要关闭防火墙并加上如下环境变量
1 | export NCCL_DEBUG=INFO |
控制 jax 允许使用显存的变量,笔者一般设置为 0.95(%),如发生 OOM 可调小
1 | export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 |
关闭 jax 的预分配显存机制,但容易产生 GPU 内存碎片,导致出现 OOM
1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false |
使 jax 按需准确分配所需的内容,并释放不再需要的内存,会导致性能下降,不建议使用,但对小的显存占用运行或调试 OOM 故障可能很有用
1 | export XLA_PYTHON_CLIENT_ALLOCATOR=platform |