{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Conditional Average Treatment Effects (CATE) with DoWhy and EconML\n", "\n", "This is an experimental feature where we use [EconML](https://github.com/microsoft/econml) methods from DoWhy. Using EconML allows CATE estimation using different methods. \n", "\n", "All four steps of causal inference in DoWhy remain the same: model, identify, estimate, and refute. The key difference is that we now call econml methods in the estimation step. There is also a simpler example using linear regression to understand the intuition behind CATE estimators. \n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os, sys\n", "sys.path.insert(1, os.path.abspath(\"../../../\")) # for dowhy source code" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import logging\n", "\n", "import dowhy\n", "from dowhy import CausalModel\n", "import dowhy.datasets\n", "\n", "import econml\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | X0 | \n", "X1 | \n", "Z0 | \n", "Z1 | \n", "W0 | \n", "W1 | \n", "W2 | \n", "W3 | \n", "v0 | \n", "y | \n", "
---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.279018 | \n", "-0.955066 | \n", "0.0 | \n", "0.228005 | \n", "2.128748 | \n", "-0.064190 | \n", "2 | \n", "1 | \n", "14.761720 | \n", "122.169174 | \n", "
1 | \n", "-0.284234 | \n", "1.382251 | \n", "1.0 | \n", "0.396490 | \n", "0.175366 | \n", "-0.115139 | \n", "1 | \n", "2 | \n", "23.566119 | \n", "345.906350 | \n", "
2 | \n", "1.411501 | \n", "2.756870 | \n", "0.0 | \n", "0.765155 | \n", "0.649999 | \n", "0.861213 | \n", "2 | \n", "3 | \n", "24.094062 | \n", "477.367085 | \n", "
3 | \n", "-0.149004 | \n", "1.887996 | \n", "0.0 | \n", "0.484696 | \n", "0.693805 | \n", "0.009433 | \n", "0 | \n", "0 | \n", "7.781959 | \n", "124.409057 | \n", "
4 | \n", "-1.033017 | \n", "1.886337 | \n", "0.0 | \n", "0.989248 | \n", "-1.287000 | \n", "-0.845054 | \n", "3 | \n", "2 | \n", "21.967755 | \n", "356.001269 | \n", "