NumPyro のインストール方法

2022年12月3日

最近 Google Colab で NumPyro をインストールをしようとしたら、ちょっと手間取ることがあったので、簡単に手順の方を紹介したいと思います。なお、NumPyro は標準では Windows をサポートしていません(JAX が Windows をサポートしていないため)。

以前のNumPyro は JAX に関するバージョンの縛りが厳しく、うまく JAX のバージョンを選ばないとちゃんと動いてくれなかったのですが、0.7.0 以降はこうしたバージョンの縛りが緩くなるという話が流れていました。

一時期はこの方針もうまく行っていたみたいで、次のようなコマンドで NumPyro をインストールするだけで使えるようになっていました。Google Colab には JAX が標準でインストールされているためです。

pip install numpyro==0.7.2

しかし、Google Colab の JAX がより新しいバージョンにアップデートされると、NumPyro のインストールはできるものの、うまく動かなくなってしまいました。

NumPyro の方も近いうちにアップデートされると思うので、大きな問題ではないだろうと思うのですが、すぐに使いたい人もいると思うので、NumPyro の version 0.7.2 と対応している JAX のインストール方法を記しておきたいと思います。

基本的には、ちょっとだけ JAX のバージョンを落とすだけです。この記事を書いている時点(2021年11月)での JAX のバージョンは 0.2.21 なのですが、これを少し落として バージョン 0.2.17 の JAX をインストールしてから NumPyro をインストールします。

pip install --upgrade jax==0.2.17 jaxlib==0.1.71+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install numpyro==0.7.2

一応、上のコマンドでは CUDA 11.1 に対応した jaxlib をインストールするようなコマンドにしています。GPU 対応が不要であれば、以下のような感じでも大丈夫なはずです。

pip install --upgrade jax==0.2.17 jaxlib==0.1.71
pip install numpyro==0.7.2

上のコマンドの方が CUDA のバージョンとの絡みがないので、トラブルには遭遇しにくいかもしれないです…

NumPyro

Posted by 管理者