Commit 67259a5b authored by novelailab's avatar novelailab

change name

parent 02d93202
from lm_arch import utils
from basedformer import utils
import math
import torch
from torch import nn
from lm_arch import gptj
from basedformer import gptj
import os
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
......@@ -77,9 +77,7 @@ def load_gpt_j(path="models/6b", state_dict=None):
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gptj.gelu_new,
"Layer": gptj.GPTJLayer
"eps": 1e-5
}
model = BaseLM.load(gptj.GPTJModel, config, path, state_dict)
return model
......@@ -3,7 +3,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from lm_train import utils
from torch.utils import data
import math
import sys
......@@ -13,17 +12,18 @@ import wandb
import numpy as np
from torch.utils.checkpoint import checkpoint as ck
from math import log2, ceil
from lm_arch import gptj, lm_base, optimizer
from lm_arch import util
from basedformer import gptj, lm_base, optimizer
from basedformer.utils import *
def _init_weights(module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
......@@ -145,25 +145,12 @@ class HyperNetworkSingle(nn.Module):
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
model_config = {
"n_layer": 12,
"n_head": 12,
"hidden_dim": 768,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
model_config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
# we need 250 batch size to train the small GPT.
......
from lm_arch import lm_base
from lm_arch.utils import *
from basedformer import lm_base
from basedformer.utils import *
import time
import torch
......
......@@ -5,7 +5,7 @@ from time import perf_counter, perf_counter_ns
import numpy as np
from tqdm import tqdm
from contextlib import contextmanager
from lm_arch.hypernet import *
from basedformer.hypernet import *
import sys
#replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
......
import setuptools
setuptools.setup(
name="basedformer",
version="0.1",
author="Eren Doğan",
description="Modular and minimal transformer codebase for experimentation.",
packages=setuptools.find_packages(),
include_package_data=True,
python_requires='>=3.7',
package_data={'basedformer': ['*.json']},
install_requires=['dotmap',
'numpy']
)
\ No newline at end of file
......@@ -6,7 +6,7 @@ import torch.optim as optim
from pathlib import Path
from lm_train import utils
from torch.utils import data
from lm_arch import lm_base, optimizer
from basedformer import lm_base, optimizer
import yaml
import sys
from tqdm import tqdm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment