import torch.nn as nn
import torch
import torch.nn.functional as F
import math
import torch.utils.checkpoint as ck

def gelu_new(x):
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

def _init_weights(module):
    """Initialize the weights."""
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if module.bias is not None:
            module.bias.data.zero_()
    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=0.02)
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)

class HyperNetworkGRU(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        self.linear1 = nn.Linear(embed_dim, embed_dim//8)
        self.gru = nn.GRU(embed_dim//8, embed_dim // 8, num_layers=1, bidirectional=False, batch_first=True)
        self.linear2 = nn.Linear(embed_dim // 8, embed_dim)
        self.ln_1 = nn.LayerNorm(embed_dim // 8, eps=1e-5)
        self.activation = gelu_new

        for module in self.modules():
            _init_weights(module)

        for param in self.linear2.parameters():
            param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))

        for param in self.gru.parameters():
            param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))

    def forward(self, x):
        x = x.float()
        x = self.linear1(x)
        x = self.gru(x)[0]
        x = self.ln_1(x)
        x = self.linear2(x)
        x = ck(self.activation, x)
        return x.bfloat16()

class HyperNetwork(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        self.linear = nn.Linear(embed_dim, embed_dim//4, bias=True)
        self.linear2 = nn.Linear(embed_dim//4, embed_dim, bias=True)
        self.activation = gelu_new
        self.num_shifts = ceil(log2(2048)) - 1
        #self.linear.weight.data.normal_(mean=0.0, std=0.02)
        for module in self.modules():
            _init_weights(module)

        for param in self.linear2.parameters():
            param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
        #state = self.state_dict()
        #for k in state:
        #    state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
        #self.load_state_dict(state)

    def forward(self, x):
        x = x.float()
        #x = shift_tokens(x, self.num_shifts)
        x = self.linear(x)
        x = ck(self.activation, x)
        x = self.linear2(x)
        x = x.mul(torch.sigmoid(x))
        return x.bfloat16()

class HyperNetworkSingle(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
        self.activation = gelu_new
        #self.linear.weight.data.normal_(mean=0.0, std=0.02)
        for module in self.modules():
            _init_weights(module)

        for param in self.linear.parameters():
            param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
        #state = self.state_dict()
        #for k in state:
        #    state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
        #self.load_state_dict(state)

    def forward(self, x):
        x = x.float()
        #x = shift_tokens(x, self.num_shifts)
        x = self.linear(x)
        x = x.mul(torch.sigmoid(x))
        return x.bfloat16()