from torch import optim
import numpy as np
import torch
from dotmap import DotMap
import pickle
#Based Optimizer

def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr):
    warmup_percent = np.clip(step, 0, warmup_steps) / warmup_steps
    anneal_percent = np.clip(step - warmup_steps, 0, anneal_steps) / anneal_steps
    #cosine schedule for annealing
    return lr * warmup_percent - (lr - end_lr) * (1 - np.cos(np.pi * anneal_percent)) / 2


class BasedOptimizer:
    def __init__(self, parameters, config, optimizer):
        defaults = {
            "lr": 6e-4,
            "end_lr": 6e-4,
            "warmup_steps": 1,
            "anneal_steps": 1,
            "total_steps": None,
            "weight_decay": 0.01,
            "tokens": None,
            "epochs": None,
            "beta1": 0.9,
            "beta2": 0.95,
            "eps": 1e-4,
        }
        
        for k, v in defaults.items():
            setattr(self, k, v)
        
        for k, v in config.items():
            setattr(self, k, v)

        self.max_lr = False
        self.curr_step = 0
        self.curr_lr = 0
            
        if optimizer == "adamw":
            self.optimizer = optim.AdamW(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
        elif optimizer == "adamw8bit":
            import bitsandbytes as bnb
            self.optimizer = bnb.optim.Adam8bit(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
        elif optimizer == "adafactor":
            try:
                from transformers.optimization import Adafactor
            except ImportError:
                raise ImportError("Please install transformers for Adafactor")

            self.optimizer = Adafactor(params=parameters)
        
    def step(self, scaler=None):
        if scaler:
            scaler.step(self.optimizer)
        else:
            self.optimizer.step()
            
        self.curr_step = self.curr_step + 1
        self.curr_lr = lr_schedule(self.curr_step, self.warmup_steps, self.anneal_steps, self.lr, self.end_lr)

        if not self.max_lr:
            if self.curr_lr == self.end_lr:
                print("max lr reached.")
                self.max_lr = True
                
            for paramx in self.optimizer.param_groups:
                paramx['lr'] = self.curr_lr
        
    def zero_grad(self):
        self.optimizer.zero_grad()
        
    def print_info(self):
        print(f"end_lr: {str(self.end_lr)}")
        print(f"warmup_steps: {str(self.warmup_steps)}")
        print(f"total_steps: {str(self.total_steps)}")
        print(f"weight_decay: {str(self.weight_decay)}")
        print(f"step: {str(self.curr_step)}")
        if self.curr_step != 0:
            print(f"curr_lr: {str(self.get_current_lr())}")

    def save(self, path):
        torch.save(self.optimizer.state_dict(), path)
        with open(path, 'wb') as f:
            pickle.dump(self, f)

    @classmethod
    def load(cls, path):
        with open(path, 'rb') as f:
            based_optimizer = pickle.load(f)
        based_optimizer.optimizer.load_state_dict(torch.load(path))
        return based_optimizer