名前はまだない

データ分析とかの備忘録か, 趣味の話か, はたまた

BART:Bayesian Additive Regression Treesについて

はじめに

条件付き平均処置効果を算出する因果推論コンペ等で良い成績を収めている手法としてBayesian Additive Regression Trees(BART)が有名です。

BARTについて名前以上のことを理解していないので、調べて簡単にまとめました。

arxiv.org

Bayesian Additive Regression Trees

BARTは因果推論のための機械学習モデルではなく、予測モデルの一つです。

未知の関数 fを近似するための決定木ベースのアンサンブルモデルです。他のアンサンブル手法と同様に、各木は弱学習器として機能し、特徴量と結果の関係の一部だけを表現します。 決定木は非常に解釈しやすく柔軟なモデルですが、過学習しやすいという欠点があります。BARTでは過学習を避けるために正則化事前分布を使用し、各木が共変量と予測変数の間の関係の一部しか説明できないようにしています。

はじめに、BARTでは以下のような関係も学習することを考えていきます。

 \displaystyle{
Y = f(X) + \epsilon,     \epsilon \sim N(0, \sigma^2)
}

BARTでは、上記の関係を次のような形で推定します。

 \displaystyle{
Y = h(X) + \epsilon = \sum_{j=1}^m g_k(x) + \epsilon,     \epsilon \sim N(0, \sigma^2)
}

式からわかるように決定木をm個アンサンブルすることによって予測を実行するモデルである。 GBDTは残差を最小にするような決定木を段階的に学習していくが、BARTは段階的な学習は行わない。(段階的ではないが順に学習はしていく)

そして、残差に正規分布を仮定しており、ツリーの合計モデルとそのモデルのパラメータの事前正規化の 2 つの部分で構成されます。

決定木のアンサンブル

一つの二分木Tを以下のように表現する。

 \displaystyle{
g(x;T_j, M_j)
}

ここで M={\mu_1, \mu_2, \ldots, \mu_b}は決定技における各枝ごとの目的変数の値である。 なおTは決定木の構造であり、特定の特徴量で2分割するときの閾値を保持している。

事前分布の設定

事前分布を設定することで個々の決定木が過剰に影響しないようにする正則化の作用がある。 これにより加法表現による表現力の獲得と過学習を回避することができる。

事前の独立性と対称性

正規化事前分布の指定を簡素化するために、 g(x;T_j,M_j)は互いに独立しており、また \sigmaとも独立している。 このことは次のように表現できる。

 \displaystyle{
p((T_1 , M_1), \ldots , (T_m , M_m ), \sigma ) = \left [ \prod_j p(T_j , M_j ) \right ] p( \sigma ) 
}
 \displaystyle{
= \left [ \prod_j p(M_j \mid T_j) p(T_j) \right ] p(\sigma)
}

また次のように説明できる。

 \displaystyle{
p(M_j \mid T_j) = \prod_j p(\mu_{ij} \mid T_j)
}

この関係が成り立つ場合、 p(T_j), p(M_j \mid T_j) , p(\sigma)を事前分布として指定する必要がある。

T_jの事前分布

決定木T_jの構造の決定には三つの要素が影響しているとして事前分布を考える。

1. ノードの深さ

決定木のノードの深さd。以下の分布により設定。

 \displaystyle{
\frac{\alpha}{(1+d)^\beta},  \alpha \in (0, 1), \beta \in [0, \infty ) 
}

正則化により各ツリーの複雑性を低く保つために、α=0.95 および β=2 を使用しています。 この場合は1, 2, 3, 4, 5つ以上の終端ノードを持つツリーに対する事前確率はそれぞれ0.05、0.55、0.28、0.09、0.03となる。

2. ノードの分割対象とする変数の分布

どの変数で分岐を行うかを選択する。可能性のある変数から一様(等確率)で選択。

3. 分割する値の分布

分割を行う場合の閾値。可能性のある分割値の離散集合に対する一様事前分布を用いる。

(Extremely Randomized Treesに近い状態と言える?)

\mu_jの事前分布

決定木の各枝の値\mu_jの事前分布は以下のように提案されている。

 \displaystyle{
\mu_{ij} \sim \mathcal{N}(0, \sigma_{\mu}^2)
}

