{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# DoWhy: Different estimation methods for causal inference\n", "This is a quick introduction to the DoWhy causal inference library.\n", "We will load in a sample dataset and use different methods for estimating the causal effect of a (pre-specified)treatment variable on a (pre-specified) outcome variable.\n", "\n", "We will see that not all estimators return the correct effect for this dataset.\n", "\n", "First, let us add the required path for Python to find the DoWhy code and load all required packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import logging\n", "\n", "import dowhy\n", "from dowhy import CausalModel\n", "import dowhy.datasets " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let us load a dataset. For simplicity, we simulate a dataset with linear relationships between common causes and treatment, and common causes and outcome. \n", "\n", "Beta is the true causal effect. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Z0 | \n", "Z1 | \n", "W0 | \n", "W1 | \n", "W2 | \n", "W3 | \n", "W4 | \n", "v0 | \n", "y | \n", "
---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.0 | \n", "0.521013 | \n", "-1.325391 | \n", "2.116213 | \n", "-0.557359 | \n", "-1.865444 | \n", "1.605934 | \n", "False | \n", "10.352212 | \n", "
1 | \n", "0.0 | \n", "0.463769 | \n", "1.069812 | \n", "1.726722 | \n", "-0.610115 | \n", "0.350867 | \n", "0.882315 | \n", "True | \n", "19.499769 | \n", "
2 | \n", "0.0 | \n", "0.125765 | \n", "-1.945924 | \n", "1.484627 | \n", "-1.575032 | \n", "0.520069 | \n", "0.559467 | \n", "False | \n", "2.205429 | \n", "
3 | \n", "0.0 | \n", "0.861999 | \n", "-0.221328 | \n", "0.939332 | \n", "0.065349 | \n", "1.291141 | \n", "0.056167 | \n", "True | \n", "12.968877 | \n", "
4 | \n", "0.0 | \n", "0.893044 | \n", "-0.533583 | \n", "2.024348 | \n", "-0.227786 | \n", "-2.962155 | \n", "0.057679 | \n", "True | \n", "14.847690 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
9995 | \n", "0.0 | \n", "0.019696 | \n", "-0.518280 | \n", "2.516658 | \n", "-3.063094 | \n", "-1.035060 | \n", "1.231945 | \n", "False | \n", "7.816809 | \n", "
9996 | \n", "0.0 | \n", "0.692714 | \n", "-0.510823 | \n", "0.612004 | \n", "-0.758462 | \n", "-1.786669 | \n", "1.217035 | \n", "True | \n", "15.047879 | \n", "
9997 | \n", "0.0 | \n", "0.164993 | \n", "0.190624 | \n", "0.365532 | \n", "-0.312707 | \n", "0.794803 | \n", "2.458445 | \n", "True | \n", "21.472658 | \n", "
9998 | \n", "0.0 | \n", "0.382386 | \n", "-2.654195 | \n", "1.469548 | \n", "-2.191149 | \n", "-0.681429 | \n", "1.197488 | \n", "False | \n", "2.936543 | \n", "
9999 | \n", "0.0 | \n", "0.713687 | \n", "-0.799016 | \n", "0.810676 | \n", "-0.969194 | \n", "-0.467067 | \n", "1.327113 | \n", "True | \n", "15.621830 | \n", "
10000 rows × 9 columns
\n", "