Source code for dowhy.interpreters.propensity_balance_interpreter

import numpy as np
import pandas as pd

from dowhy.interpreters.visual_interpreter import VisualInterpreter
from dowhy.causal_estimators.propensity_score_estimator import PropensityScoreEstimator
from dowhy.causal_estimators.propensity_score_stratification_estimator import PropensityScoreStratificationEstimator
from dowhy.causal_estimator import CausalEstimate

[docs]class PropensityBalanceInterpreter(VisualInterpreter): SUPPORTED_ESTIMATORS=[PropensityScoreStratificationEstimator,] def __init__(self, estimate, **kwargs): super().__init__(estimate, **kwargs) if not isinstance(estimate, CausalEstimate): error_msg = "The interpreter method expects a CausalEstimate object." self.logger.error(error_msg) raise ValueError(error_msg) self.estimator = self.estimate.estimator if not any(isinstance(self.estimator, est_class) for est_class in PropensityBalanceInterpreter.SUPPORTED_ESTIMATORS): error_msg = "The interpreter method only supports propensity score stratification estimator." self.logger.error(error_msg) raise ValueError(error_msg)
[docs] def interpret(self): """Balance plot that shows the change in standardized mean differences for each covariate after propensity score stratification. """ cols = self.estimator._observed_common_causes_names + self.estimator._treatment_name +["strata", "propensity_score"] df = self.estimator._data[cols] df_long=pd.wide_to_long(df.reset_index(), stubnames=["W"], i='index', j="common_cause_id").reset_index().astype({'W': 'float64'}) # First, calculating mean differences by strata mean_diff = df_long.groupby(self.estimator._treatment_name+ ["common_cause_id", "strata"]).agg(mean_w = ("W", np.mean)) mean_diff = mean_diff.groupby(["common_cause_id","strata"]). transform(lambda x: x.max() - x.min()).reset_index() mean_diff = mean_diff.query("v0==True") size_by_w_strata = df_long.groupby(["common_cause_id", "strata"]).agg( size=("propensity_score", np.size)).reset_index() size_by_strata = df_long.groupby(["common_cause_id"]).agg( size=("propensity_score", np.size)).reset_index() size_by_strata = pd.merge(size_by_w_strata, size_by_strata, on="common_cause_id") mean_diff_strata = pd.merge(mean_diff, size_by_strata, on=("common_cause_id", "strata")) stddev_by_w_strata = df_long.groupby(["common_cause_id", "strata"]).agg( stddev=("W", np.std)).reset_index() mean_diff_strata = pd.merge(mean_diff_strata, stddev_by_w_strata, on=["common_cause_id", "strata"]) mean_diff_strata["scaled_mean"] = (mean_diff_strata["mean_w"]/mean_diff_strata["stddev"])* (mean_diff_strata["size_x"]/mean_diff_strata["size_y"]) mean_diff_strata = mean_diff_strata.groupby("common_cause_id").agg(std_mean_diff = ("scaled_mean", np.sum )).reset_index() # Second, without strata mean_diff_overall = df_long.groupby(self.estimator._treatment_name+ ["common_cause_id"]).agg(mean_w = ("W", np.mean)) mean_diff_overall = mean_diff_overall.groupby("common_cause_id"). transform(lambda x: x.max() - x.min()).reset_index() mean_diff_overall=mean_diff_overall[mean_diff_overall[self.estimator._treatment_name[0]]==True] #TODO stddev_overall = df_long.groupby(["common_cause_id"]).agg( stddev=("W", np.std)).reset_index() mean_diff_overall = pd.merge(mean_diff_overall, stddev_overall, on=["common_cause_id"]) mean_diff_overall["std_mean_diff"] = mean_diff_overall["mean_w"]/mean_diff_overall["stddev"] # Third, concatenating them and plotting mean_diff_overall = mean_diff_overall[["common_cause_id", "std_mean_diff"]] mean_diff_strata["sample"]="PropensityAdjusted" mean_diff_overall["sample"] = "Unadjusted" plot_df = pd.concat([mean_diff_overall, mean_diff_strata]) import matplotlib.pyplot as plt plt.style.use("seaborn-white") fig, ax = plt.subplots(1,1) for label, subdf in plot_df.groupby('common_cause_id'): subdf.plot(kind="line", x="sample", y="std_mean_diff", ax=ax, label=label) plt.legend(title="Common causes") plt.ylabel("Standardized mean difference between treatment and control") plt.xlabel("") plt.xticks(rotation=45) return plot_df