ここで \sigma_{ \mu} = \frac{ 0.5 }{ k \sqrt{m} } と設定する。

この事前分布により各決定木の出力が過剰な大きさにならず、正則化の役目を果たすとしている。

また、Chipmanはkは1~3の間にすると良い結果が得られるとしている。

\sigmaの事前分布

 \displaystyle{
\sigma^2 \sim \frac{\nu \lambda}{\chi_{\nu}^2}
}

基本的に、この目的のために自由度  \nuとスケール  \lambda の事前分布をデータに基づいた粗い過大評価  \hat \sigma 使用して調整します。  \hat \sigma の自然な選択肢は次の2つです:

  • ナイーブな仕様:y の標本標準偏差
  • 線形モデルの仕様:y の最小二乗線形回帰からの残差標準偏差

その後、適切な形状を得るために  \nu の値を3から10の間で選び、事前分布の \sigma に関する第 q 分位点が \hat \sigma に位置するように  \lambda の値を選びます。

Chipmanはデフォルト設定として( \nu, q) = (3, 0.90)を推奨している。 また  \nu \lt 3  \sigmaの過剰適合につながるため推奨されていない。

モデルの推定

MCMCのサンプリングによる分布推定を用いて、観測値から以下の事後分布を推定する。

 \displaystyle{
p((T_1,M_1),...,(T_m,M_m), \sigma | y)
}

ここで表記を簡単にするために、 T_{(j)}  T_j を除いたすべてのツリーの集合とし、同様にM_ { ( j ) } を定義します。したがって、T_{ ( j ) } は m−1 個のツリーの集合となり、M_{ ( j ) } はそれに対応する終端ノードのパラメータです。ここでのサンプリングでは (T_ { ( j ) }, M_{ ( j ) }, \sigma )を条件づけて ( T_j, M_j, \sigma )で連続して m 回サンプリングすることを意味します。

 \displaystyle{
(T_j, M_j )|T_{(j)},M_{(j)},\sigma,y,
}

また\sigmaに関しては、観測値とツリー情報に従ってサンプリングすることになる。

 \displaystyle{
\sigma | T_1, ...,T_m, M_1, ..., M_m,  y
}

\sigmaのサンプリングは逆ガンマ分布からのサンプルであり、一般的な方法で容易に得ることができます。 一方で ( T_j, M_j)のサンプリングは、通常の分布関数のサンプリングと異なり難しいため、以下の関係を考える。

 \displaystyle{
R_j = Y -  \sum_{m \neq j}^m g(x; T_k, M_k) 
}

これはj番目のツリーを除外したフィットに基づく部分残差のn次元ベクトルです。

この部分残差を用いるとサンプリングは、次にように修正することができます。

 \displaystyle{
(T_j, M_j )| R_j , \sigma
}

実際には T_j, M_jの各サンプリングを、次の2つの連続したステップで実行することになります。

 \displaystyle{
T_j| R_j , \sigma \
}
 \displaystyle{
M_j | T_j, R_j , \sigma
}

また T_jのサンプリングは、やや複雑ですがCGM98のMetropolis-Hastings (MH) アルゴリズムを使用して実行します。 このアルゴリズムでは現在のツリーに基づいて次の4つの操作のいずれかを使用して新しいツリーを提案します。操作とそれに対応する提案確率は以下の通りです

  • 終端ノードの成長(0.25)
  • 終端ノードのペアの剪定(0.25)
  • 非終端規則の変更(0.40)
  • 親子間での規則の交換(0.10)

因果効果の推定

条件付き処置効果の推定は、処置Tを行う場合と行わない場合の結果変数yの期待値の差をとることにより行う。

 \displaystyle{
\tau(x) = \text{E}(y ~ \vert~ T = 1, X = x) - \text{E}(y ~ \vert~ T = 0, X = x)
}

www.researchgate.net

またBARTの学習に用いるの共変量に傾向スコアを追加して因果効果を推定する手法も提案されている。

この方法で推定される条件付き平均処置効果の精度が改善するとされている。

arxiv.org

実行

RでBARTを実行するパッケージとして、BARTパッケージが代表的である模様。

BARTパッケージを用いたコードは以下のページを参考にした。

cran.r-project.org

はじめにサンプルデータの生成。

library(BART)
library(tidyverse)
library(tidybayes)
library(tidytreatment)

sim <- simulate_su_hill_data(n = 200,  
                             treatment_linear = FALSE,  
                             omega = 0, 
                             add_categorical = TRUE,
                             coef_categorical_treatment = c(0,0,1),
                             coef_categorical_nontreatment = c(-1,0,-1)
)

dat <- sim$data

生成プロセスを確認する。

> sim$formulas
$treatment_assignment
expression(0.4*x5 + 0.2*x6 + 0.4*x7 + 0.2*x8 + 0.4*x9 + 0.2*x10 + 0.8*I(x5^2) + 0.8*I(x6^2) + 0.5*I(x5 * x6) + 0.3*I(x5 * x6 * x7) + 0.8*I(x7^2) + 0.2*I(x7^3) + 0.4*I(x8^2) + 0.3*I(x7 * x8) + 0.8*I(x9^2) + 0.5*I(x9 * x10))

$response_treatment
expression(0.5*x5 + 2*x7 + 0.5*x9 + 2*x10 + 0.4*I(x5^2) + 0.8*I(x6^2) + 0.5*I(x7^2) + 0.5*I(x8^2) + 0.5*I(x9^2) + 0.7*I(x9 * x10) + 1*I(c1=='3'))

$response_nontreatment
expression(0.5*x5 + 2*x7 + 0.5*x9 + 2*x10 + 0.4*I(x5^2) + 0.8*I(x6^2) + 0.5*I(x7^2) + 0.5*I(x8^2) + 0.5*I(x9^2) + 0.7*I(x9 * x10) + -1*I(c1=='1') + -1*I(c1=='3'))

$generic
~x1 + x2 + I(x1^2) + I(x2^2) + I(x2 * x6) + x5 + x6 + x7 + x8 + 
    x9 + x10 + I(x5^2) + I(x6^2) + I(x5 * x6) + I(x5 * x6 * x7) + 
    I(x7^2) + I(x7^3) + I(x8^2) + I(x7 * x8) + I(x9^2) + I(x9 * 
    x10)
<environment: 0x7fad347e6870>

モデル推定

4段階で因果効果を推定する。

Hahn、Murray、Carvalho (2020)の手順に従った以下の方法で因果効果の推定を行って行きます。

  1. 変数選択モデルの作成:共変量(処置変数を除く)に対する結果を回帰する
  2. 変数選択モデルから結果変数に影響がある共変量のサブセットを選択
  3. 傾向スコアモデルの作成:ステップ2で選択された共変量のみを使用して傾向スコアを推定するプロビット/ロジットモデルを作成
  4. 処置効果モデルを適合する:ステップ3の元の共変量と傾向スコアを使用する
# STEP 1 VS Model: 共変量を用いた回帰モデル
var_select_bart <- wbart(x.train = select(dat,-y,-z), 
                         y.train = pull(dat, y),
                         sparse = TRUE, # 変数の選択
                         nskip = 1000, # burn inのステップ数
                         ndpost = 3000 # サンプリング回数
                         )
# STEP 2: 変数選択
var_select <- covar_ranking %>% 
  filter(avg_inclusion >= quantile(avg_inclusion, 0.5)) %>%
  pull(variable)
var_select <- unique(gsub("c1[1-3]$","c1", var_select))
> var_select
[1] "c1"  "x5"  "x6"  "x7"  "x8"  "x9"  "x10"
# STEP 3 PS Model: 傾向スコアモデルの作成
prop_bart <- pbart(
  x.train = select(dat, all_of(var_select)), 
  y.train = pull(dat, z), 
  nskip = 1000,
  ndpost = 3000
) 

dat$ps <-  prop_bart$prob.train.mean
# Step 4 TE Model: BART + Psocre
te_model <- wbart(
  x.train = select(dat,-y),
  y.train = pull(dat, y),
  nskip = 1000,
  ndpost = 2000, 
  keepevery = 100 
)

推定結果の処理

処置効果の推定結果を抽出して、ヒストグラムを表示する。 この時表示するのは、全ての観測サンプルに対して実施したサンプリングの値の全てが混ざって状態である。

posterior_treat_eff <- treatment_effects(te_model, treatment = "z", newdata = dat) 

posterior_treat_eff %>% 
  ggplot() +
  geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + 
  ggtitle("処置効果のヒストグラム") + xlab("CATE")+
  theme_bw(base_family = "HiraKakuPro-W3")

次に各個体から取得した処置効果の2000件のサンプリングの中央値のヒストグラムを表示する。

posterior_treat_eff %>% 
  summarise(cte_hat = median(cte)) %>%
  ggplot() +
  geom_histogram(aes(x = cte_hat), binwidth = 0.1, colour = "white") + 
  ggtitle("処置効果のヒストグラム (各対象における中央値)") + xlab("CATE")+
  theme_bw(base_family = "HiraKakuPro-W3")

今回の場合は二峰性となっていた。

特定の共変量を条件付けた場合の因果効果の分布も確認することができる。

posterior_fitted <- fitted_draws(te_model, value = "fit", include_newdata = FALSE)

treatment_var_and_c1 <- 
  dat %>% 
  select(z,c1) %>%
  mutate(.row = 1:n(), z = as.factor(z))

posterior_fitted %>%
  left_join(treatment_var_and_c1, by = ".row") %>%
  ggplot() + 
  stat_halfeye(aes(x = z, y = fit)) + 
  facet_wrap(~c1, labeller = as_labeller( function(x) paste("c1 =",x) ) ) +
  xlab("処置") + ylab("CATE") +
  theme_bw() + ggtitle("C1と処置変数を条件付けた場合の結果変数の事後分布")

次に各個体の処置効果の信頼区間を表示する。

posterior_treat_eff %>% 
  select(-z) %>%
  point_interval() %>%
  arrange(cte) %>%
  mutate(.orow = 1:n()) %>% 
  ggplot() + 
  geom_interval(aes(x = .orow, y= cte, ymin = .lower, ymax = .upper)) +
  geom_point(aes(x = .orow, y = cte), shape = "circle open", alpha = 0.5) + 
  ggtitle("CATE分布(95%信頼区間)") +
  ylab("CATE")+
  theme_bw() + coord_flip() + scale_colour_brewer() +
  theme(axis.title.y = element_blank(), 
        axis.text.y = element_blank(), 
        axis.ticks.y = element_blank(),
        legend.position = "none")

処理効果推定における変数重要度の表示

特定の変数が処置変数と合わせてツリーに(平均的に)何回含まれたかを数えることで重要度としています。

treatment_interactions <-
  covariate_with_treatment_importance(te_model, treatment = "z")

treatment_interactions %>% 
  ggplot() + 
  geom_bar(aes(x = reorder(variable, avg_inclusion), y = avg_inclusion), stat = "identity") +
  ggtitle("処置Zとの相互作用が強い変数") +
  ylab("Inclusion counts") + xlab("") + 
  theme(axis.text.x = element_text(angle = 45, hjust=1)) +
  theme_bw(base_family = "HiraKakuPro-W3")

モデルの診断

推定したモデルにおける結果変量の残差を確認

res <- residual_draws(te_model, response = pull(dat, y), include_newdata = FALSE)

res %>%   
  point_interval(.residual, y, .width = c(0.95) ) %>%
  select(-y.lower, -y.upper) %>%
  ggplot() + 
  geom_pointinterval(aes(x = y, y = .residual, ymin = .residual.lower,  ymax = .residual.upper), alpha = 0.2) +
  scale_fill_brewer() +
  ggtitle("観測値と残差の傾向") + 
  theme_bw(base_family = "HiraKakuPro-W3")

観測値とモデル出力の相関を確認。

res %>% summarise(.fitted = mean(.fitted), y = first(y)) %>% 
  ggplot(aes(x = y, y = .fitted)) +
  geom_point() + 
  geom_smooth(method = "lm") + 
  ggtitle("観測値とモデルの出力") +
  theme_bw(base_family = "HiraKakuPro-W3")

Q-Qプロットの確認

res %>% summarise(.residual = mean(.residual)) %>%
  ggplot(aes(sample = .residual)) + 
  geom_qq() + 
  geom_qq_line() + 
  ggtitle("Q-Qプロット") +
  theme_bw(base_family = "HiraKakuPro-W3")

参考