from main import *

state_dict = SplitCheckpoint("j6b_vanilla", device="cpu")


# ORIGINAL
'''
transformer.ln_f.weight
transformer.ln_f.bias
lm_head.weight
lm_head.bias
transformer.h.9.ln_1.weight
transformer.h.9.ln_1.bias
transformer.h.9.mlp.c_proj.weight
transformer.h.9.mlp.c_proj.bias
transformer.h.9.mlp.c_fc.weight
transformer.h.9.mlp.c_fc.bias
transformer.h.9.attn.attention.out_proj.weight
transformer.h.9.attn.attention.k_proj.weight
transformer.h.9.attn.attention.v_proj.weight
transformer.h.9.attn.attention.q_proj.weight
transformer.wte.weight
'''

new_state_dict = {}
module_map = {
                "ln_1": "ln_preattn",
                "mlp.c_proj": "ff.ff2",
                "mlp.c_fc": "ff.ff1",
                "attn.attention.out_proj": "attn.out_proj",
                "attn.attention.k_proj": "attn.k_proj",
                "attn.attention.v_proj": "attn.v_proj",
                "attn.attention.q_proj": "attn.q_proj",
                "wte": "vocab_embed",
                'ln_f': 'ln_final',
                'lm_head': 'lm_head',
                }

print(type(state_dict))
for key in state_dict.keys():
    dotlist = key.split('.')
    if len(dotlist) > 3:
        layer = dotlist[2]
        for x in module_map:
            if x in key:
                new_state_dict[f"layers.{layer}.{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
                print(f"{key} -> layers.{layer}.{module_map[x]}.{dotlist[-1]}")
    else:
        for x in module_map:
            if x in key:
                new_state_dict[f"{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
                print(f"{key} -> {module_map[x]}.{dotlist[-1]}")

#print(new_state_dict)

def save(state_dict, path):
    try: os.mkdir(path)
    except: pass
    checkpoint = {}
    for i, x in enumerate(state_dict.items()):
        checkpoint[x[0]] = f"{path}/b{i}.pt"
        torch.save(x[1], f"{path}/b{i}.pt")
    torch.save(checkpoint, f"{path}/m.pt")

save(new_state_dict, "models/6b_vanilla")