NumPyro の実行例 – Eight Schools

2022年12月3日

概要

有名な例題である Eight Schools を Google Colab 上で実行してみました。コードは、こちらのサイトを参考に書いています。

Install Packages

まずは、NumPyro をインストールします。インストール完了後に、ランタイムは再起動しておきます。

!pip install --upgrade jax==0.2.17 jaxlib==0.1.71+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install numpyro==0.7.2
!pip install arviz==0.11.2

Import Packages

次に、必要なパッケージをインポートします。

import numpyro
import numpyro.distributions as dist

import jax
import jax.numpy as jnp
import arviz as az
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

Define Data

データは、下のデータを利用しました。

J = 8
y = jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

Define Model & Inference

まずは、NumPyro の Getting Started にある通りのモデルでやってみます。

def model(J, sigma, y=None):

    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    
    eta = numpyro.sample('eta', dist.Normal(0, 1), sample_shape=(J, ))
    theta = numpyro.deterministic('theta', mu + tau * eta)
    
    numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
%%time

nuts = numpyro.infer.NUTS(model, target_accept_prob=0.99, max_tree_depth=10)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=100000, num_chains=4, progress_bar=False)

mcmc.run(jax.random.PRNGKey(0), J, sigma, y)

idata = az.from_numpyro(mcmc)
CPU times: user 20.3 s, sys: 188 ms, total: 20.5 s
Wall time: 14.6 s

Chain は 4本計算しましたが、概ね 15秒くらいで計算できました。

az.plot_trace(idata);

az.summary(idata)
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
eta[0]0.3150.989-1.5442.1690.0020.002348062.0289037.01.0
eta[1]0.0980.940-1.6911.8640.0020.001387829.0294075.01.0
eta[2]-0.0850.970-1.9081.7480.0010.002419545.0305056.01.0
eta[3]0.0630.942-1.7261.8300.0010.001419036.0302751.01.0
eta[4]-0.1610.931-1.9111.6040.0020.001361756.0287679.01.0
eta[5]-0.0700.942-1.8661.6890.0010.001400910.0299947.01.0
eta[6]0.3560.960-1.4742.1450.0020.001346002.0286779.01.0
eta[7]0.0750.975-1.7551.9180.0020.002379096.0292705.01.0
mu4.3873.321-1.96710.5440.0060.004352183.0278018.01.0
tau3.5923.2270.0009.2940.0060.004220316.0162144.01.0
theta[0]6.1885.598-3.76817.0230.0090.007360746.0322368.01.0
theta[1]4.9424.672-3.78113.9810.0070.005438859.0337950.01.0
theta[2]3.9205.269-6.20713.7160.0090.006392005.0334055.01.0
theta[3]4.7494.764-4.20313.9680.0070.006438214.0339686.01.0
theta[4]3.6054.663-5.41712.2700.0070.006414163.0332309.01.0
theta[5]4.0394.828-5.30313.0730.0070.006422227.0343072.01.0
theta[6]6.2945.083-2.85116.2780.0080.006396408.0327497.01.0
theta[7]4.8385.296-5.16514.9430.0090.006384954.0327053.01.0

ただ、PyStan で計算させたときと、計算結果が違うようなので、事前分布をもう少しフラットなものに置き換えます。PyStan での実行例は下をご覧頂けたらと思います。

def model_modified(J, sigma, y=None):

    mu = numpyro.sample('mu', dist.Normal(0, 100))
    tau = numpyro.sample('tau', dist.HalfNormal(100))
    
    eta = numpyro.sample('eta', dist.Normal(0, 1), sample_shape=(J, ))
    theta = numpyro.deterministic('theta', mu + tau * eta)
    
    numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
%%time

nuts = numpyro.infer.NUTS(model_modified, target_accept_prob=0.99, max_tree_depth=10)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=100000, num_chains=4, progress_bar=False)

mcmc.run(jax.random.PRNGKey(0), J, sigma, y)

idata = az.from_numpyro(mcmc)
CPU times: user 21.3 s, sys: 60.2 ms, total: 21.3 s
Wall time: 14.4 s

計算時間はあまり変わりません。やはり 15秒くらいでした。

az.plot_trace(idata);

az.summary(idata)
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
eta[0]0.3890.940-1.4092.1390.0020.001357491.0278624.01.0
eta[1]-0.0010.874-1.6751.6460.0010.001440918.0288705.01.0
eta[2]-0.1930.933-1.9441.5780.0010.002460476.0287100.01.0
eta[3]-0.0290.884-1.6961.6540.0010.001456832.0290115.01.0
eta[4]-0.3530.875-2.0081.3170.0010.001366124.0285362.01.0
eta[5]-0.2110.893-1.8961.4880.0010.001428520.0280111.01.0
eta[6]0.3460.888-1.3452.0190.0010.001388660.0280306.01.0
eta[7]0.0590.936-1.6891.8400.0010.002459742.0286121.01.0
mu7.9105.160-1.76317.5350.0110.010236141.0191930.01.0
tau6.5185.5570.00016.1170.0150.012145392.0163708.01.0
theta[0]11.3428.275-3.13327.9800.0150.011305037.0302050.01.0
theta[1]7.8806.248-3.92320.0540.0090.007525761.0348429.01.0
theta[2]6.1307.769-9.56520.4860.0120.010414265.0314637.01.0
theta[3]7.6446.515-4.98820.0740.0090.007513369.0341057.01.0
theta[4]5.1376.337-7.30816.6890.0100.008417383.0337156.01.0
theta[5]6.1406.710-7.05118.6720.0100.008463469.0342193.01.0
theta[6]10.6456.781-1.39624.2020.0110.008389765.0340423.01.0
theta[7]8.4437.838-6.64423.8300.0120.010423002.0307346.01.0

今度は PyStan で計算したときの計算結果に近いものが出てきました。

Summary

NumPyro の方では、Chain を 4本にしたときしか計算していませんが、いずれのケースでも 15秒程度で計算が完了しました。

NumPyro

Posted by 管理者