import torch
import transformers
import sys
from icecream import ic
import os
from pathlib import Path
"""
Original:

ln_f.weight
ln_f.bias
wte.weight
wpe.weight
h.0.ln_1.weight
h.0.ln_1.bias
h.0.attn.bias
h.0.attn.c_attn.weight
h.0.attn.c_attn.bias
h.0.attn.c_proj.weight
h.0.attn.c_proj.bias
h.0.ln_2.weight
h.0.ln_2.bias
h.0.mlp.c_fc.weight
h.0.mlp.c_fc.bias
h.0.mlp.c_proj.weight
h.0.mlp.c_proj.bias

attn has biases unlike GPT-J. QKV Matrices are also merged instead of separate. what is the order though? probably just QKV.

"""
x = torch.load("pretrained/gpt-neo-125m/pytorch_model.bin")
state_dict = x
ic(x.keys())

new_state_dict = {}
module_map = {
                "ln_1": "ln_preattn",
                "ln_2": "ln_postattn",
                "mlp.c_proj": "ff.ff2",
                "mlp.c_fc": "ff.ff1",
                "attn.attention.q_proj": "attn.q_proj",
                "attn.attention.k_proj": "attn.k_proj",
                "attn.attention.v_proj": "attn.v_proj",
                "attn.attention.out_proj": "attn.out_proj",
                "wte": "vocab_embed",
                "wpe": "pos_embed",
                'ln_f': 'ln_final',
                }

print(type(state_dict))
for key in state_dict.keys():
    dotlist = key.split('.')
    if len(dotlist) > 3:
        layer = dotlist[2]
        for x in module_map:
            if x in key:   
                new_state_dict[f"layers.{layer}.{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
                print(f"{key} -> layers.{layer}.{module_map[x]}.{dotlist[-1]}")
    else:
        for x in module_map:
            if x in key:
                new_state_dict[f"{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
                print(f"{key} -> {module_map[x]}.{dotlist[-1]}")

new_state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]

for k, v in new_state_dict.items():
    print(f"{k} -> {v.shape}")


def save(state_dict, path):
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)
    checkpoint = {}
    for i, x in enumerate(state_dict.items()):
        checkpoint[x[0]] = f"{path}/b{i}.pt"
        torch.save(x[1], f"{path}/b{i}.pt")
    torch.save(checkpoint, f"{path}/m.pt")

save(new_state_dict, "pretrained/gpt-neo-125m-ported/lm")
