NumPyro のインストール方法
最近 Google Colab で NumPyro をインストールをしようとしたら、ちょっと手間取ることがあったので、簡単に手順の方を紹介したいと思います。なお、NumPyro は標準では Windows をサポートしていません(JAX が Windows をサポートしていないため)。
以前のNumPyro は JAX に関するバージョンの縛りが厳しく、うまく JAX のバージョンを選ばないとちゃんと動いてくれなかったのですが、0.7.0 以降はこうしたバージョンの縛りが緩くなるという話が流れていました。
一時期はこの方針もうまく行っていたみたいで、次のようなコマンドで NumPyro をインストールするだけで使えるようになっていました。Google Colab には JAX が標準でインストールされているためです。
pip install numpyro==0.7.2
しかし、Google Colab の JAX がより新しいバージョンにアップデートされると、NumPyro のインストールはできるものの、うまく動かなくなってしまいました。
NumPyro の方も近いうちにアップデートされると思うので、大きな問題ではないだろうと思うのですが、すぐに使いたい人もいると思うので、NumPyro の version 0.7.2 と対応している JAX のインストール方法を記しておきたいと思います。
基本的には、ちょっとだけ JAX のバージョンを落とすだけです。この記事を書いている時点(2021年11月)での JAX のバージョンは 0.2.21 なのですが、これを少し落として バージョン 0.2.17 の JAX をインストールしてから NumPyro をインストールします。
pip install --upgrade jax==0.2.17 jaxlib==0.1.71+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install numpyro==0.7.2
一応、上のコマンドでは CUDA 11.1 に対応した jaxlib をインストールするようなコマンドにしています。GPU 対応が不要であれば、以下のような感じでも大丈夫なはずです。
pip install --upgrade jax==0.2.17 jaxlib==0.1.71
pip install numpyro==0.7.2
上のコマンドの方が CUDA のバージョンとの絡みがないので、トラブルには遭遇しにくいかもしれないです…