Jax编程: 踩坑记录

摘要

记录 jax 使用过程中的坑。

gpu 上安装 jax 需要

1
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

如果没有 IB 网卡,多机训练要关闭防火墙并加上如下环境变量

1
2
3
export NCCL_DEBUG=INFO 
export NCCL_SOCKET_IFNAME=网卡ID
export NCCL_IB_DISABLE=1

控制 jax 允许使用显存的变量,笔者一般设置为 0.95(%),如发生 OOM 可调小

1
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95

关闭 jax 的预分配显存机制,但容易产生 GPU 内存碎片,导致出现 OOM

1
export XLA_PYTHON_CLIENT_PREALLOCATE=false

使 jax 按需准确分配所需的内容,并释放不再需要的内存,会导致性能下降,不建议使用,但对小的显存占用运行或调试 OOM 故障可能很有用

1
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
- ETX   Thank you for reading -
  • Copyright: All posts on this blog except otherwise stated, All adopt CC BY-NC-ND 4.0 license agreement. Please indicate the source of reprint!