Numpyroでベイズ統計モデリング~単回帰モデル~
かれこれ数年付き合ってきたstan
ですが、
推定の遅さが仕事上かなりネックになっており、
メインで使うPPLを高速にMCMCを回せるNumpyro
に変えようと模索中。
以下、RとStanで始めるベイズ統計モデリングによるデータ分析入門のデータ&分析内容を題材に、Numpyro
で実装しながら練習する記録です。
※可視化用ライブラリの arviz
もちょっと使ってない間にだいぶ高機能になっており、その辺も試していく所存。
準備
環境はgoogle colabです。
import jax.numpy as jnp import numpy as np import jax.random as random import pandas as pd import matplotlib.pyplot as plt import seaborn as sns sns.set(style='darkgrid',palette='bright') #!pip install numpyro import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS numpyro.set_host_device_count(4) import arviz as az
データの準備・確認
データは以下サポートサイトより
file_bear_sales_2 = pd.read_csv("3-2-1-beer-sales-2.csv")
file_bear_sales_2.head()
元データはsales
とtemperature
の1変量による単回帰モデル用。
plt.figure(figsize=(10,5)) sns.scatterplot(x="temperature",y="sales",data=file_bear_sales_2) plt.show()
データ分布は以下のような形。
モデル
単回帰モデルなので非常に単純。
※はてぶのTexが全く表示されなかったり、よくわからないエラー起こるので作法が謎。
def model( sales, temperature ): Intercept = numpyro.sample("Intercept",dist.Normal(0,100)) beta = numpyro.sample("beta",dist.Normal(0,100)) sigma = numpyro.sample("sigma",dist.HalfNormal(100)) numpyro.sample("sales",dist.Normal(Intercept + beta * temperature, sigma),obs = sales)
Numpyro
のいいところはモデルの記述が直感的でわかりやすいところ。
後に取り組む階層モデルもわかりやすい。
一方で、時系列には jax.lax.scan
を使うのが肝要なんですが、一気に複雑になる。
ここもいずれ実装して理解を深めていきたい。
data_dict = { "temperature":file_bear_sales_2["temperature"].values, "sales":file_bear_sales_2["sales"].values } kernel = NUTS(model) sample_kwargs = dict( sampler=kernel, num_warmup=2000, num_samples=2000, num_chains=4, chain_method="parallel" ) mcmc = MCMC(**sample_kwargs) mcmc.run(random.PRNGKey(0), **data_dict)
MCMC回すのもそうなんですが、コンパイルいらないので、体感ものすごく早い。
4000サンプリングで1chain12秒。
並列処理、GPU利用も簡単なので、アクセラレータを最大限活かせる。
結果
mcmc.print_summary()
結果はほぼ本紙と同一になりました。
arviz
でトレースプロットを確認してみる。
idata = az.from_numpyro(mcmc) az.plot_trace(idata, figsize=(16,9)) plt.show()
r_hat
でも確認できるとおり、収束も問題なしです。
所感
Numpyro
は stan
と比べるとかなり高速にMCMCを回せるし、
比較的(TFPとかと比べれば)理解しやすい記述であることもポイントの一つ。
事前分布を必ず明示しなきゃいけないので、stan
のデフォルト無情報事前分布に慣れてるとちょっと面倒です(恥
時系列はいずれ取り組みたいが、ちょっとモデル記述がややこしくなるので、そこはやっぱり stan
のほうが分かりやすい(単に慣れてるだけ説が濃厚ですが。)
次は事後予測分布の可視化が焦点です。