import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.utils import data as torch_data
from torch.utils.checkpoint import checkpoint as ck
from basedformer import optimizer, lm_utils, dataset
from basedformer.utils import *
from transformers import AutoTokenizer
from basedformer import sampling
from termcolor import colored
from typing import Callable, List
import argparse

gpu = "cuda"
amp = torch.cuda.amp
if gpu != "cuda":
    amp = torch.amp
scaler = torch.cuda.amp.GradScaler()

prompts = ["<|endoftext|>",
           " The year was",
           " I grabbed my",
           " He was known as the",
           " The tavern was full again, so I ended up sharing a table with three very different creatures: a",
           " She spread her",
           " The mercurial and beautiful",
           "<|endoftext|>[ Author:",
           "***",
           "----",
           "> You look around.\n",
           "John:",
           "Jane:"]


def _init_weights(module):
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if module.bias is not None:
            module.bias.data.zero_()

    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=0.02)

    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)


def shift(x, amt, dim=-1):
    return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value=0.)


class HyperNetworkGRU(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        self.linear1 = nn.Linear(embed_dim, embed_dim // 8)
        self.gru = nn.GRU(embed_dim // 8,
                          embed_dim // 8,
                          num_layers=1,
                          bidirectional=False,
                          batch_first=True)
        self.linear2 = nn.Linear(embed_dim // 8, embed_dim)
        self.ln_1 = nn.LayerNorm(embed_dim // 8, eps=1e-5)
        self.activation = gelu_new

        for module in self.modules():
            _init_weights(module)

        for param in self.linear2.parameters():
            param.data.normal_(mean=0.0,
                               std=(0.02 / math.sqrt(2 * config["n_layer"])))

        for param in self.gru.parameters():
            param.data.normal_(mean=0.0,
                               std=(0.02 / math.sqrt(2 * config["n_layer"])))

        self.linear_gru = nn.Sequential(
            self.linear1,
            self.gru)
        self.layernorm_linear = nn.Sequential(
            self.ln_1,
            self.linear2)

    def forward(self, x):
        x = x.float()
        x = self.linear_gru.forward(x)[0]
        x = ck(self.activation,
               self.layernorm_linear.forward(x))
        return x.bfloat16()


class HyperNetwork(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        self.linear = nn.Linear(embed_dim, embed_dim // 4, bias=True)
        self.linear2 = nn.Linear(embed_dim // 4, embed_dim, bias=True)
        self.activation = torch.nn.functional.gelu
        for module in self.modules():
            _init_weights(module)

        for param in self.linear2.parameters():
            param.data.normal_(mean=0.0,
                               std=(0.02 / math.sqrt(2 * config["n_layer"])))

    def forward(self, x):
        x = x.float()
        x = self.linear(x)
        x = ck(self.activation, x)
        x = self.linear2(x)
        x = x.mul(torch.sigmoid(x))
        return x.bfloat16()


class HyperNetworkSingle(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        n_layers = config["n_layer"]

        self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
        self.activation = gelu_new
        # self.linear.weight.data.normal_(mean=0.0, std=0.02)
        for module in self.modules():
            _init_weights(module)

        for param in self.linear.parameters():
            param.data.normal_(mean=0.0,
                               std=(0.02 / math.sqrt(2 * n_layers)))

    def forward(self, x):
        x = x.float()
        x = self.linear(x)
        x = x.mul(torch.sigmoid(x))
        return x.bfloat16()


tokenizer = AutoTokenizer.from_pretrained('gpt2')


@torch.no_grad()
def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0, run_name="",
           generate_vanilla=False):
    torch.seed()
    tokens = tokenizer.encode(prompt)
    tokens = torch.LongTensor(tokens).unsqueeze(0).to(gpu)
    tokens = [tokens] * bsz
    tokens = torch.cat(tokens, dim=0)

    rep_pen = {
        "penalty": 3,
    }

    ops = {
        "rep_pen": rep_pen,
        "tfs": 0.8,
        "temp": 0.8,
    }
    ops_list = [ops] * bsz
    tokens_generated = sampling.generate(model.forward,
                                         tokens,
                                         n_tokens,
                                         ops_list=ops_list,
                                         hypernetwork=hypernetwork,
                                         non_deterministic=True)
    tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())

    vanilla_tokens_generated = None
    if generate_vanilla:
        vanilla_tokens_generated = sampling.generate(model.forward,
                                                     tokens,
                                                     n_tokens,
                                                     ops_list=ops_list,
                                                     hypernetwork=None)
        vanilla_tokens_generated = tokenizer.batch_decode(
            vanilla_tokens_generated.cpu().numpy())

    data = []
    for x in range(len(tokens_generated)):
        entry = {"Run": run_name,
                 "Step": step,
                 "Prompt": prompt,
                 "Generated Text": str(tokens_generated[x])}
        if vanilla_tokens_generated:
            entry["Vanilla Model"] = vanilla_tokens_generated[x]
        data.append(entry)
    return data


def report_wandb(data):
    columns = list(data[0].keys())
    step = data[0]["Step"]
    data_list = [x.values() for x in data]
    wandb.log({"Generations": wandb.Table(data=data_list, columns=columns)},
              step=step)


def print_colored_bars(color):
    print(colored("======================================================",
                  color))


def report_console(data: List[dict]):
    print_colored_bars("blue")
    print(colored(data[0]['Prompt'], "white"))
    print_colored_bars("blue")
    for gen in data:
        print_colored_bars("red")
        print(colored(gen["Generated Text"], "green"))


def make_eval_function(hypernetwork: HyperNetworkSingle, config: dict) -> \
        Callable[[int], None]:
    sample_data = {'rows': []}
    gen_vanilla = config.get('generate_vanilla', False)
    run_name = config.get('run_name', '')
    tokens_step = config.get('context_size', 2049) * \
                  config.get('bs', 1) * \
                  config.get('gas', 1)
    num_samples = config.get('num_samples', 3)
    num_tokens = config.get('num_tokens', 500)

    def eval_function(curr_step: int) -> None:
        curr_tokens_step = tokens_step * (curr_step + 1)
        print()
        print_colored_bars('yellow')
        print(f"Step: {curr_step} @ {curr_tokens_step} tokens processed")
        for prompt in prompts:
            sampled = sample(prompt, num_tokens, num_samples,
                             run_name=run_name,
                             hypernetwork=hypernetwork,
                             step=curr_step,
                             generate_vanilla=gen_vanilla)
            report_console(sampled)
            sample_data['rows'].extend(sampled)
        print_colored_bars("red")
        report_wandb(sample_data['rows'])

    return eval_function


def make_hypernet_saver(hypernetwork: HyperNetworkSingle, config: dict) \
        -> Callable[[str], None]:
    def hypernet_saver(id: str) -> None:
        save_folder = Path(config["save_path"]) / id
        save_folder.mkdir(parents=True, exist_ok=True)
        torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
        opt.save(save_folder / "opt")

    return hypernet_saver


parser = argparse.ArgumentParser(description='Hypernetwork Finetuner')
parser.add_argument('--run_name', type=str, help='the run name to use',
                    required=True)
parser.add_argument('--model', type=str, help='the model to train against',
                    required=True)
parser.add_argument('--dataset', type=str, help='pre-tokenized dataset to use',
                    required=True)
parser.add_argument("--output", type=str, help='output path',
                    default='')
parser.add_argument('--optimizer', type=str, help='the optimizer to use',
                    default='adamw')
parser.add_argument('--lr', type=float, help='learning rate', default=2e-4)
parser.add_argument('--end_lr', type=float, help='end learning rate',
                    default=2e-4)
parser.add_argument('--warmup', type=int, help='warmup steps', default=10)
parser.add_argument('--bs', type=int, help='batch size', default=4)
parser.add_argument('--gas', type=int, help='gas', default=1)
parser.add_argument('--seed', type=int, help="Random seed value",
                    default=42)
parser.add_argument("--save_steps", type=int,
                    help='# of steps between checkpoint saves',
                    default=300)
parser.add_argument("--amp", type=bool, help='enable amp', default=False)
parser.add_argument('--loss_scale', type=bool, help='whether to scale loss',
                    default=False)
parser.add_argument("--eval_every", type=int,
                    help='evaluate hypernetwork every x steps',
                    default=100)
parser.add_argument('--output_path', type=str, help="Root path of all output",
                    default="./")
parser.add_argument('--no_resume', type=bool, default=False,
                    help="Do not resume from last checkpoint")
parser.add_argument("--context_size", type=int, help="Dataset context sizes",
                    default=2049)
parser.add_argument("--project_id", type=str, help="Project ID for reporting",
                    default="hypernetwork-training")
parser.add_argument("--logs", type=str, help="log directory location",
                    default="./logs")
parser.add_argument("--masked", type=bool, help="masked softmax fusion")
parser.add_argument("--sample_vanilla", type=bool, help="sample vanilla model")
parser.add_argument("--sample_tokens", type=int, default=500,
                    help="number of tokens to sample")
parser.add_argument("--sample_num", type=int, default=3,
                    help="number of samples per prompt")
parser.add_argument("--shuffle", type=bool, help="shuffle dataset contexts")
parser.add_argument("--epochs", type=int, help="number of epochs to train for")
parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False,
                    sample_vanilla=False, shuffle=False)
args = parser.parse_args()
if args.output == '':
    args.output = f'./{args.run_name}'
# we need 250 batch size to train the small GPT.
train_config = {
    "project_id": args.project_id,
    "data_path": args.dataset,
    "save_path": args.output,
    "lm_path": args.model,
    "optimizer": args.optimizer,
    "masked_softmax_fusion": args.masked,
    "do_save": args.save_steps != 0,
    "run_name": args.run_name,
    "lr": args.lr,
    "end_lr": args.end_lr,
    "warmup_steps": args.warmup,
    "bs": args.bs,
    "gas": args.gas,
    "seed": args.seed,
    "save_every": args.save_steps,
    "amp": args.amp,
    "loss_scale": args.loss_scale,
    "eval_every": args.eval_every,
    "context_size": args.context_size,
    "sample_vanilla": args.sample_vanilla,
    "num_samples": args.sample_num,
    "num_tokens": args.sample_tokens,
    "shuffle": args.shuffle,
    "epochs": args.epochs,
    "logs": args.logs,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
gas = train_config["gas"]

Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)

model = lm_utils.load_from_path(train_config["lm_path"]).to(gpu).bfloat16()
for param in model.parameters():
    param.requires_grad = False

for name, p in model.named_parameters():
    if "ln" in name or "vocab_embed" in name:
        p.requires_grad = True

hypernetwork = HyperNetworkSingle(model.config).to(gpu).float()
for param in hypernetwork.parameters():
    param.requires_grad = True
hypernetwork_saver = make_hypernet_saver(hypernetwork, train_config)
eval_fn = make_eval_function(hypernetwork, train_config)

cp_list = sorted(os.listdir(train_config["save_path"]),
                 key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(
    cp_list) > 0 else None

if last_cp and not args.no_resume:
    print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
    hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
    opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(),
                                        last_cp / "opt")
else:
    opt = optimizer.BasedOptimizer(hypernetwork.parameters(),
                                   train_config,
                                   train_config["optimizer"])

# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden
#       states from the get_logits function.
train_dataset = dataset.ShardedDataset(train_config["context_size"],
                                       train_config["data_path"])
if last_cp:
    train_dataset.skip = opt.curr_step

train_loader = torch_data.DataLoader(train_dataset,
                                     batch_size=bs * gas,
                                     shuffle=train_config["shuffle"],
                                     num_workers=0)
wandb.init(project=train_config["project_id"],
           name=train_config["run_name"],
           config={**train_config, **model.config})
print("wandb initialized")

if last_cp:
    curr_step = opt.curr_step
else:
    curr_step = 0

epoch_steps = len(train_loader)
total_steps = epoch_steps * train_config['epochs']
tokens_per_step = train_config['context_size'] * \
                  train_config['bs'] * \
                  train_config['gas']

eval_fn(curr_step)

with tqdm(total=total_steps, initial=curr_step) as t:
    for epoch in range(train_config['epochs']):
        for input_ids, labels in train_loader:
            timex = time.perf_counter()
            input_ids = input_ids.to(gpu)
            labels = labels.to(gpu)
            loss = 0
            for x in range(train_config["gas"]):
                with amp.autocast(enabled=train_config["amp"],
                                  dtype=torch.float16):
                    logits, _ = model(input_ids[x * bs:(x + 1) * bs, :].to(gpu),
                                      hypernetwork=hypernetwork,
                                      act_ck=True)
                    logits = logits.view(-1, logits.shape[-1])
                    gas_labels = labels[x * bs:(x + 1) * bs, :].contiguous()
                    gas_labels = gas_labels.view(-1)
                    gas_labels[gas_labels == 50256] = -100
                    gas_loss = F.cross_entropy(logits, gas_labels)

                if train_config["loss_scale"]:
                    scaler.scale(gas_loss).backward()
                else:
                    gas_loss.backward()

                loss += gas_loss.item()

            loss = loss / gas
            if train_config["loss_scale"]:
                scaler.unscale_(opt.optimizer)
            torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 1)
            if train_config["loss_scale"]:
                opt.step(scaler=scaler)
            else:
                opt.step()

            if train_config["loss_scale"]:
                scaler.update()

            opt.zero_grad()
            sec_per_step = (time.perf_counter() - timex)
            step_per_sec = (1. / sec_per_step)
            tokens_per_sec = step_per_sec * tokens_per_step
            curr_tokens = tokens_per_step * (curr_step + 1)
            t.set_description(f"{step_per_sec:.2f} steps/s, "
                              f"{sec_per_step:.2f}s/step, "
                              f"{tokens_per_sec:.2f}tokens/s, "
                              f"loss={loss:.4f}, "
                              f"{curr_tokens} tokens processed")
            wandb.log(
                {
                    "train/epoch": float(curr_step) / float(epoch_steps),
                    "train/loss": loss,
                    "train/tokens_per_sec": tokens_per_sec,
                    "train/sec_per_step": sec_per_step,
                    "train/step_per_sec": step_per_sec,
                    "train/lr": opt.curr_lr,
                    "train/loss_scale": scaler.get_scale(),
                    "train/tokens": curr_tokens,
                },
                step=curr_step)

            if train_config["do_save"] and \
                    curr_step % train_config["save_every"] == 0 and \
                    curr_step != 0:
                hypernetwork_saver(f"step_{curr_step}")
                print(f"\nSaved model at step {curr_step}")

            if curr_step % train_config["eval_every"] == 0 and curr_step != 0:
                eval_fn(curr_step)

            curr_step += 1
            t.update(1)
        if train_config["epochs"] > 1:
            hypernetwork_saver(f"epoch-{epoch}")

eval_fn(curr_step)
hypernetwork_saver("final")
