{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Conditional Average Treatment Effects (CATE) with DoWhy and EconML\n", "\n", "This is an experimental feature where we use [EconML](https://github.com/microsoft/econml) methods from DoWhy. Using EconML allows CATE estimation using different methods. \n", "\n", "All four steps of causal inference in DoWhy remain the same: model, identify, estimate, and refute. The key difference is that we now call econml methods in the estimation step. There is also a simpler example using linear regression to understand the intuition behind CATE estimators. \n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os, sys\n", "sys.path.insert(1, os.path.abspath(\"../../../\")) # for dowhy source code" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import logging\n", "\n", "import dowhy\n", "from dowhy import CausalModel\n", "import dowhy.datasets\n", "\n", "import econml\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | X0 | \n", "X1 | \n", "Z0 | \n", "Z1 | \n", "W0 | \n", "W1 | \n", "W2 | \n", "W3 | \n", "v0 | \n", "y | \n", "
---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "-0.423069 | \n", "-0.783912 | \n", "1.0 | \n", "0.595267 | \n", "-1.719039 | \n", "1.133622 | \n", "2.154847 | \n", "2.093053 | \n", "24.336479 | \n", "168.434975 | \n", "
1 | \n", "-0.671984 | \n", "-0.517552 | \n", "1.0 | \n", "0.840097 | \n", "-0.871639 | \n", "1.240345 | \n", "2.513988 | \n", "1.144102 | \n", "32.005832 | \n", "244.808533 | \n", "
2 | \n", "-0.408754 | \n", "-1.224234 | \n", "1.0 | \n", "0.841418 | \n", "-2.992127 | \n", "1.947631 | \n", "1.178079 | \n", "0.364007 | \n", "20.560389 | \n", "105.083930 | \n", "
3 | \n", "-1.797735 | \n", "-0.810235 | \n", "1.0 | \n", "0.469800 | \n", "-1.097293 | \n", "-0.831213 | \n", "0.490426 | \n", "-0.229491 | \n", "8.713406 | \n", "34.353637 | \n", "
4 | \n", "-0.579162 | \n", "-3.119641 | \n", "1.0 | \n", "0.691993 | \n", "-0.117981 | \n", "0.009197 | \n", "0.169698 | \n", "2.553771 | \n", "17.375187 | \n", "-28.550935 | \n", "