俺のOneNote

俺のOneNote

データ分析が仕事な人のOneNote愛とか、分析小話とか。

Numpyroでベイズ統計モデリング~事後予測分布~

今回はRとStanで始めるベイズ統計モデリングによるデータ分析入門の実装勉強第2回。

第3部第3章の「モデルを用いた予測」+αの実装です。

データ、モデル推定

ここは前回と同様のため主要コードのみ掲載。

kopaprin.hatenadiary.jp

#ライブラリ
import jax.numpy as jnp
import numpy as np
import jax.random as random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='darkgrid',palette='bright')

!pip install numpyro
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
numpyro.set_host_device_count(4)

#データインポート
file_bear_sales_2 = pd.read_csv("3-2-1-beer-sales-2.csv")

#モデル定義
def model(
    N,
    sales,
    temperature
):
  Intercept = numpyro.sample("Intercept",dist.Normal(0,100))
  beta = numpyro.sample("beta",dist.Normal(0,100))
  sigma = numpyro.sample("sigma",dist.HalfNormal(100))

  with numpyro.plate("N",N):
    numpyro.sample("sales",dist.Normal(Intercept + beta * temperature, sigma),obs = sales)

#MCMCによる事後分布サンプリング
data_dict = {
    "N":len(file_bear_sales_2),
    "temperature":file_bear_sales_2["temperature"].values,
    "sales":file_bear_sales_2["sales"].values
}
kernel = NUTS(model)
sample_kwargs = dict(
    sampler=kernel, 
    num_warmup=2000, 
    num_samples=2000, 
    num_chains=4, 
    chain_method="parallel"
)
mcmc = MCMC(**sample_kwargs)
mcmc.run(random.PRNGKey(0), **data_dict)

事後予測分布の生成

事後分布のサンプリングからパラメータのMCMCサンプルを得るところから。
numpyro.infer.Predictive インスタンスの利用と、予測したい説明変数を生成し、作成したモデルで事後予測分布を取得する準備をします。
今回は本紙に則り、気温が11度~30度までの区間の事後予測分布を得ることを目的とします。

mcmc_samples=mcmc.get_samples()
predictive = numpyro.infer.Predictive(model, mcmc_samples)
temperature_pred = jnp.arange(11,31)

あとはpredictiveインスタンスの引数に乱数と予測したいモデルの説明変数を与えるだけです。
目的変数となる観測データ(observations)はNoneを指定し、モデルから出力することを明示します。

ppc_samples = predictive(random.PRNGKey(0),N = len(temperature_pred), temperature = temperature_pred, sales=None)

これで事後予測分布が取得できました。
arviz による可視化のため、 InferenceDataオブジェクトに変換します。

idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples)

可視化による事後予測分布のチェック

あとは arvizのプロットに必要な InferenceDataオブジェクトを渡すだけです。 軸ラベルなどが適当ですが、arvizのラベルガイドに基づけば柔軟に対応できそうです。 ※読み込み大変なので、ここではデフォルトでご容赦ください。

arviz-devs.github.io

まずは95%ベイズ予測区間の可視化です。

az.plot_forest(idata_ppc.posterior_predictive["sales"],
                  var_names=["sales"],
                  hdi_prob=0.95,
                  combined=True,
                  colors='b');

軸ラベルが分かりづらいですが、本紙と同じく、気温11度~30度までの各予測分布を可視化できました。

特定の事後予測分布を並列で可視化することも可能です。
ここでは、本紙と同じく気温11度と気温30度の事後予測分布を可視化します。

az.plot_forest(idata_ppc.posterior_predictive["sales"][:,:,[0,19]],
                  kind="ridgeplot",
                  var_names=["sales"],
                  hdi_prob=0.95,
                  ridgeplot_overlap=0.9,
                  ridgeplot_truncate = False,
                  ridgeplot_quantiles=[.5],
                  combined=True,
                  ridgeplot_alpha=0.5,
                  figsize=(8,5),
                  colors='b');

本紙のRにおける bayesplotとは微妙に見た目違いますが、同じものが出力できました。
arvizベイズモデリングの結果解釈でかなり有用になりそうです。

おまけ

本紙第3部第3章にはありませんが、回帰直線による予測区間の可視化も可能です。

sales_pred =  idata_ppc.posterior_predictive['sales']
az.plot_hdi(temperature_pred, sales_pred,fill_kwargs={'alpha': 0.3})
plt.plot(temperature_pred, sales_pred.mean(axis=0).mean(axis=0),color="orange")

sns.scatterplot(x=file_bear_sales_2["temperature"], y=file_bear_sales_2["sales"],s=50,color="gray")

おわりに

stanとは若干流儀が違いますが、かなり簡易にモデルを利用した予測分布を得ることができました。
予測分布取得用のモデルを追記する必要がない分、慣れればこっちのほうが使いやすそうな気もします。

あとは arviz が優秀。もうちょい使いこなせるようにしたいです。

次は第3部第6章、ダミー変数と分散分析モデルがテーマです。