import os
from typing import OrderedDict
import torch

model_dir = 'pretrained/resnet/'
new_state_dict = {}

weights : OrderedDict = torch.load(model_dir + 'resnet_18.pth')
net_conf = (False, (2, 2, 2, 2)) 
counter = 0
new_state_dict['layerin.0.weight'] = weights['conv1.weight']
new_state_dict['layerin.2.weight'] = weights['bn1.weight']
new_state_dict['layerin.2.bias'] = weights['bn1.bias']
for i, j in enumerate(net_conf[1], 1):
    for k in range(j):
        curr_layer = f"layer{i}.{k}."
        curr_state_dict_key = f"resblocks.{counter}."
        new_state_dict[curr_state_dict_key + "conv1.weight"] = weights[curr_layer + "conv1.weight"]
        new_state_dict[curr_state_dict_key + "conv2.weight"] = weights[curr_layer + "conv2.weight"]

        new_state_dict[curr_state_dict_key + "bn1.weight"] = weights[curr_layer + "bn1.weight"]
        new_state_dict[curr_state_dict_key + "bn1.bias"] = weights[curr_layer + "bn1.bias"]
        new_state_dict[curr_state_dict_key + "bn2.weight"] = weights[curr_layer + "bn2.weight"]
        new_state_dict[curr_state_dict_key + "bn2.bias"] = weights[curr_layer + "bn2.bias"]

        if net_conf[0]:
            new_state_dict[curr_state_dict_key + "conv3.weight"] = weights[curr_layer + "conv3.weight"]
            new_state_dict[curr_state_dict_key + "bn3.weight"] = weights[curr_layer + "bn3.weight"]
            new_state_dict[curr_state_dict_key + "bn3.bias"] = weights[curr_layer + "bn3.bias"]
        counter += 1

torch.save(new_state_dict, "pretrained/resnet/modified/resnet_18.pth")