import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim
from pathlib import Path
from torch.utils import data
from basedformer import optimizer, utils, lm_utils
import yaml
import sys
from tqdm import tqdm
import time
import wandb
import numpy as np
import os
from icecream import ic
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from dotmap import DotMap
import argparse
from torch.distributed.fsdp import (
   FullyShardedDataParallel,
   CPUOffload,
)
from torch.distributed.fsdp.wrap import (
   default_auto_wrap_policy,
)

def setup(rank, world_size):
    #os.environ['MASTER_ADDR'] = 'localhost'
    #os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group(backend="nccl")
    if dist.is_initialized():
        print("Initialized process group")
    else:
        print("Failed to initialize process group")

def cleanup():
    dist.destroy_process_group()

def get_rank():
    if dist.is_initialized():
        return dist.get_rank()

def get_world():
    if dist.is_initialized():
        return dist.get_world_size()

def get_flops(args, model, iter_time_s):
    ff = model.total_params * 6
    attn = 2048 * model.config.hidden_dim * model.config.n_layer * 60
    flops = (
        args.bs * args.gas
        * 2048
        * (ff + attn)
        / (iter_time_s)
    )
    return flops / 1e12

def fsdp_train(args, model, train_loader, opt):
    bs = args["bs"]
    gas = args["gas"]
    global_rank = get_rank()
    rank = int(os.environ["LOCAL_RANK"])
    world_size = get_world()
    model.train()
    ddp_loss = torch.zeros(1).cuda()
    if rank == 0:
        t = tqdm(train_loader)
    else:
        t = train_loader

    scaler = torch.cuda.amp.GradScaler()
    counter = 0
    for input_ids, labels in t:
        timex = time.perf_counter()
        input_ids = input_ids.to(rank)
        labels = labels.to(rank)
        loss = 0
        for x in range(args["gas"]):
            with torch.cuda.amp.autocast(enabled=args["amp"], dtype=args["cast_to"]):
                logits, hidden_states = model(input_ids[x*bs:(x+1)*bs, :2048].to(rank), act_ck=True)
                logits = logits.view(-1, logits.shape[-1])
                gas_labels = labels[x*bs:(x+1)*bs, :2048].contiguous()
                gas_labels = gas_labels.view(-1)
                gas_loss = F.cross_entropy(logits, gas_labels)
                if args.contrastive_loss:
                    #print("contrastive enabled")
                    with torch.no_grad():
                        max = hidden_states.abs().amax().detach()
                    hs = hidden_states.div(max)
                    norm = hs.norm(dim=-1, keepdim=True)
                    norm = norm.matmul(norm.transpose(-1,-2))
                    contrastive_loss = torch.matmul(hs, hs.transpose(-2, -1)).div(norm).abs().mean()
                    gas_loss += contrastive_loss * args.contrastive_loss
            if args["loss_scale"]:
                scaler.scale(gas_loss).backward()
            else:
                gas_loss.backward()

            loss += gas_loss.item()

        loss = loss / gas
        if args["loss_scale"]:
            scaler.unscale_(opt.optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        if args["loss_scale"]:
            opt.step(scaler=scaler)
        else:
            opt.step()

        if args["loss_scale"]:
            scaler.update()

        #opt.zero_grad()
        model.zero_grad(set_to_none=True)
        sec_per_step = (time.perf_counter() - timex)
        flops = get_flops(args, model.module, sec_per_step)
        step_per_sec = (1. / sec_per_step)
        tokens_per_sec = (step_per_sec * 2048) * bs * gas * world_size
        batch_size = bs * gas * world_size
        ddp_loss[0] = loss
        dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
        if global_rank == 0:

            wandb.log({
                "train_loss": ddp_loss[0] / world_size,
                "train/tokens_per_sec": tokens_per_sec,
                "train/sec_per_step": sec_per_step,
                "train/step_per_sec": step_per_sec, 
                "train/lr": opt.curr_lr,
                "train/batch_size": batch_size,
                "train/loss_scale": scaler.get_scale(),
                "train/flops": flops,
                })
        
        if counter != 0 and counter % args["save_every"] == 0:
            if global_rank == 0:
                lm_utils.save(model.module, Path(args["save_path"]) / f"step_{str(counter)}")
            dist.barrier()

        counter += 1

# we need 250 batch size to train the small GPT.
def main(rank, global_rank, world_size, args):
    bs = args["bs"]
    gas = args["gas"]
    torch.manual_seed(args["seed"])
    setup(rank, world_size)
    Path(args["save_path"]).mkdir(parents=True, exist_ok=True)

    model = lm_utils.load_from_path("pretrained/gpt-j-base").float().to(rank)
    fsdp_model = DDP(model, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
    utils.print_parameters(fsdp_model)

    ic("model loaded")
    opt = optimizer.BasedOptimizer(fsdp_model.parameters(), args, "zero1")
    # TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
    print(opt.curr_step)
    train_dataset = utils.ShardedDataset(2049, args["data_path"], world_size=world_size, rank=global_rank)
    train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
    if global_rank == 0:
        wandb.init(project="basedformer-tests", name=args["run_name"], config={**args, **model.config})
    fsdp_train(args, fsdp_model, train_loader, opt)
    lm_utils.save(fsdp_model.module, Path(args["save_path"]) / "final")
    dist.barrier()
    cleanup()

if __name__ == "__main__":
    train_config = {
        "data_path": "dataset/sigurd-1G.map",
        "save_path": "models/gptj-sigurd-1G-vanilla",
        "do_save": True,
        "run_name": "gptj-sigurd-1G-vanilla",
        "lr": 6e-5,
        "end_lr": 3e-5,
        "warmup_steps": 100,
        "anneal_steps": 7850,
        "bs": 2,
        "gas": 2,
        "seed": 69,
        "save_every": 500,
        "amp": True,
        "loss_scale": True,
        "cast_to": torch.float16,
        "contrastive_loss": False,
    }

    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["LOCAL_RANK"])
    global_rank = int(os.environ["RANK"])
    torch.cuda.set_device(rank)
    main(rank, global_rank, world_size, DotMap(train_config))