from re import A
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from torch.utils import data
import math
import sys
from tqdm import tqdm
import time
import wandb
import numpy as np
from torch.utils.checkpoint import checkpoint as ck
from math import log2, ceil
from basedformer import gptj, optimizer
from basedformer.utils import *
import glob

def _init_weights(module):
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if module.bias is not None:
            module.bias.data.zero_()

    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=0.02)

    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)

def shift_tokens(x, amt, eps = 1e-5):
    n, device = x.shape[1], x.device

    cumsum = x.cumsum(dim = 1)
    *x, x_pass = x.chunk(amt + 1, dim = -1)
    *x_cumsum, _ = cumsum.chunk(amt + 1, dim = -1)

    amts = 2 ** torch.arange(amt)
    amts = amts.tolist()

    shifts = []
    denom = torch.arange(n, device = device)

    for x_chunk, x_cumsum_chunk, amt in zip(x, x_cumsum, amts):
        shifted_chunk = shift(x_cumsum_chunk, amt, dim = -2) - shift(x_cumsum_chunk, 2 * amt, dim = -2)
        shifted_denom = shift(denom, amt, dim = -1) - shift(denom, 2 * amt, dim = -1)
        shifted_denom = rearrange(shifted_denom, 'n -> () n ()')
        normed_shifted_x = shifted_chunk /  (shifted_denom + eps)
        shifts.append(normed_shifted_x)

    return torch.cat((*shifts, x_pass), dim = -1)

def discounted_cumsum(t, gamma):
    try:
        from torch_discounted_cumsum import discounted_cumsum_left
    except ImportError:
        print('unable to import torch_discounted_cumsum - please run `pip install torch-discounted-cumsum`')

    b, n, d = t.shape
    t = rearrange(t, 'b n d -> (b d) n')
    t = discounted_cumsum_left(t, gamma)
    t = rearrange(t, '(b d) n -> b n d', b = b)
    return t

def shift(x, amt, dim = -1):
    return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value = 0.)

class HyperNetworkGRU(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        self.linear1 = nn.Linear(embed_dim, embed_dim//8)
        self.gru = nn.GRU(embed_dim//8, embed_dim // 8, num_layers=1, bidirectional=False, batch_first=True)
        self.linear2 = nn.Linear(embed_dim // 8, embed_dim)
        self.ln_1 = nn.LayerNorm(embed_dim // 8, eps=1e-5)
        self.activation = gelu_new

        for module in self.modules():
            _init_weights(module)

        for param in self.linear2.parameters():
            param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))

        for param in self.gru.parameters():
            param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))

    def forward(self, x):
        x = x.float()
        x = self.linear1(x)
        x = self.gru(x)[0]
        x = self.ln_1(x)
        x = self.linear2(x)
        x = ck(self.activation, x)
        return x.bfloat16()

class HyperNetwork(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        self.linear = nn.Linear(embed_dim, embed_dim//4, bias=True)
        self.linear2 = nn.Linear(embed_dim//4, embed_dim, bias=True)
        self.activation = gelu_new
        self.num_shifts = ceil(log2(2048)) - 1
        #self.linear.weight.data.normal_(mean=0.0, std=0.02)
        for module in self.modules():
            _init_weights(module)

        for param in self.linear2.parameters():
            param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
        #state = self.state_dict()
        #for k in state:
        #    state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
        #self.load_state_dict(state)

    def forward(self, x):
        x = x.float()
        #x = shift_tokens(x, self.num_shifts)
        x = self.linear(x)
        x = ck(self.activation, x)
        x = self.linear2(x)
        x = x.mul(torch.sigmoid(x))
        return x.bfloat16()

class HyperNetworkSingle(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
        self.activation = gelu_new
        #self.linear.weight.data.normal_(mean=0.0, std=0.02)
        for module in self.modules():
            _init_weights(module)

        for param in self.linear.parameters():
            param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
        #state = self.state_dict()
        #for k in state:
        #    state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
        #self.load_state_dict(state)

    def forward(self, x):
        x = x.float()
        #x = shift_tokens(x, self.num_shifts)
        x = self.linear(x)
        x = x.mul(torch.sigmoid(x))
        return x.bfloat16()

model_config = {
    "n_layer": 28,
    "n_head": 16,
    "hidden_dim": 4096,
    "vocab_dim": 50400,
    "eps": 1e-5,
}

# we need 250 batch size to train the small GPT.
train_config = {
    "data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
    #"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
    #"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
    "save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-bs16-save",
    "do_save": True,
    "run_name": "gpt-j-enwik9-6b-postln-bf16-2e-4-4bsz-every5layersavetest",
    "lr": 2e-4,
    "end_lr": 2e-4,
    "warmup_steps": 50,
    "bs": 1,
    "gas": 4,
    "seed": 69,
    "save_every": 300,
    "amp": False,
    "loss_scale": False,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
gas = train_config["gas"]

Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)

#model = GPTModel.gpt2_init(model_config).cuda().float()
model = gptj.load_gpt_j().lm.cuda().bfloat16()
for param in model.parameters():
    param.requires_grad = False

for name, p in model.named_parameters():
    if ("ln" in name or "vocab_embed" in name):
        p.requires_grad = True

hypernetwork = HyperNetworkSingle(model_config).cuda().float()
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)])
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(2)])
for param in hypernetwork.parameters():
    param.requires_grad = True

