import os
import torch
from dotmap import DotMap
from finetune import main
if __name__ == "__main__":
    train_config = {
        "data_path": "dataset/sigurd-1G.map",
        "save_path": "models/gptj-sigurd-1G-contrastive-0.3weight",
        "do_save": True,
        "run_name": "gptj-sigurd-1G-contrastive0.3weight",
        "lr": 6e-5,
        "end_lr": 3e-5,
        "warmup_steps": 100,
        "anneal_steps": 7850,
        "bs": 2,
        "gas": 2,
        "seed": 69,
        "save_every": 500,
        "amp": True,
        "loss_scale": True,
        "cast_to": torch.float16,
        "contrastive_loss": 0.3,
    }

    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["LOCAL_RANK"])
    global_rank = int(os.environ["RANK"])
    torch.cuda.set_device(rank)
    main(rank, global_rank, world_size, DotMap(train_config))