Commit 0879cf91 authored by novelailab's avatar novelailab

fix formatting

parent 4141d527
...@@ -41,12 +41,15 @@ class BasedOptimizer: ...@@ -41,12 +41,15 @@ class BasedOptimizer:
if optimizer == "adamw": if optimizer == "adamw":
self.optimizer = optim.AdamW(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps) self.optimizer = optim.AdamW(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif optimizer == "adamw8bit": elif optimizer == "adamw8bit":
import bitsandbytes as bnb import bitsandbytes as bnb
self.optimizer = bnb.optim.Adam8bit(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps) self.optimizer = bnb.optim.Adam8bit(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif optimizer == "adafactor": elif optimizer == "adafactor":
try: try:
from transformers.optimization import Adafactor from transformers.optimization import Adafactor
except ImportError: except ImportError:
raise ImportError("Please install transformers for Adafactor") raise ImportError("Please install transformers for Adafactor")
...@@ -55,6 +58,7 @@ class BasedOptimizer: ...@@ -55,6 +58,7 @@ class BasedOptimizer:
def step(self, scaler=None): def step(self, scaler=None):
if scaler: if scaler:
scaler.step(self.optimizer) scaler.step(self.optimizer)
else: else:
self.optimizer.step() self.optimizer.step()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment