PyStan の実行例 – Eight Schools
概要
有名な例題である Eight Schools を Google Colab 上で実行してみました。コードはこちらのサイトを参考に書いています。
Install Package
まずは、PyStan をインストールします。
!pip install pystan==2.19.1.1
Import Packages
次に、必要なパッケージをインポートします。
import pystan
import arviz as az
Define Data
データは、下のデータを利用しました。
data = {'J': 8,
'y': [28, 8, -3, 7, -1, 1, 18, 12],
'sigma': [15, 10, 16, 11, 9, 11, 10, 18]}
Define Model & Inference
モデルを定義して、推論を行います。
model_code = """
data {
int<lower=0> J; // number of schools
real y[J]; // estimated treatment effects
real<lower=0> sigma[J]; // standard error of effect estimates
}
parameters {
real mu; // population treatment effect
real<lower=0> tau; // standard deviation in treatment effects
vector[J] eta; // unscaled deviation from mu by school
}
transformed parameters {
vector[J] theta = mu + tau * eta; // school treatment effects
}
model {
target += normal_lpdf(eta | 0, 1); // prior log-density
target += normal_lpdf(y | theta, sigma); // log-likelihood
}
"""
%%time
sm = pystan.StanModel(model_code=model_code)
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_67cb7d0f2cb7720776cbeb52007d2dbb NOW.
CPU times: user 1.86 s, sys: 132 ms, total: 1.99 s
Wall time: 1min 10s
Google Colab だと、コードのコンパイルに1分くらいかかりました。手元のマシンで計算すると、もっと速くなる可能性はあるかもしれません。次に、Chain を 1本だけ計算させた場合と Chain を 4本計算させた場合の両方で時間を測ってみます(Chain を並列に計算させる方法がわからなかったため ^^;)
%%time
mcmc = sm.sampling(data=data, iter=100000, warmup=500, chains=1, control=dict(adapt_delta=0.99, max_treedepth=10))
WARNING:pystan:2 of 99500 iterations ended with a divergence (0.00201 %).
WARNING:pystan:Try running with adapt_delta larger than 0.99 to remove the divergences.
CPU times: user 7.44 s, sys: 130 ms, total: 7.57 s
Wall time: 7.54 s
%%time
mcmc = sm.sampling(data=data, iter=100000, warmup=500, chains=4, control=dict(adapt_delta=0.99, max_treedepth=10))
WARNING:pystan:10 of 398000 iterations ended with a divergence (0.00251 %).
WARNING:pystan:Try running with adapt_delta larger than 0.99 to remove the divergences.
CPU times: user 7.45 s, sys: 434 ms, total: 7.88 s
Wall time: 38.3 s
Chain 1本あたりで 8~10秒くらいの時間がかかっているようです。それにしても、Stan は WARNING が細かく、親切設計ですね ^^
idata = az.from_pystan(mcmc)
az.plot_trace(idata);
az.summary(idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
mu | 7.930 | 5.217 | -1.818 | 17.534 | 0.015 | 0.022 | 212292.0 | 176914.0 | 1.0 |
tau | 6.581 | 5.728 | 0.000 | 16.263 | 0.019 | 0.026 | 144668.0 | 174137.0 | 1.0 |
eta[0] | 0.390 | 0.938 | -1.369 | 2.172 | 0.002 | 0.001 | 361783.0 | 281559.0 | 1.0 |
eta[1] | 0.001 | 0.874 | -1.674 | 1.643 | 0.001 | 0.001 | 389953.0 | 288434.0 | 1.0 |
eta[2] | -0.194 | 0.927 | -1.922 | 1.573 | 0.001 | 0.002 | 396820.0 | 288937.0 | 1.0 |
eta[3] | -0.031 | 0.883 | -1.676 | 1.672 | 0.001 | 0.001 | 357902.0 | 281880.0 | 1.0 |
eta[4] | -0.351 | 0.878 | -1.995 | 1.342 | 0.002 | 0.001 | 332011.0 | 273770.0 | 1.0 |
eta[5] | -0.210 | 0.893 | -1.910 | 1.474 | 0.001 | 0.001 | 415465.0 | 286695.0 | 1.0 |
eta[6] | 0.342 | 0.886 | -1.358 | 2.009 | 0.001 | 0.001 | 354287.0 | 272854.0 | 1.0 |
eta[7] | 0.056 | 0.933 | -1.711 | 1.810 | 0.001 | 0.002 | 410448.0 | 290257.0 | 1.0 |
theta[0] | 11.384 | 8.345 | -2.935 | 28.463 | 0.016 | 0.011 | 305154.0 | 302281.0 | 1.0 |
theta[1] | 7.894 | 6.282 | -4.210 | 19.890 | 0.009 | 0.007 | 509134.0 | 350682.0 | 1.0 |
theta[2] | 6.134 | 7.729 | -9.339 | 20.493 | 0.013 | 0.010 | 389201.0 | 319471.0 | 1.0 |
theta[3] | 7.645 | 6.536 | -4.839 | 20.317 | 0.009 | 0.007 | 486618.0 | 348020.0 | 1.0 |
theta[4] | 5.120 | 6.355 | -7.239 | 16.810 | 0.010 | 0.008 | 409586.0 | 336711.0 | 1.0 |
theta[5] | 6.142 | 6.727 | -7.249 | 18.535 | 0.010 | 0.008 | 466074.0 | 348823.0 | 1.0 |
theta[6] | 10.652 | 6.785 | -1.628 | 24.027 | 0.011 | 0.008 | 368395.0 | 334347.0 | 1.0 |
theta[7] | 8.443 | 7.879 | -6.834 | 23.830 | 0.013 | 0.011 | 387466.0 | 300924.0 | 1.0 |
Summary
Google Colab 上でやっているためかもしれませんが、Stan のコードをコンパイルするのに1分くらいの時間がかかりました。サンプルはかなり多く発生させましたが、Chain 1本あたり 8~10秒くらいで走りました。