"...svn:/svn.code.sf.net/p/irrlicht/code/trunk@643" did not exist on "c00ce1d372a86762ca0fad9d3c6a342aa3feda89"
Commit 840dd7f4 authored by novelailab's avatar novelailab

cleanup

parent 1c9d3a31
This diff is collapsed.
from torch import nn
import torch
import os
import math
from lm_arch.utils import *
# Can access and change every module from here, as both Layer class and ff and attn classes are passed from GPTModel.
class GPTModel(nn.Module):
def __init__(self, hidden_dim=512, n_layer=12, n_head=4, vocab_dim=50400, eps=1e-4, activation=nn.GELU(), Layer=None, SelfAttention=None, FeedForward=None, device="cuda", dtype=torch.float16):
super(GPTModel, self).__init__()
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
#TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet.
#TODO: Do we want to pass a config object everywhere? I don't exactly like that but passing a lot of variables is a bit ugly too.
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
for layer in self.layers:
x = layer(x, hypernetwork, act_ck)
x = self.ln_final(x)
return x
def get_logits(self, x, hypernetwork=None, act_ck=False):
x = self.forward(x, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.lm_head(x)
return x.float()
@classmethod
def load(cls, config, path=None, state_dict=None):
if path:
state_dict = SplitCheckpoint(path, device="cuda")
model = no_init(lambda: cls(**config))
model.load_state_dict(state_dict, strict=False)
return model
@classmethod
def init(cls, config):
model = cls(**config)
return model
@classmethod
def neox_init(cls, config):
model = cls(**config)
modules = [*model.layers[:-1], model.vocab_embed, model.ln_final, model.lm_head]
init = small_init_method(config["hidden_dim"])
for module in modules:
for param in module.parameters():
init(param)
last_layer = model.layers[-1]
last_layer_init = wang_init_method(config["n_layer"], config["hidden_dim"])
for param in last_layer.parameters():
last_layer_init(param)
return model
@classmethod
def simple_init(cls, config):
model = cls(**config)
state = model.state_dict()
for k in state:
state[k] = state[k] / math.sqrt(2 * config["n_layer"])
model.load_state_dict(state)
return model
def save(self, path):
try: os.mkdir(path)
except: pass
checkpoint = {}
for i, x in enumerate(self.state_dict().items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
def wang_init_method(n_layers, dim):
std = 2 / n_layers / math.sqrt(dim)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
# Stolen from NeoX. For the 20B run wang_init used on the output layer and small_init on rest of the layers.
def small_init_method(dim):
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
std = math.sqrt(2 / (5 * dim))
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
\ No newline at end of file
...@@ -79,7 +79,7 @@ def _attn(query, key, value, causal_mask, masked_bias, ...@@ -79,7 +79,7 @@ def _attn(query, key, value, causal_mask, masked_bias,
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later. # Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype): def __init__(self, hidden_dim, n_head, device, dtype):
super(SelfAttention, self).__init__() super().__init__(self)
max_positions = 2049 max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view( bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool() 1, 1, max_positions, max_positions).bool()
...@@ -143,7 +143,7 @@ class SelfAttention(nn.Module): ...@@ -143,7 +143,7 @@ class SelfAttention(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype): def __init__(self, dim, hidden_dim, activation, device, dtype):
super(FeedForward, self).__init__() super().__init__(self)
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype) self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype) self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation self.activation = activation
...@@ -159,14 +159,16 @@ class FeedForward(nn.Module): ...@@ -159,14 +159,16 @@ class FeedForward(nn.Module):
class GPTJLayer(nn.Module): class GPTJLayer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype): def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
super(GPTJLayer, self).__init__() super().__init__(self)
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype) self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype) self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype) self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, hypernetwork=None, act_ck=False): def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5):
residual = x residual = x
if act_ck: if act_ck:
x = ck(self.ln_preattn, x) x = ck(self.ln_preattn, x)
attn_out = ck(self.attn, x) attn_out = ck(self.attn, x)
...@@ -175,17 +177,35 @@ class GPTJLayer(nn.Module): ...@@ -175,17 +177,35 @@ class GPTJLayer(nn.Module):
x = self.ln_preattn(x) x = self.ln_preattn(x)
attn_out = self.attn(x) attn_out = self.attn(x)
ff_out = self.ff(x, act_ck)
x = residual + attn_out + ff_out
if hypernetwork: if hypernetwork:
if diff_hypernets:
if interleaving_layers and layer_id % every_n == 0:
if self.tick:
hyper_out = hypernetwork[0](x)
self.tick = False
else:
hyper_out = hypernetwork[1](x)
self.tick = True
elif layer_id % every_n == 0:
hyper_out = hypernetwork[(layer_id // every_n) - 1](x)
else:
if layer_id % every_n == 0:
hyper_out = hypernetwork(x) hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % every_n == 0:
x = x + hyper_out x = x + hyper_out
return x return x
class GPTModel(nn.Module): class GPTJModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation, Layer, device, dtype): def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation, Layer, device, dtype):
super(GPTModel, self).__init__() super().__init__(self)
self.n_layer = n_layer self.n_layer = n_layer
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype) self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
...@@ -194,25 +214,6 @@ class GPTModel(nn.Module): ...@@ -194,25 +214,6 @@ class GPTModel(nn.Module):
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True) self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer): for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype)) self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
#TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet.
#TODO: Do we want to pass a config object everywhere? I don't exactly like that but passing a lot of variables is a bit ugly too.
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if ("ff2" in name or "out_proj" in name) and "weight" in name:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.n_layer)))
def forward(self, x, hypernetwork=None, act_ck=False): def forward(self, x, hypernetwork=None, act_ck=False):
x = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck) x = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck)
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
from torch import nn from torch import nn
import os import os
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class BaseLM(nn.Module): class BaseLM(nn.Module):
def __init__(self, config=None, lm=None): def __init__(self, config=None, lm=None):
self.config = config self.config = config
...@@ -57,6 +58,8 @@ class BaseLM(nn.Module): ...@@ -57,6 +58,8 @@ class BaseLM(nn.Module):
def save(self, path): def save(self, path):
if self.lm is None: if self.lm is None:
print("No LM object to save. Please first init a model.") print("No LM object to save. Please first init a model.")
return
try: os.mkdir(path) try: os.mkdir(path)
except: pass except: pass
checkpoint = {} checkpoint = {}
......
...@@ -6,7 +6,6 @@ except ImportError: ...@@ -6,7 +6,6 @@ except ImportError:
from pathlib import Path from pathlib import Path
import os import os
def no_init(loading_code): def no_init(loading_code):
def dummy(self): def dummy(self):
return return
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment