俺のOneNote

俺のOneNote

データ分析が仕事な人のOneNote愛とか、分析小話とか。

Numpyroでベイズ統計モデリング~ポアソン回帰モデル~

RとStanで始めるベイズ統計モデリングによるデータ分析入門のNumpyro実装第5回。

今回はポアソン回帰モデルです。 リンク関数に対数関数、分布にポアソン分布を利用します。
ポアソン分布は、ある期間に平均 $\lambda$ 回起こる事象が、$X$ 回起こる確率の分布です。

パラメータは $\lambda$ の一つのみで、正の値しかとらないことが特徴です。

準備

import jax.numpy as jnp
import numpy as np
import jax.random as random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='darkgrid',palette='bright')

!pip install numpyro
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
numpyro.set_host_device_count(4)

import arviz as az


fish_num_climate = pd.read_csv("3-8-1-fish-num-1.csv")

上記のように、気温 temperature と天気 weather により、魚の釣果がどのようになったかのデータセットになっています。

color_dict = {"cloudy":"b","sunny":"r"}

g = sns.relplot(data=fish_num_climate,x="temperature",y="fish_num",hue="weather",palette=color_dict)
g.fig.set_figheight(6)
g.fig.set_figwidth(10)

散布図をみると、晴れ sunny よりも曇りcloudy のほうが釣果が多いようです。
また、目的変数が正の値しかとらないことや、temperature が低い場合はデータの分散が少なく、単純な正規分布を仮定したモデルは適さないことがうかがえます。

モデル

いつもどおりダミー変数処理を噛ませて、デザイン行列によりモデルを記述していきます。

fish_num_climate_2 = pd.get_dummies(fish_num_climate).drop("weather_cloudy",axis=1)
fish_num_climate_2["Intercept"] = 1

モデルでは、リンク関数である対数関数の逆関数として exp で線形予測子を変換しています。

あとはこれまで正規分布から発生させていたものを、ポアソン分布 dist.Poisson に変換するだけです。

# モデル
def model(
    N,
    C,
    X,
    fish_num,
  
):
  beta = numpyro.sample("beta",dist.Normal(0,100),sample_shape=(C,))

  with numpyro.plate("N",N):
    lambda_ = jnp.exp(jnp.dot(X,beta))
    numpyro.sample("fish_num",dist.Poisson(lambda_),obs = fish_num)

# データ準備、推論
X = fish_num_climate_2[["Intercept","weather_sunny","temperature"]].values
data_dict = {
    "N":X.shape[0],
    "C":X.shape[1],
    "X":X,
    "fish_num":fish_num_climate_2["fish_num"].values
}

kernel = NUTS(model)
sample_kwargs = dict(
    sampler=kernel, 
    num_warmup=2000, 
    num_samples=2000, 
    num_chains=4, 
    chain_method="parallel"
)
mcmc = MCMC(**sample_kwargs)
mcmc.run(random.PRNGKey(0), **data_dict)

解釈

結果は無事、書籍と同様になっています。

az.summary(mcmc)

beta[1]sunny の場合の係数です。
リンク関数を使用したモデルであるため、解釈としては「晴れの場合、釣果はexp(-0.59)倍となる」という表現が正しいです。

事後予測分布

続いて、sunny , cloudy 別の事後予測分布を描写してみます。

# 事後予測分布の取得
mcmc_samples=mcmc.get_samples()
predictive = numpyro.infer.Predictive(model, mcmc_samples)

pred_dict = {
    "N":X.shape[0],
    "C":X.shape[1],
    "X":X,
    "fish_num":None
}

ppc_samples = predictive(random.PRNGKey(0),**pred_dict)
idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples)
fish_num_pred =  idata_ppc.posterior_predictive['fish_num']

# データの準備
cloudy_index = fish_num_climate_2[fish_num_climate_2["weather_sunny"]==0].index
sunny_index = fish_num_climate_2[fish_num_climate_2["weather_sunny"]==1].index

cloudy_temp = fish_num_climate_2.iloc[cloudy_index,:]["temperature"]
sunny_temp = fish_num_climate_2.iloc[sunny_index,:]["temperature"]

cloudy_fush_num_pred = ppc_samples["fish_num"][:,list(cloudy_index)]
sunny_fush_num_pred = ppc_samples["fish_num"][:,list(sunny_index)]

# 可視化
ax = az.plot_hdi(cloudy_temp, cloudy_fush_num_pred, hdi_prob=0.99, plot_kwargs={"ls": "--"},smooth=False,color="b",figsize=(10,6))
az.plot_hdi(sunny_temp, sunny_fush_num_pred, hdi_prob=0.99, plot_kwargs={"ls": "--"},smooth=False,color="r")
sns.scatterplot(data=fish_num_climate, x="temperature",y="fish_num",hue="weather",palette=color_dict)
plt.show()

これで、99%ベイズ予測区間の描写ができました。 temperatureが高くなるほど分散が大きくなっており、sunnyよりcloudyのほうが釣果予測が高いことが分かります。

まとめ

データの特徴とその生成過程を、リンク関数や確率分布によって柔軟に構築できることが一般化線形モデルの強みであり、
ベイジアンモデリングを利用することにより、確率的にパラメータや予測区間を推定できることが解釈の上で有用なケースが多いです。

次は、二項分布を利用したロジスティック回帰モデルがテーマです。