# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
""" 
    The MNIST_MLP architecture is borrowed from OoD-Bench:
        @inproceedings{ye2022ood,
         title={OoD-Bench: Quantifying and Understanding Two Dimensions of Out-of-Distribution Generalization},
         author={Ye, Nanyang and Li, Kaican and Bai, Haoyue and Yu, Runpeng and Hong, Lanqing and Zhou, Fengwei and Li, Zhenguo and Zhu, Jun},
         booktitle={CVPR},
         year={2022}
        }
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models
[docs]class Identity(nn.Module):
    """An identity layer"""
    def __init__(self):
        super(Identity, self).__init__()
[docs]    def forward(self, x):
        return x  
[docs]class MLP(nn.Module):
    """Just  an MLP"""
    def __init__(self, n_inputs, n_outputs, mlp_width, mlp_depth, mlp_dropout):
        super(MLP, self).__init__()
        self.input = nn.Linear(n_inputs, mlp_width)
        self.dropout = nn.Dropout(mlp_dropout)
        self.hiddens = nn.ModuleList([nn.Linear(mlp_width, mlp_width) for _ in range(mlp_depth - 2)])
        self.output = nn.Linear(mlp_width, n_outputs)
        self.n_outputs = n_outputs
[docs]    def forward(self, x):
        x = self.input(x)
        x = self.dropout(x)
        x = F.relu(x)
        for hidden in self.hiddens:
            x = hidden(x)
            x = self.dropout(x)
            x = F.relu(x)
        x = self.output(x)
        return x  
[docs]class MNIST_MLP(nn.Module):
    def __init__(self, input_shape):
        super(MNIST_MLP, self).__init__()
        self.hdim = hdim = 390
        self.encoder = nn.Sequential(
            nn.Linear(input_shape[0] * input_shape[1] * input_shape[2], hdim),
            nn.ReLU(True),
            nn.Linear(hdim, hdim),
            nn.ReLU(True),
        )
        self.n_outputs = hdim
        for m in self.encoder:
            if isinstance(m, nn.Linear):
                gain = nn.init.calculate_gain("relu")
                nn.init.xavier_uniform_(m.weight, gain=gain)
                nn.init.zeros_(m.bias)
[docs]    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.encoder(x)  
[docs]class ResNet(torch.nn.Module):
    """ResNet with the softmax chopped off and the batchnorm frozen"""
    def __init__(self, input_shape, resnet18=True, resnet_dropout=0.0):
        super(ResNet, self).__init__()
        if resnet18:
            self.network = torchvision.models.resnet18(pretrained=True)
            self.n_outputs = 512
        else:
            self.network = torchvision.models.resnet50(pretrained=True)
            self.n_outputs = 2048
        # adapt number of channels
        nc = input_shape[0]
        if nc != 3:
            tmp = self.network.conv1.weight.data.clone()
            self.network.conv1 = nn.Conv2d(nc, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            for i in range(nc):
                self.network.conv1.weight.data[:, i, :, :] = tmp[:, i % 3, :, :]
        # save memory
        del self.network.fc
        self.network.fc = Identity()
        self.freeze_bn()
        self.dropout = nn.Dropout(resnet_dropout)
[docs]    def forward(self, x):
        """Encode x into a feature vector of size n_outputs."""
        return self.dropout(self.network(x)) 
[docs]    def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        """
        super().train(mode)
        self.freeze_bn() 
[docs]    def freeze_bn(self):
        for m in self.network.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()  
[docs]class MNIST_CNN(nn.Module):
    """
    Hand-tuned architecture for MNIST.
    Weirdness I've noticed so far with this architecture:
    - adding a linear layer after the mean-pool in features hurts
        RotatedMNIST-100 generalization severely.
    """
    n_outputs = 128
    def __init__(self, input_shape):
        super(MNIST_CNN, self).__init__()
        self.conv1 = nn.Conv2d(input_shape[0], 64, 3, 1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1)
        self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1)
        self.bn0 = nn.GroupNorm(8, 64)
        self.bn1 = nn.GroupNorm(8, 128)
        self.bn2 = nn.GroupNorm(8, 128)
        self.bn3 = nn.GroupNorm(8, 128)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
[docs]    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.bn0(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.bn1(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.bn2(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.bn3(x)
        x = self.avgpool(x)
        x = x.view(len(x), -1)
        return x  
[docs]class ContextNet(nn.Module):
    def __init__(self, input_shape):
        super(ContextNet, self).__init__()
        # Keep same dimensions
        padding = (5 - 1) // 2
        self.context_net = nn.Sequential(
            nn.Conv2d(input_shape[0], 64, 5, padding=padding),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 5, padding=padding),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 1, 5, padding=padding),
        )
[docs]    def forward(self, x):
        return self.context_net(x)  
[docs]def Classifier(in_features, out_features, is_nonlinear=False):
    if is_nonlinear:
        return torch.nn.Sequential(
            torch.nn.Linear(in_features, in_features // 2),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features // 2, in_features // 4),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features // 4, out_features),
        )
    else:
        return torch.nn.Linear(in_features, out_features)