import torch
import transformers
import sys
"""
Original:

ln_f.weight
ln_f.bias
wte.weight
wpe.weight
h.0.ln_1.weight
h.0.ln_1.bias
h.0.attn.bias
h.0.attn.c_attn.weight
h.0.attn.c_attn.bias
h.0.attn.c_proj.weight
h.0.attn.c_proj.bias
h.0.ln_2.weight
h.0.ln_2.bias
h.0.mlp.c_fc.weight
h.0.mlp.c_fc.bias
h.0.mlp.c_proj.weight
h.0.mlp.c_proj.bias

attn has biases unlike GPT-J. QKV Matrices are also merged instead of separate. what is the order though?

"""
x = torch.load("models/gpt2_vanilla/pytorch_model.bin")
print(x["h.0.attn.c_attn.weight"].reshape(-1, 768, 768).shape)
sys.exit(0)

new_state_dict = {}
module_map = {
                "ln_1": "ln_preattn",
                "mlp.c_proj": "ff.ff2",
                "mlp.c_fc": "ff.ff1",
                "attn.attention.out_proj": "attn.out_proj",
                "attn.attention.k_proj": "attn.k_proj",
                "attn.attention.v_proj": "attn.v_proj",
                "attn.attention.q_proj": "attn.q_proj",
                "wte": "vocab_embed",
                'ln_f': 'ln_final',
                'lm_head': 'lm_head',
                }

print(type(state_dict))
for key in state_dict.keys():
    dotlist = key.split('.')
    if len(dotlist) > 3:
        layer = dotlist[2]
        for x in module_map:
            if x in key:
                new_state_dict[f"layers.{layer}.{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
                print(f"{key} -> layers.{layer}.{module_map[x]}.{dotlist[-1]}")
    else:
        for x in module_map:
            if x in key:
                new_state_dict[f"{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
                print(f"{key} -> {module_map[x]}.{dotlist[-1]}")

#print(new_state_dict)

def save(state_dict, path):
    try: os.mkdir(path)
    except: pass
    checkpoint = {}
    for i, x in enumerate(state_dict.items()):
        checkpoint[x[0]] = f"{path}/b{i}.pt"
        torch.save(x[1], f"{path}/b{i}.pt")
    torch.save(checkpoint, f"{path}/m.pt")

save(new_state_dict, "models/6b_vanilla")