Commit 67259a5b authored by novelailab's avatar novelailab

change name

parent 02d93202
from lm_arch import utils from basedformer import utils
import math import math
import torch import torch
from torch import nn from torch import nn
from lm_arch import gptj from basedformer import gptj
import os import os
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense. #Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
...@@ -77,9 +77,7 @@ def load_gpt_j(path="models/6b", state_dict=None): ...@@ -77,9 +77,7 @@ def load_gpt_j(path="models/6b", state_dict=None):
"n_head": 16, "n_head": 16,
"hidden_dim": 4096, "hidden_dim": 4096,
"vocab_dim": 50400, "vocab_dim": 50400,
"eps": 1e-5, "eps": 1e-5
"activation": gptj.gelu_new,
"Layer": gptj.GPTJLayer
} }
model = BaseLM.load(gptj.GPTJModel, config, path, state_dict) model = BaseLM.load(gptj.GPTJModel, config, path, state_dict)
return model return model
...@@ -3,7 +3,6 @@ import torch ...@@ -3,7 +3,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from pathlib import Path from pathlib import Path
from lm_train import utils
from torch.utils import data from torch.utils import data
import math import math
import sys import sys
...@@ -13,17 +12,18 @@ import wandb ...@@ -13,17 +12,18 @@ import wandb
import numpy as np import numpy as np
from torch.utils.checkpoint import checkpoint as ck from torch.utils.checkpoint import checkpoint as ck
from math import log2, ceil from math import log2, ceil
from lm_arch import gptj, lm_base, optimizer from basedformer import gptj, lm_base, optimizer
from lm_arch import util from basedformer.utils import *
def _init_weights(module): def _init_weights(module):
"""Initialize the weights."""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02) module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.Embedding): elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02) module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
...@@ -145,25 +145,12 @@ class HyperNetworkSingle(nn.Module): ...@@ -145,25 +145,12 @@ class HyperNetworkSingle(nn.Module):
x = x.mul(torch.sigmoid(x)) x = x.mul(torch.sigmoid(x))
return x.bfloat16() return x.bfloat16()
model_config = {
"n_layer": 12,
"n_head": 12,
"hidden_dim": 768,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
model_config = { model_config = {
"n_layer": 28, "n_layer": 28,
"n_head": 16, "n_head": 16,
"hidden_dim": 4096, "hidden_dim": 4096,
"vocab_dim": 50400, "vocab_dim": 50400,
"eps": 1e-5, "eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
} }
# we need 250 batch size to train the small GPT. # we need 250 batch size to train the small GPT.
......
from lm_arch import lm_base from basedformer import lm_base
from lm_arch.utils import * from basedformer.utils import *
import time import time
import torch import torch
......
...@@ -5,7 +5,7 @@ from time import perf_counter, perf_counter_ns ...@@ -5,7 +5,7 @@ from time import perf_counter, perf_counter_ns
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from contextlib import contextmanager from contextlib import contextmanager
from lm_arch.hypernet import * from basedformer.hypernet import *
import sys import sys
#replicating timeit magic function of ipython #replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True): def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
......
import setuptools
setuptools.setup(
name="basedformer",
version="0.1",
author="Eren Doğan",
description="Modular and minimal transformer codebase for experimentation.",
packages=setuptools.find_packages(),
include_package_data=True,
python_requires='>=3.7',
package_data={'basedformer': ['*.json']},
install_requires=['dotmap',
'numpy']
)
\ No newline at end of file
...@@ -6,7 +6,7 @@ import torch.optim as optim ...@@ -6,7 +6,7 @@ import torch.optim as optim
from pathlib import Path from pathlib import Path
from lm_train import utils from lm_train import utils
from torch.utils import data from torch.utils import data
from lm_arch import lm_base, optimizer from basedformer import lm_base, optimizer
import yaml import yaml
import sys import sys
from tqdm import tqdm from tqdm import tqdm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment