Source code for dowhy.causal_prediction.algorithms.erm

import torch
from torch import nn
from torch.nn import functional as F

from dowhy.causal_prediction.algorithms.base_algorithm import PredictionAlgorithm


[docs]class ERM(PredictionAlgorithm): def __init__(self, model, optimizer="Adam", lr=1e-3, weight_decay=0.0, betas=(0.9, 0.999), momentum=0.9): super().__init__(model, optimizer, lr, weight_decay, betas, momentum) """Class for ERM Algorithm. :param model: Networks used for training. `model` type expected is torch.nn.Sequential(featurizer, classifier) where featurizer and classifier are of type torch.nn.Module. :param optimizer: Optimization algorithm used for training. Currently supports "Adam" and "SGD". :param lr: learning rate for ERM :param weight_decay: Value of weight decay for optimizer :param betas: Adam configuration parameters (beta1, beta2), exponential decay rate for the first moment and second-moment estimates, respectively. :param momentum: Value of momentum for SGD optimzer :returns: an instance of PredictionAlgorithm class """
[docs] def training_step(self, train_batch, batch_idx): """ Override `training_step` from PredictionAlgorithm class for ERM-specific training loop. """ x = torch.cat([x for x, y, _ in train_batch]) y = torch.cat([y for x, y, _ in train_batch]) out = self.model(x) loss = F.cross_entropy(out, y) acc = (torch.argmax(out, dim=1) == y).float().mean() metrics = {"train_acc": acc, "train_loss": loss} self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True) return loss