{ "cells": [ { "cell_type": "markdown", "id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890", "metadata": {}, "source": [ "# TabPFN Estimator - Example Notebook\n", "\n", "TabPFN (Prior-Data Fitted Network) is a foundational model for tabular data, pre-trained on a large collection of synthetic datasets and designed to perform in-context learning at inference time. \n", "Unlike traditional machine learning models, TabPFN requires no hyperparameter tuning — the model adapts to each dataset entirely at inference, making it especially practical for causal effect estimation workflows where iterative model selection is costly.\n", "\n", "\n", "For causal estimation, TabPFN is used here as the outcome model in backdoor adjustment.\n", "Its competitive out-of-the-box performance on small tabular datasets — without requiring\n", "hyperparameter tuning — makes it a practical alternative to conventional regression models,\n", "especially when the functional form of the outcome is unknown.\n", "\n", "TabPFN is best suited for datasets with up to 10,000 samples and up to 500 features. \n", "It should not be used for large datasets, time-series data, or outcomes with more than 10 distinct classes. \n", "While TabPFN can run on CPU, GPU execution is recommended for larger ensemble sizes. \n", "\n", "## What We Demonstrate\n", "\n", "We walk through the standard DoWhy pipeline using `TabpfnEstimator` on a small synthetic\n", "dataset, and show how to configure the estimator for different compute environments:\n", "CPU, single GPU, and multi-GPU.\n", "\n", "Requires `pip install tabpfn torch`. \n", "TabPFN v2.5+ model weights are gated on HuggingFace\n", "— accept terms at https://huggingface.co/Prior-Labs/tabpfn_2_5 and authenticate via\n", "`huggingface-cli login` or the `HF_TOKEN` environment variable." ] }, { "cell_type": "code", "execution_count": null, "id": "defaf42a", "metadata": {}, "outputs": [], "source": [ "import importlib.util\n", "import os\n", "\n", "HAS_TABPFN = importlib.util.find_spec(\"tabpfn\") is not None\n", "HAS_TORCH = importlib.util.find_spec(\"torch\") is not None\n", "HAS_HF_TOKEN = bool(os.environ.get(\"HF_TOKEN\"))\n", "HAS_TABPFN_TOKEN = bool(os.environ.get(\"TABPFN_TOKEN\"))\n", "\n", "TABPFN_AVAILABLE = HAS_TABPFN and HAS_TORCH and (HAS_HF_TOKEN or HAS_TABPFN_TOKEN)\n", "\n", "if TABPFN_AVAILABLE:\n", " print(\"TabPFN prerequisites detected. Running TabPFN estimate cell.\")\n", "else:\n", " missing = []\n", " if not HAS_TABPFN:\n", " missing.append(\"tabpfn\")\n", " if not HAS_TORCH:\n", " missing.append(\"torch\")\n", " if not (HAS_HF_TOKEN or HAS_TABPFN_TOKEN):\n", " missing.append(\"HF_TOKEN or TABPFN_TOKEN\")\n", " print(\"Skipping TabPFN estimate cell; missing prerequisites: \" + \", \".join(missing))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b2c3d4e5-f6a7-8901-bcde-f12345678901", "metadata": {}, "outputs": [], "source": [ "import dowhy\n", "import dowhy.datasets\n", "from dowhy import CausalModel\n" ] }, { "cell_type": "markdown", "id": "c3d4e5f6-a7b8-9012-cdef-123456789012", "metadata": {}, "source": [ "## Generate Data\n", "\n", "We use a small synthetic dataset that is well within TabPFN's design constraints, with 500 samples and two common causes. \n", "The dataset has a binary treatment variable and a continuous outcome, which allows TabPFN to operate in regression mode automatically." ] }, { "cell_type": "code", "execution_count": null, "id": "d4e5f6a7-b8c9-0123-defa-234567890123", "metadata": {}, "outputs": [], "source": [ "data = dowhy.datasets.linear_dataset(\n", " beta=10,\n", " num_common_causes=2,\n", " num_instruments=0,\n", " num_treatments=1,\n", " num_samples=500, # increase this for on GPU mode \n", " treatment_is_binary=True,\n", " outcome_is_binary=False,\n", ")\n", "df = data[\"df\"]\n", "print(df.head())" ] }, { "cell_type": "markdown", "id": "e5f6a7b8-c9d0-1234-efab-345678901234", "metadata": {}, "source": [ "## Initialize Causal Models & Identify Estimand\n", "\n", "We construct a `CausalModel` from the synthetic dataset using the graph encoded in `data[\"gml_graph\"]`. \n", "DoWhy then analyzes the causal graph to produce an identified estimand, which formalizes which statistical quantity we need to estimate to recover the causal effect." ] }, { "cell_type": "code", "execution_count": null, "id": "f6a7b8c9-d0e1-2345-fabc-456789012345", "metadata": {}, "outputs": [], "source": [ "model = CausalModel(\n", " data=df,\n", " treatment=data[\"treatment_name\"],\n", " outcome=data[\"outcome_name\"],\n", " graph=data[\"gml_graph\"],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "a7b8c9d0-e1f2-3456-abcd-567890123456", "metadata": {}, "outputs": [], "source": [ "model.view_model()" ] }, { "cell_type": "code", "execution_count": null, "id": "b8c9d0e1-f2a3-4567-bcde-678901234567", "metadata": {}, "outputs": [], "source": [ "identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)\n", "print(identified_estimand)" ] }, { "cell_type": "markdown", "id": "c9d0e1f2-a3b4-5678-cdef-789012345678", "metadata": {}, "source": [ "## Estimate Effect\n", "TabPFN can run on CPU or GPU; the appropriate device is detected automatically based on availability. \n", "The key `method_params` differ by compute environment — choose the configuration that matches your available hardware. \n", "All three produce the same estimator; minor numerical differences may result from TabPFN's internal stochasticity, not from differences in the underlying estimation approach." ] }, { "cell_type": "markdown", "id": "d0e1f2a3-b4c5-6789-defa-890123456789", "metadata": {}, "source": [ "### CPU\n", "\n", "This is the default configuration. We use `n_estimators=4` to keep runtime reasonable on CPU while still producing a reliable estimate.\n", "If TabPFN prerequisites are unavailable in your environment, the estimate cell below will be skipped." ] }, { "cell_type": "code", "execution_count": null, "id": "e1f2a3b4-c5d6-7890-efab-901234567890", "metadata": {}, "outputs": [], "source": [ "if TABPFN_AVAILABLE:\n", " causal_estimate = model.estimate_effect(\n", " identified_estimand,\n", " method_name=\"backdoor.tabpfn\",\n", " method_params={\n", " \"model_type\": \"auto\",\n", " \"n_estimators\": 16,\n", " \"use_multi_gpu\": False,\n", " },\n", " )\n", " print(\"Causal Estimate is \" + str(causal_estimate.value))\n", "else:\n", " print(\"TabPFN estimation skipped in this environment.\")\n" ] }, { "cell_type": "markdown", "id": "f2a3b4c5-d6e7-8901-fabc-012345678901", "metadata": {}, "source": [ "### Single GPU\n", "\n", "When a single GPU is available, TabPFN automatically uses it via `torch.cuda`. Increase `n_estimators` to 8 or more for better accuracy when compute allows." ] }, { "cell_type": "code", "execution_count": null, "id": "a3b4c5d6-e7f8-9012-abcd-123456789012", "metadata": {}, "outputs": [], "source": [ "# Uncomment to run on a single GPU\n", "# causal_estimate = model.estimate_effect(\n", "# identified_estimand,\n", "# method_name=\"backdoor.tabpfn\",\n", "# method_params={\n", "# \"model_type\": \"auto\",\n", "# \"n_estimators\": 16,\n", "# \"use_multi_gpu\": False, # single GPU is picked up automatically\n", "# },\n", "# )\n", "# print(\"Causal Estimate is \" + str(causal_estimate.value))" ] }, { "cell_type": "markdown", "id": "b4c5d6e7-f8a9-0123-bcde-234567890123", "metadata": {}, "source": [ "### Multi-GPU\n", "\n", "With multiple GPUs, TabPFN splits the training data across devices using multiprocessing and averages the resulting predictions. Set `use_multi_gpu=True` to enable this mode; device IDs are auto-detected from `torch.cuda.device_count()`, or you can specify them explicitly with the `device_ids` parameter." ] }, { "cell_type": "code", "execution_count": null, "id": "c5d6e7f8-a9b0-1234-cdef-345678901234", "metadata": {}, "outputs": [], "source": [ "# Uncomment to run on multiple GPUs\n", "# causal_estimate = model.estimate_effect(\n", "# identified_estimand,\n", "# method_name=\"backdoor.tabpfn\",\n", "# method_params={\n", "# \"model_type\": \"auto\",\n", "# \"n_estimators\": 16,\n", "# \"use_multi_gpu\": True,\n", "# # \"device_ids\": [0, 1], # optional: specify GPU indices explicitly\n", "# },\n", "# )\n", "# print(\"Causal Estimate is \" + str(causal_estimate.value))" ] }, { "cell_type": "markdown", "id": "d6e7f8a9-b0c1-2345-defa-456789012345", "metadata": {}, "source": [ "## References\n", "\n", "1. Hollmann, N., Müller, S., Eggensperger, K., and Hutter, F.\n", " TabPFN: A Prior-Data Fitted Network for Tabular Data.\n", " ICLR (2023). https://arxiv.org/abs/2207.01848\n", "\n", "2. Hollmann, N., Müller, S., Purucker, L., Krishnakumar, A., Körner, M., Hoo, R., Schirrmeister, R. T., and Hutter, F.\n", " Accurate predictions on small data with a tabular foundation model.\n", " Nature (2025). https://doi.org/10.1038/s41586-024-08328-6" ] } ], "metadata": { "kernelspec": { "display_name": "dowhy-dev", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.19" } }, "nbformat": 4, "nbformat_minor": 5 }