Conditional Average Treatment Effects (CATE) with DoWhy and EconML#
This is an experimental feature where we use EconML methods from DoWhy. Using EconML allows CATE estimation using different methods.
All four steps of causal inference in DoWhy remain the same: model, identify, estimate, and refute. The key difference is that we now call econml methods in the estimation step. There is also a simpler example using linear regression to understand the intuition behind CATE estimators.
All datasets are generated using linear structural equations.
[1]:
%load_ext autoreload
%autoreload 2
[2]:
import numpy as np
import pandas as pd
import logging
import dowhy
from dowhy import CausalModel
import dowhy.datasets
import econml
import warnings
warnings.filterwarnings('ignore')
BETA = 10
[3]:
data = dowhy.datasets.linear_dataset(BETA, num_common_causes=4, num_samples=10000,
num_instruments=2, num_effect_modifiers=2,
num_treatments=1,
treatment_is_binary=False,
num_discrete_common_causes=2,
num_discrete_effect_modifiers=0,
one_hot_encode=False)
df=data['df']
print(df.head())
print("True causal estimate is", data["ate"])
X0 X1 Z0 Z1 W0 W1 W2 W3 v0 \
0 -0.309229 -1.580070 0.0 0.351641 -0.244653 0.229809 1 2 14.187684
1 1.182555 -0.759687 1.0 0.633429 0.308832 -0.411033 2 1 30.081103
2 -0.120545 0.997922 1.0 0.584607 -0.583494 1.444148 1 0 21.757118
3 -0.189336 1.553915 0.0 0.616707 0.830747 -0.274412 0 0 6.230958
4 -1.155143 1.908829 0.0 0.749289 0.575723 0.537886 3 2 29.004429
y
0 94.431111
1 462.026719
2 239.145803
3 70.781719
4 199.744559
True causal estimate is 9.759698502204927
[4]:
model = CausalModel(data=data["df"],
treatment=data["treatment_name"], outcome=data["outcome_name"],
graph=data["gml_graph"])
[5]:
model.view_model()
from IPython.display import Image, display
display(Image(filename="causal_model.png"))
[6]:
identified_estimand= model.identify_effect(proceed_when_unidentifiable=True)
print(identified_estimand)
Estimand type: EstimandType.NONPARAMETRIC_ATE
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(E[y|W2,W0,W3,W1])
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W2,W0,W3,W1,U) = P(y|v0,W2,W0,W3,W1)
### Estimand : 2
Estimand name: iv
Estimand expression:
⎡ -1⎤
⎢ d ⎛ d ⎞ ⎥
E⎢─────────(y)⋅⎜─────────([v₀])⎟ ⎥
⎣d[Z₀ Z₁] ⎝d[Z₀ Z₁] ⎠ ⎦
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)
### Estimand : 3
Estimand name: frontdoor
No such variable(s) found!
Linear Model#
First, let us build some intuition using a linear model for estimating CATE. The effect modifiers (that lead to a heterogeneous treatment effect) can be modeled as interaction terms with the treatment. Thus, their value modulates the effect of treatment.
Below the estimated effect of changing treatment from 0 to 1.
[7]:
linear_estimate = model.estimate_effect(identified_estimand,
method_name="backdoor.linear_regression",
control_value=0,
treatment_value=1)
print(linear_estimate)
*** Causal Estimate ***
## Identified estimand
Estimand type: EstimandType.NONPARAMETRIC_ATE
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(E[y|W2,W0,W3,W1])
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W2,W0,W3,W1,U) = P(y|v0,W2,W0,W3,W1)
## Realized estimand
b: y~v0+W2+W0+W3+W1+v0*X1+v0*X0
Target units:
## Estimate
Mean value: 9.759655771129744
### Conditional Estimates
__categorical__X1 __categorical__X0
(-2.8049999999999997, 0.1] (-4.375, -1.138] 0.398088
(-1.138, -0.554] 4.873899
(-0.554, -0.0303] 7.800511
(-0.0303, 0.533] 10.572449
(0.533, 3.292] 15.339292
(0.1, 0.696] (-4.375, -1.138] 1.726966
(-1.138, -0.554] 6.149392
(-0.554, -0.0303] 9.023480
(-0.0303, 0.533] 11.799896
(0.533, 3.292] 16.378770
(0.696, 1.21] (-4.375, -1.138] 2.299447
(-1.138, -0.554] 6.880426
(-0.554, -0.0303] 9.728293
(-0.0303, 0.533] 12.603546
(0.533, 3.292] 17.300397
(1.21, 1.806] (-4.375, -1.138] 3.033189
(-1.138, -0.554] 7.586417
(-0.554, -0.0303] 10.577888
(-0.0303, 0.533] 13.411810
(0.533, 3.292] 18.112063
(1.806, 4.792] (-4.375, -1.138] 4.246136
(-1.138, -0.554] 8.803440
(-0.554, -0.0303] 11.766477
(-0.0303, 0.533] 14.518407
(0.533, 3.292] 19.081754
dtype: float64
EconML methods#
We now move to the more advanced methods from the EconML package for estimating CATE.
First, let us look at the double machine learning estimator. Method_name corresponds to the fully qualified name of the class that we want to use. For double ML, it is “econml.dml.DML”.
Target units defines the units over which the causal estimate is to be computed. This can be a lambda function filter on the original dataframe, a new Pandas dataframe, or a string corresponding to the three main kinds of target units (“ate”, “att” and “atc”). Below we show an example of a lambda function.
Method_params are passed directly to EconML. For details on allowed parameters, refer to the EconML documentation.
[8]:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LassoCV
from sklearn.ensemble import GradientBoostingRegressor
dml_estimate = model.estimate_effect(identified_estimand, method_name="backdoor.econml.dml.DML",
control_value = 0,
treatment_value = 1,
target_units = lambda df: df["X0"]>1, # condition used for CATE
confidence_intervals=False,
method_params={"init_params":{'model_y':GradientBoostingRegressor(),
'model_t': GradientBoostingRegressor(),
"model_final":LassoCV(fit_intercept=False),
'featurizer':PolynomialFeatures(degree=1, include_bias=False)},
"fit_params":{}})
print(dml_estimate)
*** Causal Estimate ***
## Identified estimand
Estimand type: EstimandType.NONPARAMETRIC_ATE
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(E[y|W2,W0,W3,W1])
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W2,W0,W3,W1,U) = P(y|v0,W2,W0,W3,W1)
## Realized estimand
b: y~v0+W2+W0+W3+W1 | X1,X0
Target units: Data subset defined by a function
## Estimate
Mean value: 19.129843183233174
Effect estimates: [[15.20183627]
[19.55021696]
[22.40585988]
[19.54489692]
[19.78961984]
[17.08052593]
[20.71110103]
[20.20745576]
[16.65819366]
[23.17332997]
[20.47872369]
[19.79488681]
[19.85518623]
[19.10674021]
[17.27692916]
[13.85143981]
[17.81042865]
[15.65953403]
[23.0234651 ]
[22.22690276]
[21.90923516]
[19.33088761]
[17.35739936]
[22.00430775]
[17.46348377]
[22.54796156]
[21.76503544]
[16.41919381]
[14.98282382]
[19.38073719]
[24.79615838]
[18.29803343]
[17.05581112]
[23.87818209]
[21.81598911]
[22.02304146]
[17.80466299]
[16.65662542]
[18.453081 ]
[18.52213105]
[16.76070674]
[16.22522639]
[16.2146118 ]
[18.09934849]
[23.45354808]
[14.35326035]
[17.51588668]
[20.57498224]
[16.95032042]
[15.45416976]
[22.16461737]
[18.79286381]
[18.72479431]
[24.35063502]
[20.45473958]
[20.46249293]
[16.48260157]
[15.92440961]
[16.20767991]
[21.75784058]
[14.63310885]
[17.70836165]
[14.7939248 ]
[19.31828471]
[17.13149588]
[17.60261813]
[15.51696722]
[19.10463479]
[23.44051662]
[15.1274549 ]
[18.17150022]
[19.50025665]
[19.04627931]
[17.98936212]
[21.38856785]
[25.04530927]
[16.92359435]
[18.69804054]
[17.30069888]
[20.06293663]
[20.70327186]
[17.9158388 ]
[22.19387571]
[13.39336532]
[19.99505428]
[19.35719629]
[21.75209627]
[22.72400996]
[24.33638301]
[16.23656447]
[21.18445491]
[17.38085788]
[17.71043719]
[15.16465082]
[21.68711095]
[17.85417372]
[25.48673533]
[19.38082082]
[15.98066771]
[15.99037073]
[17.20257472]
[19.88457815]
[16.32408496]
[19.65664099]
[23.07020369]
[18.77166502]
[24.75785249]
[19.19964808]
[16.95128089]
[18.72060843]
[22.51133937]
[20.04054435]
[19.14457041]
[26.39654363]
[20.20331227]
[16.6847504 ]
[18.54538981]
[23.51414855]
[17.56902411]
[14.90323729]
[18.83827073]
[14.90155485]
[17.53298695]
[21.56870487]
[18.84496954]
[20.41622129]
[19.76894525]
[19.64236222]
[23.10614996]
[29.57651351]
[20.66601191]
[19.45719814]
[24.46031338]
[26.43764472]
[16.92471166]
[17.71930458]
[22.77869962]
[17.90180148]
[15.3106078 ]
[19.29670494]
[13.68605436]
[20.36642688]
[22.27358294]
[18.53243165]
[18.16137344]
[21.77964388]
[19.26343245]
[17.51862444]
[20.28277055]
[18.64850977]
[22.45355838]
[17.30887196]
[25.40826971]
[17.41558671]
[13.48599274]
[18.53246937]
[19.39759071]
[18.66933008]
[19.49428713]
[18.94921438]
[18.92796543]
[17.86528838]
[20.63744393]
[18.31030705]
[16.35610249]
[19.26985886]
[18.61721229]
[19.02899667]
[27.89723061]
[18.83813833]
[17.51730505]
[20.31381571]
[16.81861421]
[18.46970468]
[18.00069479]
[18.6045752 ]
[17.03590824]
[17.37873883]
[14.94035567]
[21.10569452]
[20.42574272]
[16.05251461]
[19.32513422]
[20.94800741]
[18.45407375]
[17.28711323]
[16.742276 ]
[17.4965399 ]
[22.92417043]
[19.6884172 ]
[23.99283098]
[17.33184954]
[21.00099964]
[19.57002518]
[16.98160908]
[17.91411186]
[21.80375061]
[16.83538653]
[17.9133587 ]
[17.07532744]
[18.50212675]
[18.43334743]
[20.59294587]
[21.97900194]
[20.1886126 ]
[22.1391408 ]
[21.02977004]
[17.38372685]
[22.90502821]
[20.23403559]
[22.39970943]
[18.40369267]
[17.78767338]
[19.38001886]
[17.08582365]
[19.33798596]
[15.37870216]
[18.28864927]
[18.33268671]
[17.3833204 ]
[18.97389212]
[18.38946744]
[17.96702087]
[15.66174124]
[19.588337 ]
[19.51010033]
[16.76309067]
[18.7386256 ]
[22.34615299]
[27.54373873]
[16.81547065]
[19.63157826]
[14.86152239]
[19.28125185]
[19.87251012]
[18.83148796]
[17.25738131]
[16.81398822]
[15.42575028]
[16.72309854]
[20.35256215]
[18.18125392]
[18.35749537]
[22.69038616]
[24.42239469]
[19.54379661]
[17.22389761]
[17.264678 ]
[21.81194766]
[25.56052794]
[18.06187269]
[15.85434773]
[18.84901338]
[15.84631772]
[23.65579963]
[17.22594334]
[18.52882866]
[19.81396814]
[17.44610798]
[19.90052756]
[24.53999152]
[16.65378799]
[21.17704924]
[15.38260099]
[18.86352787]
[20.95119338]
[20.47835016]
[20.97338187]
[20.91544456]
[17.05391416]
[17.68327568]
[16.06801914]
[20.62633396]
[21.55632993]
[19.41559114]
[21.1270919 ]
[16.72871938]
[19.41311738]
[18.1531789 ]
[19.80420702]
[17.54777324]
[20.60303617]
[20.15827331]
[22.35093381]
[18.83555115]
[18.60601573]
[16.66467813]
[20.06749519]
[17.37801317]
[20.64946529]
[18.06274759]
[21.69836155]
[15.34978142]
[21.09993223]
[16.29849714]
[19.58245749]
[16.49821585]
[20.02520469]
[16.88339935]
[17.07551126]
[21.37677092]
[19.98039713]
[18.71750869]
[22.46590058]
[16.91627818]
[18.32300519]
[16.10378647]
[19.4968663 ]
[18.094339 ]
[21.22150559]
[17.62938422]
[22.67612259]
[19.31210448]
[17.16955121]
[18.23467425]
[17.39000702]
[19.07078215]
[18.08812648]
[18.47488122]
[17.04435358]
[18.75821491]
[18.69424817]
[16.59648944]
[18.00788111]
[17.82721633]
[19.8770877 ]
[23.97376277]
[16.33844361]
[17.93066312]
[19.87173267]
[22.15422855]
[17.10331518]
[18.38235076]
[14.92071011]
[16.54647379]
[17.04113192]
[18.16067009]
[19.16828693]
[23.79575439]
[18.65961666]
[15.11834556]
[16.38514849]
[19.090668 ]
[22.87499618]
[22.60666679]
[19.06954899]
[18.83078903]
[17.23803841]
[18.60799291]
[18.96775555]
[17.56652797]
[14.72455787]
[22.4154735 ]
[16.55236733]
[20.49868059]
[19.63050876]
[20.58101937]
[19.48555226]
[21.18278933]
[21.51628568]
[18.01682149]
[18.81610288]
[25.90201709]
[20.43045694]
[19.17727077]
[18.06429614]
[16.46318916]
[21.71180192]
[20.20718178]
[21.07339481]
[18.10752441]
[16.19131963]
[27.28916168]
[20.05503805]
[24.60838368]
[20.24049794]
[17.93001446]
[19.29326965]
[19.87507414]
[15.41969309]
[22.39036646]
[19.36302789]
[19.05668003]
[18.63574649]
[15.77973673]
[21.60692948]
[18.67928459]
[22.82169408]
[20.65078766]
[16.29469087]
[15.50984892]
[21.77456298]
[19.49295474]
[18.93292547]
[18.38312246]
[18.9782276 ]
[19.57253673]
[18.00758033]
[29.14727494]
[18.90134771]
[22.3551142 ]
[18.36377175]
[21.03978156]
[17.52248357]
[17.19237151]
[18.60880057]
[28.38818341]
[19.30078181]
[19.28394064]
[20.93107564]
[18.38329419]
[20.94188197]
[16.73771327]
[14.75459201]
[19.29529312]
[21.60492184]
[20.84422927]
[21.1239554 ]
[17.78444486]
[16.38787928]
[18.29579508]
[18.91247607]
[18.71477462]
[19.149995 ]
[16.64199664]
[24.03601829]
[12.81592555]
[16.24092621]
[16.63231194]
[17.63248649]
[17.73412275]
[19.87065593]
[16.86708655]
[18.46630291]
[21.41322854]
[20.85565525]
[21.11071387]
[19.87077845]
[17.32538336]
[19.29615819]
[16.75211894]
[25.85572406]
[18.86180041]
[16.50329365]
[23.95880637]
[14.70003059]
[18.03669236]
[18.70982454]
[20.46533427]
[15.44950097]
[19.32826445]
[15.14538866]
[18.59751459]
[17.93061142]
[19.93757687]
[25.97451092]
[19.25794177]
[13.80117099]
[17.8420948 ]
[19.79107602]
[18.12512464]
[16.04238894]
[18.9991662 ]
[22.52716559]
[17.20817739]
[19.3423548 ]
[18.79936076]
[18.7755071 ]
[19.56890483]
[15.40072827]
[16.58913308]
[18.76715113]
[16.66279498]
[27.64730068]
[19.74249769]
[19.47032262]
[16.12592753]
[20.69807529]
[23.63405854]
[19.43647369]
[25.93837154]
[20.0357851 ]
[15.77588386]
[24.61568668]
[18.85153895]
[21.357767 ]
[24.09064934]
[16.74707953]
[18.9722717 ]
[17.0618107 ]
[17.85450242]
[18.43801423]
[21.48906126]
[20.82612646]
[21.5983478 ]
[19.12060867]
[18.61174967]
[19.80256089]
[17.25466727]
[15.10114277]
[19.85497501]
[20.07834436]
[17.19819092]
[18.61128515]
[19.31718502]
[18.51153803]
[20.26117155]
[16.34452947]
[19.6894647 ]
[17.40118574]
[15.45576617]
[18.58921529]
[18.32617778]
[24.07857433]
[18.87364723]
[20.60792129]
[17.02042314]
[18.86392017]
[22.71978392]
[16.40033295]
[18.51260428]
[16.34633631]
[17.40927032]
[14.3535665 ]
[20.14413912]
[16.93040633]
[20.88369394]
[16.83188203]
[19.01599886]
[18.45608943]
[14.29952652]
[23.4592268 ]
[17.50669206]
[17.07672586]
[19.03168783]
[18.14923622]
[21.73584098]
[15.56902733]
[17.68486509]
[16.34302472]
[19.55914174]
[16.06978724]
[29.22177691]
[21.57852066]
[19.76921897]
[18.52500008]
[19.38854677]
[16.95802005]
[18.85378975]
[15.95472865]
[18.04761855]
[15.17442415]
[20.1176802 ]
[21.71010933]
[19.5456965 ]
[18.86569967]
[16.75032005]
[17.37788463]
[16.45131716]
[21.78246179]
[19.61183537]
[20.80592242]
[22.28107183]
[19.21470425]
[21.78947031]
[20.43962092]
[21.20510383]
[18.68747383]
[15.63899711]
[20.33790264]
[18.05608252]
[24.17115983]
[16.16261219]
[18.97277178]
[26.34455891]
[25.34238924]
[17.05226035]
[18.81597706]
[18.34135554]
[21.11671073]
[20.73450976]
[19.54898062]
[20.3794707 ]
[19.4897815 ]
[21.55449353]
[19.13868673]
[21.51822373]
[18.75659239]
[18.97861188]
[17.66290594]
[20.72824823]
[18.46978124]
[25.07499165]
[15.62542718]
[22.74715115]
[18.50675379]
[17.00053344]
[19.03591366]
[21.9954402 ]
[25.7091185 ]
[16.89703814]
[21.87119293]
[23.49827165]
[18.49327645]
[23.42403009]
[19.63144445]
[18.34705165]
[19.81209303]
[17.1074043 ]
[18.68419014]
[21.46613324]
[16.12454298]
[17.58582561]
[17.41269872]
[19.29794761]
[16.06612142]
[18.20095384]
[16.3084961 ]
[16.7449537 ]
[18.2794248 ]
[19.95921149]
[18.43685066]
[14.56423383]
[20.06498767]
[18.51769082]
[19.08512547]
[19.35258004]
[20.17786134]
[16.91420801]
[20.74690371]
[16.83031816]
[20.20257977]
[17.42923281]
[20.94398418]
[15.19559251]
[19.38768053]
[18.71551482]
[19.23267919]
[14.48023728]
[16.23547931]
[17.65418933]
[18.23142494]
[14.85507851]
[23.69890172]
[19.37787683]
[19.48742511]
[17.3788293 ]
[17.13810815]
[17.94947554]
[22.09418367]
[17.60828025]
[18.20419981]
[18.74628832]
[15.59266835]
[23.55275991]
[19.67521578]
[24.58684638]
[22.07625522]
[17.04620952]
[19.5591036 ]
[19.77444861]
[17.4740881 ]
[16.96713562]
[16.60898273]
[18.51464248]
[18.17081849]
[15.31892641]
[20.4671599 ]
[18.25739773]
[18.49433692]
[17.84009252]
[17.64361067]
[16.18590935]
[16.98291945]
[27.88434865]
[19.02504333]
[14.26481059]
[20.99807885]
[12.92173774]
[17.88572404]
[25.74716575]
[19.24041748]
[17.43639846]
[20.26216979]
[18.12456072]
[15.18571469]
[16.37582881]
[20.2083925 ]
[18.73309578]
[20.88150057]
[17.9652965 ]
[15.60817191]
[20.52222193]
[18.91561493]
[12.9193547 ]
[20.46447538]
[26.52405421]
[19.70700172]
[27.08699219]
[21.987351 ]
[18.33893227]
[19.88988306]
[19.75436983]
[16.60344664]
[17.24864431]
[19.79428098]
[21.5825204 ]
[16.92571791]
[14.62125966]
[18.46011746]
[16.70194936]
[18.2810335 ]
[17.168997 ]
[20.73003758]
[16.05452096]
[18.79602568]
[19.12886916]
[21.36622855]
[18.94242052]
[17.61204172]
[19.51011947]
[20.1337072 ]
[18.88152918]
[20.20202488]
[17.74620146]
[24.61982073]
[18.51776453]
[18.61278337]
[17.489319 ]
[17.68494604]
[14.51674663]
[22.14592026]
[17.95367465]
[17.89319587]
[24.64779122]
[24.87320108]
[18.55105516]
[19.22899803]
[20.83269543]
[18.29118543]
[18.48684273]
[14.42256443]
[15.80573411]
[15.13032854]
[18.48875462]
[17.36962733]
[21.87223307]
[19.45817919]
[18.88475965]
[18.41178895]
[16.72169006]
[18.91825543]
[17.18634716]
[23.01147515]
[19.62252191]
[16.9429303 ]
[15.29755799]
[21.53223933]
[17.56035258]
[18.66950033]
[16.85733315]
[16.60317315]
[19.64525674]
[22.64424442]
[20.58143 ]
[17.04836085]
[17.70775634]
[14.07372912]
[16.52092357]
[17.5394137 ]
[18.88156171]
[22.39443918]
[19.55858306]
[19.55943764]
[30.43422677]
[18.9860247 ]
[19.9388336 ]
[18.9827948 ]
[15.48918244]
[18.32451234]
[16.03945944]
[21.55268268]
[18.31374972]
[17.78023415]
[27.9238354 ]
[17.05510297]
[18.81047558]
[18.75643298]
[20.68726747]
[19.42514266]
[20.82390078]
[19.73245891]
[17.70004035]
[16.01104516]
[20.66867104]
[19.36045214]
[17.98879432]
[28.16052656]
[15.70432231]
[17.28333841]
[18.4075105 ]
[15.50296725]
[14.37607313]
[24.42922139]
[16.16113928]
[16.60726614]
[15.84957424]
[19.029428 ]
[22.22847908]
[23.33396852]
[18.63927796]
[18.34861753]
[18.68434186]
[22.37837786]
[20.80160293]
[20.30196391]
[19.81856021]
[18.34464837]
[19.80619938]
[17.68258343]
[21.38034537]
[15.23719112]
[21.47666182]
[20.05726381]
[19.86718912]
[14.13765648]
[19.54978842]
[17.66831137]
[15.93647893]
[21.47187402]
[15.82209485]
[20.60494383]
[19.08964758]
[18.77913466]
[18.09271529]
[16.42987696]
[18.34463942]
[17.06416309]
[18.79526297]
[16.97023889]
[16.82682824]
[22.2913786 ]
[18.07187871]
[21.62417968]
[17.91760146]
[17.25734138]
[20.35839447]
[19.3034859 ]
[15.19553562]
[20.08355249]
[15.53054138]
[17.67076347]
[18.22873428]
[17.50830795]
[20.50458729]
[22.54434089]
[20.00063583]
[19.46110511]
[18.64973831]
[21.55073752]
[18.93652771]
[16.7578211 ]
[20.10890355]
[14.78441153]
[18.56272883]
[16.19951637]
[18.21363356]
[23.70321494]
[18.47252542]
[19.00251778]
[23.62114796]
[20.15123194]
[16.6599916 ]
[21.26157082]
[18.76579436]
[17.7969205 ]
[17.94096488]
[21.92980402]
[17.66861841]
[15.76848945]
[18.49890951]
[20.17103477]
[20.23650136]
[19.37949119]
[17.60349368]
[20.36300007]
[16.86672708]
[16.00019053]
[23.12271946]
[24.69274565]
[17.77049516]
[17.07268428]
[14.82197157]
[21.69887246]
[17.13211025]
[19.62487604]
[14.85884325]
[17.07038129]
[21.31336114]
[18.11561859]
[16.80467202]
[17.5738248 ]
[16.20103439]
[20.74560567]
[20.39925625]
[24.58751072]
[22.82877939]
[14.77574122]
[20.187799 ]
[18.15842668]
[20.62012631]
[17.46823838]
[17.17694813]
[17.04136194]
[22.32241174]
[19.05317973]
[24.03511433]
[20.83985243]
[22.01748009]
[25.81021734]
[17.46653295]
[22.28892298]
[20.59596976]
[17.21213764]
[19.00218782]
[15.91889239]
[17.3643154 ]
[14.68666162]
[20.44959979]
[25.11799622]
[24.6023766 ]
[21.34145859]
[27.40679542]
[23.33093007]
[19.54793843]
[19.29343568]
[15.84449212]
[26.89686193]
[17.69109442]
[16.01357845]
[21.59582129]
[18.99249517]
[24.68667159]
[19.75546571]
[19.37992351]
[21.88348931]
[14.54586927]
[16.83269389]
[21.30889744]
[16.33095431]
[19.36032896]
[18.71377905]
[16.50812936]
[19.22137318]
[22.35284405]
[20.8983029 ]
[19.84020916]
[17.7036086 ]
[15.73387585]
[16.97579108]
[20.09199201]
[18.52188264]
[18.91686424]
[17.28262574]
[18.67871678]
[19.72896624]
[24.51040237]
[17.20428939]
[17.14500654]
[23.15317123]
[18.60192614]
[17.91032576]]
[9]:
print("True causal estimate is", data["ate"])
True causal estimate is 9.759698502204927
[10]:
dml_estimate = model.estimate_effect(identified_estimand, method_name="backdoor.econml.dml.DML",
control_value = 0,
treatment_value = 1,
target_units = 1, # condition used for CATE
confidence_intervals=False,
method_params={"init_params":{'model_y':GradientBoostingRegressor(),
'model_t': GradientBoostingRegressor(),
"model_final":LassoCV(fit_intercept=False),
'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
"fit_params":{}})
print(dml_estimate)
*** Causal Estimate ***
## Identified estimand
Estimand type: EstimandType.NONPARAMETRIC_ATE
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(E[y|W2,W0,W3,W1])
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W2,W0,W3,W1,U) = P(y|v0,W2,W0,W3,W1)
## Realized estimand
b: y~v0+W2+W0+W3+W1 | X1,X0
Target units:
## Estimate
Mean value: 9.701955486991544
Effect estimates: [[ 6.21092926]
[15.35562554]
[10.69245368]
...
[10.57874362]
[11.78838451]
[14.59987791]]
CATE Object and Confidence Intervals#
EconML provides its own methods to compute confidence intervals. Using BootstrapInference in the example below.
[11]:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LassoCV
from sklearn.ensemble import GradientBoostingRegressor
from econml.inference import BootstrapInference
dml_estimate = model.estimate_effect(identified_estimand,
method_name="backdoor.econml.dml.DML",
target_units = "ate",
confidence_intervals=True,
method_params={"init_params":{'model_y':GradientBoostingRegressor(),
'model_t': GradientBoostingRegressor(),
"model_final": LassoCV(fit_intercept=False),
'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
"fit_params":{
'inference': BootstrapInference(n_bootstrap_samples=100, n_jobs=-1),
}
})
print(dml_estimate)
*** Causal Estimate ***
## Identified estimand
Estimand type: EstimandType.NONPARAMETRIC_ATE
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(E[y|W2,W0,W3,W1])
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W2,W0,W3,W1,U) = P(y|v0,W2,W0,W3,W1)
## Realized estimand
b: y~v0+W2+W0+W3+W1 | X1,X0
Target units: ate
## Estimate
Mean value: 9.707773674361558
Effect estimates: [[ 6.03511542]
[15.19392943]
[10.6961485 ]
...
[10.56333225]
[11.81163023]
[14.61231231]]
95.0% confidence interval: [[[ 5.89738649 15.18635435 10.72211926 ... 10.59480476 11.85233371
14.65175606]]
[[ 6.25148795 15.65375463 10.86867734 ... 10.73755151 12.02514294
14.93143688]]]
Can provide a new inputs as target units and estimate CATE on them.#
[12]:
test_cols= data['effect_modifier_names'] # only need effect modifiers' values
test_arr = [np.random.uniform(0,1, 10) for _ in range(len(test_cols))] # all variables are sampled uniformly, sample of 10
test_df = pd.DataFrame(np.array(test_arr).transpose(), columns=test_cols)
dml_estimate = model.estimate_effect(identified_estimand,
method_name="backdoor.econml.dml.DML",
target_units = test_df,
confidence_intervals=False,
method_params={"init_params":{'model_y':GradientBoostingRegressor(),
'model_t': GradientBoostingRegressor(),
"model_final":LassoCV(),
'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
"fit_params":{}
})
print(dml_estimate.cate_estimates)
[[13.12252028]
[12.91695872]
[15.52526321]
[11.12202724]
[13.39670634]
[12.59673232]
[11.62382935]
[12.57653561]
[13.63902793]
[15.15168255]]
Can also retrieve the raw EconML estimator object for any further operations#
[13]:
print(dml_estimate._estimator_object)
<econml.dml.dml.DML object at 0x7f51bb689a60>
Works with any EconML method#
In addition to double machine learning, below we example analyses using orthogonal forests, DRLearner (bug to fix), and neural network-based instrumental variables.
Binary treatment, Binary outcome#
[14]:
data_binary = dowhy.datasets.linear_dataset(BETA, num_common_causes=4, num_samples=10000,
num_instruments=1, num_effect_modifiers=2,
treatment_is_binary=True, outcome_is_binary=True)
# convert boolean values to {0,1} numeric
data_binary['df'].v0 = data_binary['df'].v0.astype(int)
data_binary['df'].y = data_binary['df'].y.astype(int)
print(data_binary['df'])
model_binary = CausalModel(data=data_binary["df"],
treatment=data_binary["treatment_name"], outcome=data_binary["outcome_name"],
graph=data_binary["gml_graph"])
identified_estimand_binary = model_binary.identify_effect(proceed_when_unidentifiable=True)
X0 X1 Z0 W0 W1 W2 W3 v0 y
0 1.908358 1.837408 0.0 1.551148 0.718173 2.144841 0.105256 1 1
1 -1.518746 -0.530164 0.0 -0.278654 1.209751 2.339488 0.249613 1 1
2 0.757815 -0.144733 1.0 1.968192 -0.272409 0.308428 -1.058811 1 1
3 1.298266 -0.259942 0.0 -0.329655 1.133614 -0.734842 -1.756041 0 0
4 -0.382853 1.908866 0.0 1.452588 -1.018550 0.878970 0.079483 1 1
... ... ... ... ... ... ... ... .. ..
9995 0.145846 0.302591 0.0 1.592968 2.223186 0.100735 -1.631679 0 1
9996 0.412768 1.339241 1.0 1.812368 0.451433 1.035600 -0.365655 1 1
9997 -1.588935 0.483044 1.0 1.543930 -0.349024 0.671658 -1.133542 1 1
9998 -1.103233 0.345121 0.0 0.191568 -1.711238 1.413318 -1.634635 0 1
9999 0.808179 0.645542 0.0 1.144551 0.110358 -0.023063 0.632114 1 1
[10000 rows x 9 columns]
Using DRLearner estimator#
[15]:
from sklearn.linear_model import LogisticRegressionCV
#todo needs binary y
drlearner_estimate = model_binary.estimate_effect(identified_estimand_binary,
method_name="backdoor.econml.dr.LinearDRLearner",
confidence_intervals=False,
method_params={"init_params":{
'model_propensity': LogisticRegressionCV(cv=3, solver='lbfgs', multi_class='auto')
},
"fit_params":{}
})
print(drlearner_estimate)
print("True causal estimate is", data_binary["ate"])
*** Causal Estimate ***
## Identified estimand
Estimand type: EstimandType.NONPARAMETRIC_ATE
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(E[y|W2,W0,W3,W1])
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W2,W0,W3,W1,U) = P(y|v0,W2,W0,W3,W1)
## Realized estimand
b: y~v0+W2+W0+W3+W1 | X1,X0
Target units: ate
## Estimate
Mean value: 0.17871127524721636
Effect estimates: [[0.27289129]
[0.09558651]
[0.19644977]
...
[0.10708254]
[0.12546481]
[0.20981968]]
True causal estimate is 0.1648
Instrumental Variable Method#
[16]:
dmliv_estimate = model.estimate_effect(identified_estimand,
method_name="iv.econml.iv.dml.DMLIV",
target_units = lambda df: df["X0"]>-1,
confidence_intervals=False,
method_params={"init_params":{
'discrete_treatment':False,
'discrete_instrument':False
},
"fit_params":{}})
print(dmliv_estimate)
*** Causal Estimate ***
## Identified estimand
Estimand type: EstimandType.NONPARAMETRIC_ATE
### Estimand : 1
Estimand name: iv
Estimand expression:
⎡ -1⎤
⎢ d ⎛ d ⎞ ⎥
E⎢─────────(y)⋅⎜─────────([v₀])⎟ ⎥
⎣d[Z₀ Z₁] ⎝d[Z₀ Z₁] ⎠ ⎦
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)
## Realized estimand
b: y~v0+W2+W0+W3+W1 | X1,X0
Target units: Data subset defined by a function
## Estimate
Mean value: 12.03618464825759
Effect estimates: [[ 6.06278199]
[15.24163262]
[10.82819085]
...
[10.68475397]
[11.9558582 ]
[14.75439927]]
Metalearners#
[17]:
data_experiment = dowhy.datasets.linear_dataset(BETA, num_common_causes=5, num_samples=10000,
num_instruments=2, num_effect_modifiers=5,
treatment_is_binary=True, outcome_is_binary=False)
# convert boolean values to {0,1} numeric
data_experiment['df'].v0 = data_experiment['df'].v0.astype(int)
print(data_experiment['df'])
model_experiment = CausalModel(data=data_experiment["df"],
treatment=data_experiment["treatment_name"], outcome=data_experiment["outcome_name"],
graph=data_experiment["gml_graph"])
identified_estimand_experiment = model_experiment.identify_effect(proceed_when_unidentifiable=True)
X0 X1 X2 X3 X4 Z0 Z1 \
0 -1.803338 -0.222929 0.850613 0.052941 0.032794 1.0 0.835204
1 0.224042 0.460976 0.094785 -0.547756 -0.462117 1.0 0.428228
2 -0.861512 -0.699643 1.217972 -0.508407 -0.164014 1.0 0.403523
3 -0.470370 2.239789 1.200768 -0.279083 -0.899226 0.0 0.621538
4 -1.510403 1.721982 -0.949821 -0.445389 -2.325149 1.0 0.663460
... ... ... ... ... ... ... ...
9995 -1.957728 2.035616 -0.579188 -0.084068 1.089231 1.0 0.673998
9996 -0.348887 0.447439 0.559542 -1.080055 -0.295480 1.0 0.420495
9997 1.124305 2.668876 -1.503443 -1.011762 -0.702943 1.0 0.501797
9998 -1.868358 0.486675 -0.740070 1.132700 0.245923 1.0 0.927832
9999 -0.258291 1.028679 0.754417 1.484959 -0.882857 1.0 0.640461
W0 W1 W2 W3 W4 v0 y
0 1.096544 -0.582427 -0.294597 0.105514 -0.727263 1 7.907877
1 -0.027400 1.534856 0.478577 -0.891543 0.029324 1 12.110836
2 1.311323 -1.372459 0.764496 0.928932 -1.241142 1 9.877393
3 1.466195 -0.765241 0.005756 0.330176 0.642096 1 18.909254
4 0.023509 -0.683962 0.236290 -0.505506 -0.537863 1 5.871040
... ... ... ... ... ... .. ...
9995 0.834235 0.794972 0.237262 1.201480 -0.378222 1 16.845116
9996 -1.058951 0.459315 -0.163190 0.868038 -0.446979 1 11.947050
9997 -0.311409 0.052510 0.813563 0.746291 0.290983 1 17.572916
9998 -0.367764 1.182595 0.147166 0.049880 -1.159016 1 8.550200
9999 0.344271 -1.622375 2.221930 2.081351 0.561735 1 19.165823
[10000 rows x 14 columns]
[18]:
from sklearn.ensemble import RandomForestRegressor
metalearner_estimate = model_experiment.estimate_effect(identified_estimand_experiment,
method_name="backdoor.econml.metalearners.TLearner",
confidence_intervals=False,
method_params={"init_params":{
'models': RandomForestRegressor()
},
"fit_params":{}
})
print(metalearner_estimate)
print("True causal estimate is", data_experiment["ate"])
*** Causal Estimate ***
## Identified estimand
Estimand type: EstimandType.NONPARAMETRIC_ATE
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(E[y|W4,W2,W0,W3,W1])
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W4,W2,W0,W3,W1,U) = P(y|v0,W4,W2,W0,W3,W1)
## Realized estimand
b: y~v0+X1+X0+X4+X2+X3+W4+W2+W0+W3+W1
Target units: ate
## Estimate
Mean value: 13.382504168673297
Effect estimates: [[ 9.93550584]
[11.40722635]
[11.04510408]
...
[18.26446741]
[ 8.32510091]
[20.3107527 ]]
True causal estimate is 10.921283595005184
Avoiding retraining the estimator#
Once an estimator is fitted, it can be reused to estimate effect on different data points. In this case, you can pass fit_estimator=False to estimate_effect. This works for any EconML estimator. We show an example for the T-learner below.
[19]:
# For metalearners, need to provide all the features (except treatmeant and outcome)
metalearner_estimate = model_experiment.estimate_effect(identified_estimand_experiment,
method_name="backdoor.econml.metalearners.TLearner",
confidence_intervals=False,
fit_estimator=False,
target_units=data_experiment["df"].drop(["v0","y", "Z0", "Z1"], axis=1)[9995:],
method_params={})
print(metalearner_estimate)
print("True causal estimate is", data_experiment["ate"])
*** Causal Estimate ***
## Identified estimand
Estimand type: EstimandType.NONPARAMETRIC_ATE
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(E[y|W4,W2,W0,W3,W1])
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W4,W2,W0,W3,W1,U) = P(y|v0,W4,W2,W0,W3,W1)
## Realized estimand
b: y~v0+X1+X0+X4+X2+X3+W4+W2+W0+W3+W1
Target units: Data subset provided as a data frame
## Estimate
Mean value: 15.28987664334279
Effect estimates: [[16.88637146]
[12.66269074]
[18.26446741]
[ 8.32510091]
[20.3107527 ]]
True causal estimate is 10.921283595005184
Refuting the estimate#
Adding a random common cause variable#
[20]:
res_random=model.refute_estimate(identified_estimand, dml_estimate, method_name="random_common_cause")
print(res_random)
Refute: Add a random common cause
Estimated effect:13.167128354479889
New effect:13.157760749130695
p value:0.84
Adding an unobserved common cause variable#
[21]:
res_unobserved=model.refute_estimate(identified_estimand, dml_estimate, method_name="add_unobserved_common_cause",
confounders_effect_on_treatment="linear", confounders_effect_on_outcome="linear",
effect_strength_on_treatment=0.01, effect_strength_on_outcome=0.02)
print(res_unobserved)
Refute: Add an Unobserved Common Cause
Estimated effect:13.167128354479889
New effect:13.150097157648123
Replacing treatment with a random (placebo) variable#
[22]:
res_placebo=model.refute_estimate(identified_estimand, dml_estimate,
method_name="placebo_treatment_refuter", placebo_type="permute",
num_simulations=10 # at least 100 is good, setting to 10 for speed
)
print(res_placebo)
Refute: Use a Placebo Treatment
Estimated effect:13.167128354479889
New effect:0.031705885126463565
p value:0.37153338198560326
Removing a random subset of the data#
[23]:
res_subset=model.refute_estimate(identified_estimand, dml_estimate,
method_name="data_subset_refuter", subset_fraction=0.8,
num_simulations=10)
print(res_subset)
Refute: Use a subset of data
Estimated effect:13.167128354479889
New effect:13.162692581469027
p value:0.4721241569582957
More refutation methods to come, especially specific to the CATE estimators.