import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim

#Based Optimizer
class BasedOptimizer:
    def __init__(self, model, config, optimizer):
        self.min_lr = config["min_lr"] if "min_lr" in config else 1e-06
        self.warmup_end = config["lr"] if "lr" in config else 5e-06
        self.warmup_init = config["warmup_init"] if "warmup_init" in config else 0
        self.warmup_steps = config["warmup_steps"] if "warmup_steps" in config else 1
        self.total_steps = config["total_steps"] if "total_steps" in config else None
        self.weight_decay = config["weight_decay"] if "weight_decay" in config else 0
        self.start_step = config["start_step"] if "start_step" in config else 0
        self.curr_step = self.start_step
        self.curr_lr = 0

        optim_func = optim.AdamW


        self.optimizers = optim_func(model.parameters(), lr=self.warmup_init, weight_decay=self.weight_decay, betas=config["betas"], eps=config["eps"])
    
    def get_current_lr(self):
        cosine_lr = self.min_lr + 0.5 * (self.warmup_end - self.min_lr) * (1 + math.cos(math.pi * min(1.0, max(0, self.curr_step - self.warmup_steps) / (self.total_steps - self.warmup_steps))))
        target_lr = self.warmup_end if self.curr_step < self.warmup_steps else cosine_lr
        return inter(self.warmup_init, target_lr, max(0, self.curr_step - self.start_step) / max(1, self.warmup_steps))
        return min(self.end_lr * (self.curr_step / self.warmup_steps), self.end_lr)

    def backward(self, loss):
        self.optimizers[0].backward(loss, update_master_grads=False)
        #loss.backward()
        
    def step(self, scaler=None):
        self.curr_lr = self.get_current_lr()
        for optimizer in self.optimizers:
            for paramx in optimizer.param_groups:
                paramx['lr'] = self.curr_lr
            optimizer.update_master_grads()
            if scaler:
                for optimizer in self.optimizers:
                    scaler.step(optimizer)
            else:
                optimizer.step()

        self.curr_step += 1
        
    def zero_grad(self):
        for optimizer in self.optimizers:
            optimizer.zero_grad()
        
    def print_info(self):
        print(f"min_lr: {str(self.min_lr)}")
        print(f"warmup_end: {str(self.warmup_end)}")
        print(f"warmup_init: {str(self.warmup_init)}")
        print(f"warmup_steps: {str(self.warmup_steps)}")
        print(f"start_step: {str(self.start_step)}")
        print(f"total_steps: {str(self.total_steps)}")
        print(f"weight_decay: {str(self.weight_decay)}")
        print(f"step: {str(self.curr_step)}")
        print(f"curr_lr: {str(self.get_current_lr())}")