import torch
import torch.nn as nn

# Regression targets: [amplitude, time]
WINDOW_SIZE = 9
OUT_DIM = 2


class MLP(nn.Module):
    def __init__(self, input_dim, out_dim=OUT_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64), # 64
            nn.ReLU(inplace=True),
            nn.Linear(64, 64), # 64,64
            nn.ReLU(inplace=True),
            nn.Linear(64, out_dim), #  64
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)


class CNN1D(nn.Module):
    def __init__(self, input_dim, out_dim=OUT_DIM):
        super().__init__()
        self.conv1 = nn.Conv1d(
            in_channels=1,
            out_channels=16,
            kernel_size=5,
            padding=2
        )
        self.act1  = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool1d(kernel_size=2)

        self.conv2 = nn.Conv1d(
            in_channels=16,
            out_channels=32,
            kernel_size=5,
            padding=2
        )
        self.act2  = nn.ReLU(inplace=True)

        self.fc1   = nn.Linear(32, 64)
        self.act3  = nn.ReLU(inplace=True)
        self.out   = nn.Linear(64, out_dim)
        self.out_act = nn.Sigmoid()

    def forward(self, x):
        x = self.pool1(self.act1(self.conv1(x)))  # (N,16,W/2)
        x = self.act2(self.conv2(x))              # (N,32,W/2)
        x = torch.mean(x, dim=-1)                 # Global avg pool
        x = self.act3(self.fc1(x))
        return self.out_act(self.out(x))
