Numpyroでベイズ統計モデリング~GLMM~
RとStanで始めるベイズ統計モデリングによるデータ分析入門のNumpyro実装第7回。
今回は一般可線形混合モデル(GLMM)にベイズ推定を利用した、いわゆる階層ベイズモデルです。
準備
import jax.numpy as jnp import numpy as np import jax.random as random import pandas as pd import matplotlib.pyplot as plt import seaborn as sns sns.set(style='darkgrid',palette='bright') !pip install numpyro import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS numpyro.set_host_device_count(4) import arviz as az
df = pd.read_csv("/4-1-1-fish-num-2.csv")
今回のデータは過去ポアソン回帰で利用したデータと似ていますが、各レコードに一意のid列があります。
それぞれのデータの背景として、計測されていないものが理由で目的変数が変化することを想定したモデルにする必要があるということです。
color_dict = {"cloudy":"b","sunny":"r"} g = sns.relplot(data=df,x="temperature",y="fish_num",hue="weather",palette=color_dict) g.fig.set_figheight(6) g.fig.set_figwidth(10)
単純に可視化しても、通常のポアソン分布よりも分散が大きそうであることがわかります。
確認
これを、過去と同様に単純なポアソン回帰にあてはめて推定してみます。
def model_1( N, C, X, fish_num, ): beta = numpyro.sample("beta",dist.Normal(0,100),sample_shape=(C,)) with numpyro.plate("N",N): lambda_ = jnp.exp(jnp.dot(X,beta)) numpyro.sample("fish_num",dist.Poisson(lambda_),obs = fish_num)
※モデル以外のサンプリング・可視化コードは過去同様なので省略
99%予測区間よりも外側に多くのデータが位置し、モデルの精度があまり高くないことがわかります。
モデル
ここで利用したモデルは以下のような以下のような単純なモデルとなっています。
※ここで、 $x_1$ は'weather'のダミー変数、 $x_2$ は'temperature' です。
このモデルを改良し、線形予測子にランダム効果を加えた以下のようなモデルにとします。
前述のモデルはすべての事象に共通した効果量であるため固定効果と呼ばれ、$r_i$ のような何らかの確率分布に従いランダムに変化する係数であるため、ランダム効果と呼ばれます。
このような固定効果とランダム効果を組み合わせたモデルを混合モデルと呼び、これまでの一般化線形モデルにランダム効果を加えたモデルを一般化線形混合モデル(GLMM)と呼びます。
これもNumpyroで以下のとおり実装します。
def model_2( N, C, X, fish_num, ): beta = numpyro.sample("beta",dist.Normal(0,100),sample_shape=(C,)) sigma_r = numpyro.sample("sigma_r",dist.HalfNormal(100)) r = numpyro.sample("r",dist.Normal(0,sigma_r),sample_shape=(N,)) with numpyro.plate("N",N): lambda_ = jnp.exp(jnp.dot(X,beta)+r) numpyro.sample("fish_num",dist.Poisson(lambda_),obs = fish_num)
X_1 = df_2[["Intercept","weather_sunny","temperature"]].values data_dict_1 = { "N":X_1.shape[0], "C":X_1.shape[1], "X":X_1, "fish_num":df["fish_num"].values } kernel_2 = NUTS(model_2) sample_kwargs_2 = dict( sampler=kernel_2, num_warmup=2000, num_samples=2000, num_chains=4, chain_method="parallel" ) mcmc_2 = MCMC(**sample_kwargs_2) mcmc_2.run(random.PRNGKey(0), **data_dict_1)
解釈
az.plot_trace(mcmc_2,compact=True,figsize=(14,8)) plt.show()
トレースプロットは以下のとおりできちんと収束していることがわかります。
'sigma_r' がランダム効果の分散で、それによって 'r' がサンプルごとに異なっていることがわかります。
az.summary(mcmc_2,var_names=["beta","sigma_r"])
各パラメータも書籍と同様になっており、無事サンプルごとのランダム効果を取り入れたGLMMが推定できました。