from . import gptj
from . import gpt2
from . import fairseq
from . import gptneo
from . import alibi
from . import vit 
from . import resnet
from . import fast

MODEL_MAP = {
    "gptj": gptj.GPTJModel,
    "gpt2": gpt2.GPT2Model,
    "gpt-fairseq": fairseq.GPTFairModel,
    "gpt-neo": gptneo.GPTNeoModel,
    "alibi": alibi.AlibiModel,
    "vit": vit.VisionTransformer,
    "resnet": resnet.ResNet
}

def get_model(model_name: str):
    return MODEL_MAP[model_name]
