import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import wandb
from torch.utils.checkpoint import checkpoint as ck
from basedformer import optimizer, lm_utils, dataset
from basedformer.utils import *
from transformers import AutoTokenizer
from basedformer import sampling
from termcolor import colored
import argparse

gpu = "cuda"
amp = torch.cuda.amp
if gpu != "cuda":
    amp = torch.amp
scaler = torch.cuda.amp.GradScaler()

prompts = ["<|endoftext|>",
           "The year was",
           "I grabbed my",
           "She lifted the",
           "He was known as the",
           "The tavern was full again, so I ended up sharing a table with three very different creatures: a",
           "I had been hiking in the wilderness when suddenly a",
           "She spread her",
           "The mercurial and beautiful woman laughed",
           "[ Author:",
           "[ Tags:",
           "***"]

def _init_weights(module):
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if module.bias is not None:
            module.bias.data.zero_()

    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=0.02)

    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)


def shift(x, amt, dim=-1):
    return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value=0.)


class HyperNetworkGRU(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        self.linear1 = nn.Linear(embed_dim, embed_dim // 8)
        self.gru = nn.GRU(embed_dim // 8,
                          embed_dim // 8,
                          num_layers=1,
                          bidirectional=False,
                          batch_first=True)
        self.linear2 = nn.Linear(embed_dim // 8, embed_dim)
        self.ln_1 = nn.LayerNorm(embed_dim // 8, eps=1e-5)
        self.activation = gelu_new

        for module in self.modules():
            _init_weights(module)

        for param in self.linear2.parameters():
            param.data.normal_(mean=0.0,
                               std=(0.02 / math.sqrt(2 * config["n_layer"])))

        for param in self.gru.parameters():
            param.data.normal_(mean=0.0,
                               std=(0.02 / math.sqrt(2 * config["n_layer"])))

    def forward(self, x):
        return ck(self.activation,
                  self.linear2(
                      self.ln_1(
                          self.gru(
                              self.linear1(
                                  x.float()))[0]))).bfloat16()


class HyperNetwork(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        self.linear = nn.Linear(embed_dim, embed_dim // 4, bias=True)
        self.linear2 = nn.Linear(embed_dim // 4, embed_dim, bias=True)
        self.activation = torch.nn.functional.gelu
        for module in self.modules():
            _init_weights(module)

        for param in self.linear2.parameters():
            param.data.normal_(mean=0.0,
                               std=(0.02 / math.sqrt(2 * config["n_layer"])))

    def forward(self, x):
        x = self.linear2(
            ck(self.activation,
               self.linear(x.float())))
        return x.mul(torch.sigmoid(x)).bfloat16()


class HyperNetworkSingle(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
        self.activation = gelu_new
        # self.linear.weight.data.normal_(mean=0.0, std=0.02)
        for module in self.modules():
            _init_weights(module)

        for param in self.linear.parameters():
            param.data.normal_(mean=0.0,
                               std=(0.02 / math.sqrt(2 * config["n_layer"])))
        # state = self.state_dict()
        # for k in state:
        #    state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
        # self.load_state_dict(state)

    def forward(self, x):
        x = self.linear(x.float())
        return x.mul(torch.sigmoid(x)).bfloat16()


tokenizer = AutoTokenizer.from_pretrained('gpt2')


@torch.no_grad()
def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0):
    torch.seed()
    tokens = tokenizer.encode(prompt)
    tokens = torch.LongTensor(tokens).unsqueeze(0).to(gpu)
    tokens = [tokens] * bsz
    tokens = torch.cat(tokens, dim=0)

    rep_pen = {
        "penalty": 3,
    }

    ops = {
        "rep_pen": rep_pen,
        "tfs": 0.8,
        "temp": 0.8,
    }
    ops_list = [ops] * bsz
    tokens_generated = sampling.generate(model.forward,
                                         tokens,
                                         n_tokens,
                                         ops_list=ops_list,
                                         hypernetwork=hypernetwork,
                                         non_deterministic=True)
    vanilla_tokens_generated = sampling.generate(model.forward,
                                                 tokens,
                                                 n_tokens,
                                                 ops_list=ops_list,
                                                 hypernetwork=None)
    tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
    vanilla_tokens_generated = tokenizer.batch_decode(
        vanilla_tokens_generated.cpu().numpy())
    data = []
    for x in range(len(tokens_generated)):
        data.append([step,
                     prompt,
                     str(tokens_generated[x]),
                     str(vanilla_tokens_generated[x])])

    return data


def report_wandb(data):
    columns = ["Step", "Prompt", "Generated Text", "Vanilla Model"]
    wandb.log({"Generations": wandb.Table(data=data, columns=columns)})


def report_console(data):
    for gen in data[2]:
        print(colored("======================================================",
                      "red"))
        print(colored(gen, "green"))
    print(colored("======================================================",
                  "red"))

def make_hypernet_saver(train_config, hypernetwork):
    def hypernet_saver(id: str):
        save_folder = Path(train_config["save_path"]) / id
        save_folder.mkdir(parents=True, exist_ok=True)
        torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
        opt.save(save_folder / "opt")
    return hypernet_saver

parser = argparse.ArgumentParser(description='Hypernetwork Finetuner')
parser.add_argument('--run_name', type=str, help='the run name to use',
                    required=True)
parser.add_argument('--model', type=str, help='the model to train against',
                    required=True)
parser.add_argument('--dataset', type=str, help='pre-tokenized dataset to use',
                    required=True)
parser.add_argument("--output", type=str, help='output path',
                    default='')
parser.add_argument('--optimizer', type=str, help='the optimizer to use',
                    default='adamw')
parser.add_argument('--lr', type=float, help='learning rate', default=2e-4)
parser.add_argument('--end_lr', type=float, help='end learning rate',
                    default=2e-4)
parser.add_argument('--warmup', type=int, help='warmup steps')
parser.add_argument('--bs', type=int, help='batch size', default=4)
parser.add_argument('--gas', type=int, help='gas', default=1)
parser.add_argument('--seed', type=int, help="Random seed value",
                    default=42)
parser.add_argument("--save_steps", type=int,
                    help='# of steps between checkpoint saves',
                    default=300)
parser.add_argument("--amp", type=bool, help='enable amp', default=False)
parser.add_argument('--loss_scale', type=bool, help='whether to scale loss',
                    default=False)
parser.add_argument("--eval_every", type=int,
                    help='evaluate hypernetwork every x steps',
                    default=100)
parser.add_argument('--output_path', type=str, help="Root path of all output",
                    default="./")
parser.add_argument('--no_resume', type=bool, default=False,
                    help="Do not resume from last checkpoint")
parser.add_argument("--context_size", type=int, help="Dataset context sizes",
                    default=2048)
parser.add_argument("--project_id", type=str, help="Project ID for reporting",
                    default="hypernetwork-training")
parser.add_argument("--logs", type=str, help="log directory location",
                    default="./logs")
parser.add_argument("--masked", type=bool, help="masked softmax fusion")
parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False)
args = parser.parse_args()
if args.output == '':
    args.output = f'./{args.run_name}'

# we need 250 batch size to train the small GPT.
train_config = {
    "data_path": args.dataset,
    "save_path": args.model,
    "lm_path": args.model,
    "optimizer": args.optimizer,
    "masked_softmax_fusion": args.masked,
    "do_save": args.save_steps != 0,
    "run_name": args.run_name,
    "lr": args.lr,
    "end_lr": args.end_lr,
    "warmup_steps": args.warmup,
    "bs": args.bs,
    "gas": args.gas,
    "seed": args.seed,
    "save_every": args.save_steps0,
    "amp": args.amp,
    "loss_scale": args.loss_scale,
    "eval_every": args.eval_every,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
gas = train_config["gas"]

Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)

model = lm_utils.load_from_path("pretrained/sigurdv4").to(gpu).bfloat16()
for param in model.parameters():
    param.requires_grad = False

for name, p in model.named_parameters():
    if ("ln" in name or "vocab_embed" in name):
        p.requires_grad = True

hypernetwork = HyperNetworkSingle(model.config).to(gpu).float()
for param in hypernetwork.parameters():
    param.requires_grad = True
hypernetwork_saver = make_hypernet_saver(train_config, hypernetwork)

cp_list = sorted(os.listdir(train_config["save_path"]),
                 key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(
    cp_list) > 0 else None
print(last_cp)

if last_cp and not args.no_resume:
    print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
    hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
    opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(),
                                        last_cp / "opt")

else:
    opt = optimizer.BasedOptimizer(hypernetwork.parameters(),
                                   train_config,
                                   train_config["optimizer"])

# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden
#       states from the get_logits function.
print(opt.curr_step)
train_dataset = dataset.ShardedDataset(2049, train_config["data_path"])
if last_cp:
    train_dataset.skip = opt.curr_step * bs * gas

train_loader = data.DataLoader(train_dataset,
                               batch_size=bs * gas,
                               shuffle=False,
                               num_workers=0)
wandb.init(project="hypernetwork-tests",
           name=train_config["run_name"],
           config={**train_config, **model.config})

if last_cp:
    curr_step = opt.curr_step
else:
    curr_step = 0

t = tqdm(train_loader, initial=curr_step)
sample_data = []

for input_ids, labels in t:
    timex = time.perf_counter()
    input_ids = input_ids.to(gpu)
    labels = labels.to(gpu)
    loss = 0
    for x in range(train_config["gas"]):
        with amp.autocast(enabled=train_config["amp"],
                          dtype=torch.float16):
            logits, _ = model(input_ids[x * bs:(x + 1) * bs, :].to(gpu),
                              hypernetwork=hypernetwork,
                              act_ck=True)
            # print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
            logits = logits.view(-1, logits.shape[-1])
            gas_labels = labels[x * bs:(x + 1) * bs, :].contiguous()
            gas_labels = gas_labels.view(-1)
            gas_loss = F.cross_entropy(logits, gas_labels)

        if train_config["loss_scale"]:
            scaler.scale(gas_loss).backward()
        else:
            gas_loss.backward()

        loss += gas_loss.item()

    loss = loss / gas
    if train_config["loss_scale"]:
        scaler.unscale_(opt.optimizer)
    torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 1)
    if train_config["loss_scale"]:
        opt.step(scaler=scaler)
    else:
        opt.step()

    if train_config["loss_scale"]:
        scaler.update()

    opt.zero_grad()
    sec_per_step = (time.perf_counter() - timex)
    step_per_sec = (1. / sec_per_step)
    tokens_per_sec = (step_per_sec * 2048) * bs * gas
    t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step,"
                      + f"{tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
    wandb.log(
        {
            "train/loss": loss,
            "train/tokens_per_sec": tokens_per_sec,
            "train/sec_per_step": sec_per_step,
            "train/step_per_sec": step_per_sec,
            "train/lr": opt.curr_lr,
            "train/loss_scale": scaler.get_scale()
        },
        step=curr_step)

    if train_config["do_save"] and \
            curr_step % train_config["save_every"] == 0 and \
            curr_step != 0:
        hypernetwork_saver(f"step_{curr_step}")
        print(f"\nSaved model at step {curr_step}")

    if curr_step % train_config["eval_every"] == 0 and curr_step != 0:
        for prompt in prompts:
            sampled = sample(prompt, 500, 3, hypernetwork=hypernetwork)
            print(f"PROMPT:\n{prompt}")
            report_console(sampled)
            sample_data = sample_data + sampled
        report_wandb(sample_data)

    curr_step += 1

hypernetwork_saver("final")