import json
import torch
from fairseq.models.transformer_lm import TransformerLanguageModel
import sys
import os
from pathlib import Path

copy_eot_to_newline = True
copy_newline_to_eot = True
model_dir = 'pretrained/en_dense_lm_125m' # path to smol model weights to fix tokenizer shuffle

checkpoint = {}
ckmap = {}
ckid = 0

def save(params, name):
    global ckid
    ckmap[name] = f"b{ckid}.pt"
    ckid += 1
    path = Path(f"{sys.argv[2]}/lm")
    path.mkdir(parents=True, exist_ok=True)
    torch.save(params, path / ckmap[name])
    torch.save(ckmap, path / "m.pt")
    print(name + ": " + str(params.shape))
    del params

def no_init(loading_code):
    def dummy(self):
        return
    
    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy
    
    result = loading_code()
    for mod in modules:
        mod.reset_parameters = original[mod]
    
    return result

lm = no_init(lambda: TransformerLanguageModel.from_pretrained(model_dir, bpe='gpt2').eval().cpu())
fairdict = torch.load(f"{sys.argv[1]}", map_location="cpu")

try:
    os.mkdir(sys.argv[2])
except:
    pass

hidden_dim = fairdict["cfg"]["model"]["decoder_embed_dim"]
num_heads = fairdict["cfg"]["model"]["decoder_attention_heads"]
num_layers = fairdict["cfg"]["model"]["decoder_layers"]

fairdict = fairdict["model"]

config = {
    "model_class": "gpt-fairseq",
    "model_path": ".",
    "model_config": {
        "n_layer": num_layers,
        "n_head": num_heads,
        "hidden_dim": hidden_dim,
        "vocab_dim": 51200,
        "eps": 1e-05,
        "n_tokens": 2049
    }
  }

with open(f"{sys.argv[2]}/config.json", "w") as fh:
    fh.write(json.dumps(config))

#print(lm)

def hack_embs(embs):
    eot = embs[50256].clone()
    newline = embs[198].clone()
    if copy_eot_to_newline:
        embs[198] = eot
    if copy_newline_to_eot:
        embs[50256] = newline

# gpt2 compatible input/output embedding layers
l1 = []
l2 = []

check = {}
for i in range(50256):
    check[i] = True

for i, s in enumerate(lm.tgt_dict.symbols):
    try:
        if str(int(s)) == s and s != '50256':
            l2.append(int(s))
            l1.append(i)
            del check[int(s)]
    except:
        pass

for i, s in enumerate([lm.tgt_dict.eos_word, lm.tgt_dict.pad_word, lm.tgt_dict.bos_word, lm.tgt_dict.unk_word]):
    l2.append(50256 + i)
    l1.append(lm.tgt_dict.indices[s])

mapping = {}
for i in range(50260):
    mapping[l1[i]] = l2[i]


with torch.no_grad():
    wte = fairdict["decoder.embed_tokens.weight"].clone()
    for i in range(50260):
        wte[mapping[i]] = fairdict["decoder.embed_tokens.weight"][i]
    #hack_embs(wte)
    save(wte.half(), "vocab_embed.weight")
    lm_head = fairdict["decoder.output_projection.weight"].clone()
    for i in range(50260):
        lm_head[mapping[i]] = fairdict["decoder.output_projection.weight"][i]
    #hack_embs(lm_head)
    save(lm_head.half(), "lm_head.weight")

save(torch.FloatTensor(1), "pos_embed._float_tensor")

new_state_dict = {}
for y in fairdict:
    dotlist = y.split(".")

    if y == "decoder.version":
        trans_to = "Passed"
        pass

    elif y == "decoder.embed_tokens.weight":
        continue

    elif len(dotlist) >= 2 and dotlist[1] == "layers":
        layer_id = dotlist[2]

        if dotlist[-2] in ["k_proj", "v_proj", "q_proj", "out_proj"]:
            trans_to = f"layers.{layer_id}.attn.{dotlist[-2]}.{dotlist[-1]}"

        if dotlist[-2] == "self_attn_layer_norm":
            trans_to = f"layers.{layer_id}.ln_preattn.{dotlist[-1]}"

        if dotlist[3] == "fc1":
            trans_to = f"layers.{layer_id}.ff.ff1.{dotlist[-1]}"

        if dotlist[3] == "fc2":
            trans_to = f"layers.{layer_id}.ff.ff2.{dotlist[-1]}"
        
        if dotlist[3] == "final_layer_norm":
            trans_to = f"layers.{layer_id}.ln_postattn.{dotlist[-1]}"

    elif len(dotlist) >= 2 and dotlist[1] == "layer_norm":
        trans_to = f"ln_final.{dotlist[-1]}"

    elif y == "decoder.output_projection.weight":
        continue

    if trans_to != "Passed":
        save(fairdict[y].half(), trans_to)
    print(f"{trans_to} < {y}")