NumPyro の実行例 – Eight Schools
概要
有名な例題である Eight Schools を Google Colab 上で実行してみました。コードは、こちらのサイトを参考に書いています。
Install Packages
まずは、NumPyro をインストールします。インストール完了後に、ランタイムは再起動しておきます。
!pip install --upgrade jax==0.2.17 jaxlib==0.1.71+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install numpyro==0.7.2
!pip install arviz==0.11.2
Import Packages
次に、必要なパッケージをインポートします。
import numpyro
import numpyro.distributions as dist
import jax
import jax.numpy as jnp
import arviz as az
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)
Define Data
データは、下のデータを利用しました。
J = 8
y = jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
Define Model & Inference
まずは、NumPyro の Getting Started にある通りのモデルでやってみます。
def model(J, sigma, y=None):
mu = numpyro.sample('mu', dist.Normal(0, 5))
tau = numpyro.sample('tau', dist.HalfCauchy(5))
eta = numpyro.sample('eta', dist.Normal(0, 1), sample_shape=(J, ))
theta = numpyro.deterministic('theta', mu + tau * eta)
numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
%%time
nuts = numpyro.infer.NUTS(model, target_accept_prob=0.99, max_tree_depth=10)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=100000, num_chains=4, progress_bar=False)
mcmc.run(jax.random.PRNGKey(0), J, sigma, y)
idata = az.from_numpyro(mcmc)
CPU times: user 20.3 s, sys: 188 ms, total: 20.5 s
Wall time: 14.6 s
Chain は 4本計算しましたが、概ね 15秒くらいで計算できました。
az.plot_trace(idata);
az.summary(idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
eta[0] | 0.315 | 0.989 | -1.544 | 2.169 | 0.002 | 0.002 | 348062.0 | 289037.0 | 1.0 |
eta[1] | 0.098 | 0.940 | -1.691 | 1.864 | 0.002 | 0.001 | 387829.0 | 294075.0 | 1.0 |
eta[2] | -0.085 | 0.970 | -1.908 | 1.748 | 0.001 | 0.002 | 419545.0 | 305056.0 | 1.0 |
eta[3] | 0.063 | 0.942 | -1.726 | 1.830 | 0.001 | 0.001 | 419036.0 | 302751.0 | 1.0 |
eta[4] | -0.161 | 0.931 | -1.911 | 1.604 | 0.002 | 0.001 | 361756.0 | 287679.0 | 1.0 |
eta[5] | -0.070 | 0.942 | -1.866 | 1.689 | 0.001 | 0.001 | 400910.0 | 299947.0 | 1.0 |
eta[6] | 0.356 | 0.960 | -1.474 | 2.145 | 0.002 | 0.001 | 346002.0 | 286779.0 | 1.0 |
eta[7] | 0.075 | 0.975 | -1.755 | 1.918 | 0.002 | 0.002 | 379096.0 | 292705.0 | 1.0 |
mu | 4.387 | 3.321 | -1.967 | 10.544 | 0.006 | 0.004 | 352183.0 | 278018.0 | 1.0 |
tau | 3.592 | 3.227 | 0.000 | 9.294 | 0.006 | 0.004 | 220316.0 | 162144.0 | 1.0 |
theta[0] | 6.188 | 5.598 | -3.768 | 17.023 | 0.009 | 0.007 | 360746.0 | 322368.0 | 1.0 |
theta[1] | 4.942 | 4.672 | -3.781 | 13.981 | 0.007 | 0.005 | 438859.0 | 337950.0 | 1.0 |
theta[2] | 3.920 | 5.269 | -6.207 | 13.716 | 0.009 | 0.006 | 392005.0 | 334055.0 | 1.0 |
theta[3] | 4.749 | 4.764 | -4.203 | 13.968 | 0.007 | 0.006 | 438214.0 | 339686.0 | 1.0 |
theta[4] | 3.605 | 4.663 | -5.417 | 12.270 | 0.007 | 0.006 | 414163.0 | 332309.0 | 1.0 |
theta[5] | 4.039 | 4.828 | -5.303 | 13.073 | 0.007 | 0.006 | 422227.0 | 343072.0 | 1.0 |
theta[6] | 6.294 | 5.083 | -2.851 | 16.278 | 0.008 | 0.006 | 396408.0 | 327497.0 | 1.0 |
theta[7] | 4.838 | 5.296 | -5.165 | 14.943 | 0.009 | 0.006 | 384954.0 | 327053.0 | 1.0 |
ただ、PyStan で計算させたときと、計算結果が違うようなので、事前分布をもう少しフラットなものに置き換えます。PyStan での実行例は下をご覧頂けたらと思います。
def model_modified(J, sigma, y=None):
mu = numpyro.sample('mu', dist.Normal(0, 100))
tau = numpyro.sample('tau', dist.HalfNormal(100))
eta = numpyro.sample('eta', dist.Normal(0, 1), sample_shape=(J, ))
theta = numpyro.deterministic('theta', mu + tau * eta)
numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
%%time
nuts = numpyro.infer.NUTS(model_modified, target_accept_prob=0.99, max_tree_depth=10)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=100000, num_chains=4, progress_bar=False)
mcmc.run(jax.random.PRNGKey(0), J, sigma, y)
idata = az.from_numpyro(mcmc)
CPU times: user 21.3 s, sys: 60.2 ms, total: 21.3 s
Wall time: 14.4 s
計算時間はあまり変わりません。やはり 15秒くらいでした。
az.plot_trace(idata);
az.summary(idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
eta[0] | 0.389 | 0.940 | -1.409 | 2.139 | 0.002 | 0.001 | 357491.0 | 278624.0 | 1.0 |
eta[1] | -0.001 | 0.874 | -1.675 | 1.646 | 0.001 | 0.001 | 440918.0 | 288705.0 | 1.0 |
eta[2] | -0.193 | 0.933 | -1.944 | 1.578 | 0.001 | 0.002 | 460476.0 | 287100.0 | 1.0 |
eta[3] | -0.029 | 0.884 | -1.696 | 1.654 | 0.001 | 0.001 | 456832.0 | 290115.0 | 1.0 |
eta[4] | -0.353 | 0.875 | -2.008 | 1.317 | 0.001 | 0.001 | 366124.0 | 285362.0 | 1.0 |
eta[5] | -0.211 | 0.893 | -1.896 | 1.488 | 0.001 | 0.001 | 428520.0 | 280111.0 | 1.0 |
eta[6] | 0.346 | 0.888 | -1.345 | 2.019 | 0.001 | 0.001 | 388660.0 | 280306.0 | 1.0 |
eta[7] | 0.059 | 0.936 | -1.689 | 1.840 | 0.001 | 0.002 | 459742.0 | 286121.0 | 1.0 |
mu | 7.910 | 5.160 | -1.763 | 17.535 | 0.011 | 0.010 | 236141.0 | 191930.0 | 1.0 |
tau | 6.518 | 5.557 | 0.000 | 16.117 | 0.015 | 0.012 | 145392.0 | 163708.0 | 1.0 |
theta[0] | 11.342 | 8.275 | -3.133 | 27.980 | 0.015 | 0.011 | 305037.0 | 302050.0 | 1.0 |
theta[1] | 7.880 | 6.248 | -3.923 | 20.054 | 0.009 | 0.007 | 525761.0 | 348429.0 | 1.0 |
theta[2] | 6.130 | 7.769 | -9.565 | 20.486 | 0.012 | 0.010 | 414265.0 | 314637.0 | 1.0 |
theta[3] | 7.644 | 6.515 | -4.988 | 20.074 | 0.009 | 0.007 | 513369.0 | 341057.0 | 1.0 |
theta[4] | 5.137 | 6.337 | -7.308 | 16.689 | 0.010 | 0.008 | 417383.0 | 337156.0 | 1.0 |
theta[5] | 6.140 | 6.710 | -7.051 | 18.672 | 0.010 | 0.008 | 463469.0 | 342193.0 | 1.0 |
theta[6] | 10.645 | 6.781 | -1.396 | 24.202 | 0.011 | 0.008 | 389765.0 | 340423.0 | 1.0 |
theta[7] | 8.443 | 7.838 | -6.644 | 23.830 | 0.012 | 0.010 | 423002.0 | 307346.0 | 1.0 |
今度は PyStan で計算したときの計算結果に近いものが出てきました。
Summary
NumPyro の方では、Chain を 4本にしたときしか計算していませんが、いずれのケースでも 15秒程度で計算が完了しました。