import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim
from pathlib import Path
from lm_train import optimizer, utils
from torch.utils import data
from main import *
import yaml
import sys
from tqdm import tqdm
import time
import wandb
from lm_arch.gpt2 import GPT2Model

model_config = {
    "n_layer": 12,
    "n_head": 12,
    "hidden_dim": 768,
    "vocab_dim": 50400,
    "eps": 1e-5,
    "activation": gelu_new,
    "Layer": GPTLayer
}

# we need 250 batch size to train the small GPT.
train_config = {
    "data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
    "save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2fp16amp2",
    "run_name": "owt2-125m-fp16AMP-1024ctx-120bs-1e-4lr",
    "lr": 1e-4,
    "end_lr": 1e-4,
    "warmup_steps": 100,
    "bs": 12,
    "gas": 10,
    "seed": 69,
    "save_every": 500,
    "amp": True,
}
bs = train_config["bs"]
gas = train_config["gas"]

Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)

model = GPT2Model.gpt2_init(model_config).cuda().float()
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")

# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.

train_dataset = utils.FbDataset(2049, train_config["data_path"])
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0)
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model_config})

t = tqdm(train_loader)
curr_step = 0

scaler = torch.cuda.amp.GradScaler()

for input_ids, labels in t:
    timex = time.perf_counter()
    input_ids = input_ids.cuda()
    labels = labels.cuda()
    loss = 0
    for x in range(train_config["gas"]):
        if train_config["amp"]:
            with torch.cuda.amp.autocast():
                logits = model(input_ids[x*bs:(x+1)*bs, :1024].cuda(), hypernetwork=None, act_ck=False)
                logits = logits.view(-1, logits.shape[-1])
                gas_labels = labels[x*bs:(x+1)*bs, :1024].contiguous()
                gas_labels = gas_labels.view(-1)
                gas_loss = F.cross_entropy(logits, gas_labels)
        else:

            logits = model(input_ids[x*bs:(x+1)*bs, :1024].cuda(), hypernetwork=None, act_ck=False)
            logits = logits.view(-1, logits.shape[-1])
            gas_labels = labels[x*bs:(x+1)*bs, :1024].contiguous()
            gas_labels = gas_labels.view(-1)
            gas_loss = F.cross_entropy(logits, gas_labels)

        if train_config["amp"]:
            scaler.scale(gas_loss).backward()
        else:
            gas_loss.backward()

        loss += gas_loss.item()

    loss = loss / gas
    if train_config["amp"]:
        scaler.unscale_(opt.optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
    if train_config["amp"]:
        opt.step(scaler=scaler)
    else:
        opt.step()
    if train_config["amp"]:
        scaler.update()
    #opt.step()
    opt.zero_grad()
    sec_per_step = (time.perf_counter() - timex) / (bs*gas)
    step_per_sec = (1. / sec_per_step)
    tokens_per_sec = step_per_sec * 1024
    t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
    wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr, "train/loss_scale": scaler.get_scale()})
    curr_step += 1
    if curr_step % train_config["save_every"] == 0:
        model.save(train_config["save_path"] + f"/{curr_step}")
        print(f"Saved model at step {curr_step}")
