CmdStanPyの試運転

先日CmdStanPyをインストールしてみました。(詳細は下記参照)
publicjournal.hatenablog.com
RのStanやPyStanと結構違うので、どういうものかと一通り使ってみました。ひとまず“Hello, World”を参照にしながら動作確認をしました。

同じような例題が書籍「Pythonによるベイズ統計モデリング」のP48~57にあるので、同書と同じことをCmdStanPyを使ってやってみたいと思います。

例題

こんな例題を考えます。

コイン投げ問題を考える。コインを10回投げて裏と表を記録する。表を1、裏を0として記録したところ、[1,0,1,0,0,0,0,0,1,0]となった。このデータをもとに、コインが偏っているかを判断する。コインの偏りを調べるために、次のモデルを用いる。


\begin{aligned}
\theta  \sim Beta \big(1,1\big) \\
y  \sim Bin  \big(n=10, p=\theta\big) 
\end{aligned}

プログラムの全貌

プログラムは次の通り。

Stanのファイル

Stanのファイルは次の通り

解説

統計モデルの作成

上記のStanファイルにそって統計モデルを作成します。

モデルオブジェクトの生成

先ほどのStanファイルを読み込みます。
PyStanならばpystan.stan(model_code=XXX, data=XXXX, iter=1000, chains=4)などと書いていましたが、CmdStanPyならばStanファイルを指定して次の通りになります。

### Stanファイルを読み込んでオブジェクトを生成する
stan_file_path = 'sample_model.stan'
model = CmdStanModel(stan_file=stan_file_path)

model_codeでモデルをベタ書きできるのかは調査中です。

サンプリングの実行

sample()ハミルトニアンモンテカルロ法のノーUターンサンプラが走るそうです。

### MCMCを実行して事後分布をサンプリングする
iter_sampling = 1000 # サンプリングの数.デフォルト1000
chains = 4 # 並列サンプリングの数.デフォルト4

# データの作成
front_and_back_list = [1,0,1,0,0,0,0,0,1,0]
data = {
    "N" : len(front_and_back_list),
    "y" : front_and_back_list
    }
# 実行
fit_sm = model.sample(data=data,
                      iter_sampling=iter_sampling,
                      iter_warmup = iter_warmup, 
                      chains = chains,
                      seed=1234)

sample()の引数として、よく使うのは次のものかと思います。

  • data:データの辞書を渡す。これは他のStanと同じ。
  • iter_sampling :サンプル数を指定する。他のStanのiterと同じっぽい。
  • iter_warmup :バーンインの数を指定する。他のStanのwarmupと同じっぽい。
  • chains:並列して走るチェーンの数。これは他のStanと同じ。
  • seed:seed値を設定。これは他のStanと同じ。

下記を参考にしました。

またノーUターンサンプラは下記に詳細が載っています。

診断と結果の解釈

ここら辺はもっと便利なものがあるのかどうか分かりませんが、ひとまずPyMC3の収束診断っぽいものを作ってみました。

### 事後分布の診断
la = fit_sm.stan_variables()
plt.figure(figsize=(10,3))
bins=15
for i, k in enumerate(la.keys()):
    boxplot_list =[]
    for j in range(chains):
        x = la[k][j*iter_sampling:(j+1)*iter_sampling]
        hist, bin_edges = np.histogram(x, bins=bins) # 度数分布に変換
        #print(len(hist),len(bin_edges))
        plt.subplot(len(la.keys()), 3, 3*i+1) 
        plt.plot(bin_edges[:bins], hist, alpha=0.6, lw=0.6)
        plt.subplot(len(la.keys()), 3, 3*i+2) 
        plt.plot(x, alpha=0.5, lw=0.8)
        boxplot_list.append(x)
    plt.subplot(len(la.keys()), 3, 3*i+3) 
    plt.boxplot(boxplot_list)
plt.show()

### 結果を解釈する
samples = az.from_cmdstanpy(posterior = fit_sm)
print(az.summary(samples))

f:id:shu10038:20220227182108j:plain

また結果の要約についてもarvizというライブラリを使うと比較的簡単に出来るようです。

Data variables:
    theta    float64 1.002
        mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
theta  0.335  0.133   0.095    0.577      0.003    0.002    1462.0    1934.0    1.0

arviz-devs.github.io