Commit fb134d28 authored by novelailab's avatar novelailab

zero2 works

parent 9d27a5cc
...@@ -95,6 +95,7 @@ class SelfAttention(nn.Module): ...@@ -95,6 +95,7 @@ class SelfAttention(nn.Module):
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions) sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
self.register_buffer("sin", sin) self.register_buffer("sin", sin)
self.register_buffer("cos", cos) self.register_buffer("cos", cos)
if self.config.masked_softmax_fusion:
self.fused_softmax = FusedScaleMaskSoftmax( self.fused_softmax = FusedScaleMaskSoftmax(
input_in_fp16=False, input_in_fp16=False,
input_in_bf16=True, input_in_bf16=True,
...@@ -104,6 +105,8 @@ class SelfAttention(nn.Module): ...@@ -104,6 +105,8 @@ class SelfAttention(nn.Module):
attn_mask_type="causal", attn_mask_type="causal",
scaled_masked_softmax_fusion=True, scaled_masked_softmax_fusion=True,
) )
else:
self.fused_softmax = None
def forward(self, x, kv=None, cache=False): def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim B, S, H = x.shape # batch, sequence, hidden_dim
...@@ -242,7 +245,7 @@ class GPTJModel(base_lm.BaseModel): ...@@ -242,7 +245,7 @@ class GPTJModel(base_lm.BaseModel):
'activation': gelu_new, 'activation': gelu_new,
'SelfAttention': SelfAttention, 'SelfAttention': SelfAttention,
'FeedForward': FeedForward, 'FeedForward': FeedForward,
'masked_softmax_fusion': False, 'masked_softmax_fusion': True,
} }
base_lm.BaseModel.__init__(self, user_config, **kwargs) base_lm.BaseModel.__init__(self, user_config, **kwargs)
if self.config.masked_softmax_fusion: if self.config.masked_softmax_fusion:
......
...@@ -77,6 +77,10 @@ class BasedOptimizer: ...@@ -77,6 +77,10 @@ class BasedOptimizer:
eps=self.eps, eps=self.eps,
) )
elif self.optimizer_name == "zero2":
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
self.optimizer = DistributedFusedAdam(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps, grad_sync_dtype=torch.float32)
elif self.optimizer_name == "adafactor": elif self.optimizer_name == "adafactor":
try: try:
from transformers.optimization import Adafactor from transformers.optimization import Adafactor
......
...@@ -5,7 +5,7 @@ import torch.cuda.amp as amp ...@@ -5,7 +5,7 @@ import torch.cuda.amp as amp
import torch.optim as optim import torch.optim as optim
from pathlib import Path from pathlib import Path
from torch.utils import data from torch.utils import data
from basedformer import optimizer, utils, lm_utils from basedformer import optimizer, utils, lm_utils, dataset
import yaml import yaml
import sys import sys
from tqdm import tqdm from tqdm import tqdm
...@@ -16,17 +16,11 @@ import os ...@@ -16,17 +16,11 @@ import os
from icecream import ic from icecream import ic
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP #from torch.nn.parallel import DistributedDataParallel as DDP
from apex.parallel.distributed import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from dotmap import DotMap from dotmap import DotMap
import argparse import argparse
from torch.distributed.fsdp import (
FullyShardedDataParallel,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
default_auto_wrap_policy,
)
def setup(rank, world_size): def setup(rank, world_size):
#os.environ['MASTER_ADDR'] = 'localhost' #os.environ['MASTER_ADDR'] = 'localhost'
...@@ -97,14 +91,19 @@ def fsdp_train(args, model, train_loader, opt): ...@@ -97,14 +91,19 @@ def fsdp_train(args, model, train_loader, opt):
norm = norm.matmul(norm.transpose(-1,-2)) norm = norm.matmul(norm.transpose(-1,-2))
contrastive_loss = torch.matmul(hs, hs.transpose(-2, -1)).div(norm).abs().mean() contrastive_loss = torch.matmul(hs, hs.transpose(-2, -1)).div(norm).abs().mean()
gas_loss += contrastive_loss * args.contrastive_loss gas_loss += contrastive_loss * args.contrastive_loss
if args["loss_scale"]: if args["loss_scale"]:
with opt.optimizer.no_sync():
scaler.scale(gas_loss).backward() scaler.scale(gas_loss).backward()
else: else:
with opt.optimizer.no_sync():
gas_loss.backward() gas_loss.backward()
loss += gas_loss.item() loss += gas_loss.item()
loss = loss / gas loss = loss / gas
opt.optimizer.grad_sync()
if args["loss_scale"]: if args["loss_scale"]:
scaler.unscale_(opt.optimizer) scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1) torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
...@@ -116,10 +115,10 @@ def fsdp_train(args, model, train_loader, opt): ...@@ -116,10 +115,10 @@ def fsdp_train(args, model, train_loader, opt):
if args["loss_scale"]: if args["loss_scale"]:
scaler.update() scaler.update()
#opt.zero_grad() opt.zero_grad()
model.zero_grad(set_to_none=True) #model.zero_grad(set_to_none=True)
sec_per_step = (time.perf_counter() - timex) sec_per_step = (time.perf_counter() - timex)
flops = get_flops(args, model.module, sec_per_step) flops = get_flops(args, model, sec_per_step)
step_per_sec = (1. / sec_per_step) step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 2048) * bs * gas * world_size tokens_per_sec = (step_per_sec * 2048) * bs * gas * world_size
batch_size = bs * gas * world_size batch_size = bs * gas * world_size
...@@ -153,15 +152,17 @@ def main(rank, global_rank, world_size, args): ...@@ -153,15 +152,17 @@ def main(rank, global_rank, world_size, args):
setup(rank, world_size) setup(rank, world_size)
Path(args["save_path"]).mkdir(parents=True, exist_ok=True) Path(args["save_path"]).mkdir(parents=True, exist_ok=True)
model = lm_utils.load_from_path("pretrained/gpt-j-base").float().to(rank) model = lm_utils.load_from_path("/home/xuser/nvme1/pretrained/gpt-j-base").half().to(rank)
fsdp_model = DDP(model, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True) #fsdp_model = DDP(model, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
#fsdp_model = DDP(model)
fsdp_model = model
utils.print_parameters(fsdp_model) utils.print_parameters(fsdp_model)
ic("model loaded") ic("model loaded")
opt = optimizer.BasedOptimizer(fsdp_model.parameters(), args, "zero1") opt = optimizer.BasedOptimizer(fsdp_model.parameters(), args, "zero2")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function. # TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step) print(opt.curr_step)
train_dataset = utils.ShardedDataset(2049, args["data_path"], world_size=world_size, rank=global_rank) train_dataset = dataset.ShardedDataset(2049, args["data_path"], world_size=world_size, rank=global_rank)
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, ) train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
if global_rank == 0: if global_rank == 0:
wandb.init(project="basedformer-tests", name=args["run_name"], config={**args, **model.config}) wandb.init(project="basedformer-tests", name=args["run_name"], config={**args, **model.config})
...@@ -172,21 +173,21 @@ def main(rank, global_rank, world_size, args): ...@@ -172,21 +173,21 @@ def main(rank, global_rank, world_size, args):
if __name__ == "__main__": if __name__ == "__main__":
train_config = { train_config = {
"data_path": "dataset/sigurd-1G.map", "data_path": "/home/xuser/nvme1/dataset/sigurd-1G.map",
"save_path": "models/gptj-sigurd-1G-vanilla", "save_path": "models/gptj-sigurd-1G-vanilla",
"do_save": True, "do_save": False,
"run_name": "gptj-sigurd-1G-vanilla", "run_name": "gptj-sigurd-1G-vanilla",
"lr": 6e-5, "lr": 6e-5,
"end_lr": 3e-5, "end_lr": 3e-5,
"warmup_steps": 100, "warmup_steps": 100,
"anneal_steps": 7850, "anneal_steps": 7850,
"bs": 2, "bs": 2,
"gas": 2, "gas": 8,
"seed": 69, "seed": 69,
"save_every": 500, "save_every": 500,
"amp": True, "amp": False,
"loss_scale": True, "loss_scale": True,
"cast_to": torch.float16, "cast_to": torch.bfloat16,
"contrastive_loss": False, "contrastive_loss": False,
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment