from curses import meta
from torch import optim
import numpy as np
import torch
from dotmap import DotMap
import pickle
import os
from pathlib import Path
from torch.distributed.optim import ZeroRedundancyOptimizer
#Based Optimizer

def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr, cosine_warmup=False):
    warmup_percent = np.clip(step, 0, warmup_steps) / warmup_steps
    anneal_percent = np.clip(step - warmup_steps, 0, anneal_steps) / anneal_steps

    if cosine_warmup:
        main_lr = lr * (1 - np.cos(np.pi * warmup_percent)) / 2
    else:
        main_lr = lr * warmup_percent

    anneal_lr = (lr - end_lr) * (1 - np.cos(np.pi * anneal_percent)) / 2
    return main_lr - anneal_lr

class BasedOptimizer:
    def __init__(self, parameters, config, optimizer, init=True):
        if init:
            self.config = config
            self.optimizer_name = optimizer
            self.parameters = parameters
            self.init_config()
            self.init_optimizer()

    def init_config(self):
        defaults = {
            "lr": 6e-4,
            "end_lr": 6e-4,
            "warmup_steps": 1,
            "anneal_steps": 1,
            "total_steps": None,
            "weight_decay": 0.01,
            "tokens": None,
            "epochs": None,
            "beta1": 0.9,
            "beta2": 0.95,
            "eps": 1e-4,
            "max_lr": False,
            "curr_step": 0,
            "curr_lr": 0,
        }
        
        for k, v in defaults.items():
            setattr(self, k, v)
        
        for k, v in self.config.items():
            setattr(self, k, v)

    def init_optimizer(self):
        if self.optimizer_name == "adamw":
            self.optimizer = optim.AdamW(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)

        elif self.optimizer_name == "adamw8bit":
            import bitsandbytes as bnb
            self.optimizer = bnb.optim.Adam8bit(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)

        elif self.optimizer_name == "zero1":
            import bitsandbytes as bnb
            self.optimizer = ZeroRedundancyOptimizer(
                self.parameters,
                optimizer_class=bnb.optim.Adam8bit,
                lr=0,
                weight_decay=self.weight_decay,
                betas=(self.beta1, self.beta2),
                eps=self.eps,
            )

        elif self.optimizer_name == "adafactor":
            try:
                from transformers.optimization import Adafactor

            except ImportError:
                raise ImportError("Please install transformers for Adafactor")

            self.optimizer = Adafactor(params=self.parameters)

        
    def step(self, dry_run=False, scaler=None):
        if not dry_run:
            if scaler:
                scaler.step(self.optimizer)

            else:
                self.optimizer.step()
        
        self.curr_step = self.curr_step + 1
        self.curr_lr = lr_schedule(self.curr_step, self.warmup_steps, self.anneal_steps, self.lr, self.end_lr)
        
        if not self.max_lr:
            if self.curr_lr == self.end_lr:
                print("max lr reached.")
                self.max_lr = True
                
            for paramx in self.optimizer.param_groups:
                paramx['lr'] = self.curr_lr
        
    def zero_grad(self):
        self.optimizer.zero_grad()
        
    def print_info(self):
        print(f"end_lr: {str(self.end_lr)}")
        print(f"warmup_steps: {str(self.warmup_steps)}")
        print(f"total_steps: {str(self.total_steps)}")
        print(f"weight_decay: {str(self.weight_decay)}")
        print(f"step: {str(self.curr_step)}")
        if self.curr_step != 0:
            print(f"curr_lr: {str(self.get_current_lr())}")

    def save(self, path: Path):
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)
        torch.save(self.optimizer.state_dict(), path / "opt_states.pt")

        metadata = self.__dict__.copy()
        #clean the optimizer and parameters from the dict.
        del metadata["optimizer"]
        del metadata["parameters"] 
        with open(path / "opt_metadata.pkl", 'wb') as f:
            pickle.dump(metadata, f)

    @classmethod
    def load(cls, parameters, path):
        path = Path(path)
        with open(path / "opt_metadata.pkl", 'rb') as f:
            metadata = pickle.load(f)
        
        based_optimizer = cls(parameters, metadata, metadata["optimizer_name"])
        try:
            based_optimizer.optimizer.load_state_dict(torch.load(path / "opt_states.pt"))
        except:
            print("Couldn't load the optimizer, initializing the optimizer states. Honk!!!")
            pass
        return based_optimizer