cp_list = sorted(os.listdir(train_config["save_path"]), key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(cp_list) > 0 else None
print(last_cp)

if last_cp:
    print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
    hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
    opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(), last_cp / "opt")

else:
    opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, "adamw")

# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
train_dataset = FbDataset(2049, train_config["data_path"])
if last_cp:
    train_dataset.skip = opt.curr_step * bs * gas
    
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
wandb.init(project="hypernetwork-tests", name=train_config["run_name"], config={**train_config, **model_config})

if last_cp:
    curr_step = opt.curr_step
else:
    curr_step = 0

t = tqdm(train_loader, initial=curr_step)

scaler = torch.cuda.amp.GradScaler()

for input_ids, labels in t:
    timex = time.perf_counter()
    input_ids = input_ids.cuda()
    labels = labels.cuda()
    loss = 0
    for x in range(train_config["gas"]):
        with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
            logits = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=hypernetwork, act_ck=True)
            #print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
            logits = logits.view(-1, logits.shape[-1])
            gas_labels = labels[x*bs:(x+1)*bs, :].contiguous()
            gas_labels = gas_labels.view(-1)
            gas_loss = F.cross_entropy(logits, gas_labels)

        if train_config["loss_scale"]:
            scaler.scale(gas_loss).backward()
        else:
            gas_loss.backward()

        loss += gas_loss.item()

    loss = loss / gas
    if train_config["loss_scale"]:
        scaler.unscale_(opt.optimizer)
    torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 1)
    if train_config["loss_scale"]:
        opt.step(scaler=scaler)
    else:
        opt.step()

    if train_config["loss_scale"]:
        scaler.update()

    opt.zero_grad()
    sec_per_step = (time.perf_counter() - timex)
    step_per_sec = (1. / sec_per_step)
    tokens_per_sec = (step_per_sec * 2048) * bs * gas
    t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
    wandb.log(
        {
            "train/loss": loss,
            "train/tokens_per_sec": tokens_per_sec,
            "train/sec_per_step": sec_per_step,
            "train/step_per_sec": step_per_sec, 
            "train/lr": opt.curr_lr, 
            "train/loss_scale": scaler.get_scale()
        },
        step=curr_step)

    if train_config["do_save"]:
        if curr_step % train_config["save_every"] == 0 and curr_step != 0:
            save_folder = Path(train_config["save_path"]) / f"step_{curr_step}"
            save_folder.mkdir(parents=True, exist_ok=True)
            torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
            opt.save(save_folder / "opt")
            print(f"Saved model at step {curr_step}")
            sys.exit(0)

    curr_step += 1
            