import torch
import transformers
import sys
from icecream import ic
import os
"""
Original:

ln_f.weight
ln_f.bias
wte.weight
wpe.weight
h.0.ln_1.weight
h.0.ln_1.bias
h.0.attn.bias
h.0.attn.c_attn.weight
h.0.attn.c_attn.bias
h.0.attn.c_proj.weight
h.0.attn.c_proj.bias
h.0.ln_2.weight
h.0.ln_2.bias
h.0.mlp.c_fc.weight
h.0.mlp.c_fc.bias
h.0.mlp.c_proj.weight
h.0.mlp.c_proj.bias

attn has biases unlike GPT-J. QKV Matrices are also merged instead of separate. what is the order though? probably just QKV.

"""
x = torch.load("models/gpt2_vanilla/pytorch_model.bin")
state_dict = x
print(x["h.0.attn.c_attn.weight"].reshape(-1, 768, 768).shape)
ic(x["h.0.attn.c_attn.weight"].shape)
ic(x["h.0.attn.c_attn.bias"].shape)
ic(x["h.0.attn.c_proj.weight"].shape)
ic(x["h.0.attn.c_proj.bias"].shape)

new_state_dict = {}
module_map = {
                "ln_1": "ln_preattn",
                "ln_2": "ln_postattn",
                "mlp.c_proj": "ff.ff2",
                "mlp.c_fc": "ff.ff1",
                "attn.c_proj": "attn.out_proj",
                "attn.c_attn": "attn.k_proj",
                "wte": "vocab_embed",
                "wpe": "pos_embed",
                'ln_f': 'ln_final',
                'lm_head': 'lm_head',
                }

print(type(state_dict))
for key in state_dict.keys():
    dotlist = key.split('.')
    if len(dotlist) > 3:
        layer = dotlist[1]
        for x in module_map:
            if x in key:
                if x == "attn.c_attn":
                    if "weight" in key:
                        hidden_dim = state_dict[key].shape[0]
                        qkv = state_dict[key].reshape(-1, hidden_dim, hidden_dim).split(1)
                        new_state_dict[f"layers.{layer}.attn.q_proj.weight"] = qkv[0].squeeze(0).transpose(-1, -2)
                        new_state_dict[f"layers.{layer}.attn.k_proj.weight"] = qkv[1].squeeze(0).transpose(-1, -2)
                        new_state_dict[f"layers.{layer}.attn.v_proj.weight"] = qkv[2].squeeze(0).transpose(-1, -2)
                    if "bias" in key:
                        hidden_dim = state_dict[key].shape[0] // 3
                        qkv = state_dict[key].reshape(-1, hidden_dim).split(1)
                        new_state_dict[f"layers.{layer}.attn.q_proj.bias"] = qkv[0].squeeze(0)
                        new_state_dict[f"layers.{layer}.attn.k_proj.bias"] = qkv[1].squeeze(0)
                        new_state_dict[f"layers.{layer}.attn.v_proj.bias"] = qkv[2].squeeze(0)
                else:
                    if len(state_dict[key].shape) == 2:
                        ic("transpose!")
                        new_state_dict[f"layers.{layer}.{module_map[x]}.{dotlist[-1]}"] = state_dict[key].transpose(-1, -2)
                    else:
                        new_state_dict[f"layers.{layer}.{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
                    print(f"{key} -> layers.{layer}.{module_map[x]}.{dotlist[-1]}")
    else:
        for x in module_map:
            if x in key:
                new_state_dict[f"{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
                print(f"{key} -> {module_map[x]}.{dotlist[-1]}")

for k, v in new_state_dict.items():
    print(f"{k} -> {v.shape}")


def save(state_dict, path):
    try: os.mkdir(path)
    except: pass
    checkpoint = {}
    for i, x in enumerate(state_dict.items()):
        checkpoint[x[0]] = f"{path}/b{i}.pt"
        torch.save(x[1], f"{path}/b{i}.pt")
    torch.save(checkpoint, f"{path}/m.pt")

save(new_state_dict, "pretrained/gpt2/lm")
