Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
c8d491e1
Commit
c8d491e1
authored
May 11, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
eval harness works
parent
44751bc6
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
158 additions
and
421 deletions
+158
-421
.gitignore
.gitignore
+3
-1
basedformer/__init__.py
basedformer/__init__.py
+8
-0
basedformer/gptj.py
basedformer/gptj.py
+4
-2
basedformer/lm_utils.py
basedformer/lm_utils.py
+40
-9
basedformer/sampling.py
basedformer/sampling.py
+0
-2
eval_tasks/eval_adapter.py
eval_tasks/eval_adapter.py
+80
-387
run_pyfra.py
run_pyfra.py
+17
-12
scripts/comparehf.py
scripts/comparehf.py
+6
-8
No files found.
.gitignore
View file @
c8d491e1
...
@@ -132,4 +132,6 @@ models
...
@@ -132,4 +132,6 @@ models
gptjconvert
gptjconvert
j6b_vanilla
j6b_vanilla
wandb
wandb
*.map
*.map
\ No newline at end of file
pretrained
lm_cache
\ No newline at end of file
basedformer/__init__.py
View file @
c8d491e1
from
.
import
gptj
MODEL_MAP
=
{
"gptj"
:
(
gptj
.
GPTJModel
,
gptj
.
GPTJConfig
),
}
def
get_model
(
model_name
:
str
):
return
MODEL_MAP
[
model_name
]
basedformer/gptj.py
View file @
c8d491e1
...
@@ -13,7 +13,7 @@ except ImportError:
...
@@ -13,7 +13,7 @@ except ImportError:
import
os
import
os
from
pathlib
import
Path
from
pathlib
import
Path
import
math
import
math
from
basedformer
import
lm_
base
from
basedformer
import
lm_
utils
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
def
fixed_pos_embedding
(
dim
=
None
,
seq_len
=
None
,
x
=
None
):
def
fixed_pos_embedding
(
dim
=
None
,
seq_len
=
None
,
x
=
None
):
...
@@ -192,6 +192,7 @@ class GPTJLayer(nn.Module):
...
@@ -192,6 +192,7 @@ class GPTJLayer(nn.Module):
class
GPTJModel
(
nn
.
Module
):
class
GPTJModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
self
.
n_layer
=
config
.
n_layer
self
.
n_layer
=
config
.
n_layer
self
.
hidden_dim
=
config
.
hidden_dim
self
.
hidden_dim
=
config
.
hidden_dim
self
.
vocab_embed
=
nn
.
Embedding
(
config
.
vocab_dim
,
self
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
vocab_embed
=
nn
.
Embedding
(
config
.
vocab_dim
,
self
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
...
@@ -248,6 +249,7 @@ class GPTJModel(nn.Module):
...
@@ -248,6 +249,7 @@ class GPTJModel(nn.Module):
class
GPTJConfig
:
class
GPTJConfig
:
n_layer
:
int
=
6
n_layer
:
int
=
6
n_head
:
int
=
8
n_head
:
int
=
8
n_tokens
:
int
=
2048
hidden_dim
:
int
=
512
hidden_dim
:
int
=
512
vocab_dim
:
int
=
50400
vocab_dim
:
int
=
50400
eps
:
float
=
1e-5
eps
:
float
=
1e-5
...
@@ -265,5 +267,5 @@ def load_gpt_j(path="models/6b", state_dict=None):
...
@@ -265,5 +267,5 @@ def load_gpt_j(path="models/6b", state_dict=None):
"eps"
:
1e-5
"eps"
:
1e-5
}
}
config
=
GPTJConfig
(
**
config
)
config
=
GPTJConfig
(
**
config
)
model
=
lm_
base
.
load
(
GPTJModel
,
config
,
path
)
model
=
lm_
utils
.
_load_dict_model
(
GPTJModel
,
config
,
path
)
return
model
return
model
basedformer/lm_
base
.py
→
basedformer/lm_
utils
.py
View file @
c8d491e1
from
basedformer
import
utils
from
basedformer
import
utils
import
basedformer
import
math
import
math
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -35,7 +36,34 @@ def no_init(model_class, config):
...
@@ -35,7 +36,34 @@ def no_init(model_class, config):
model
=
utils
.
no_init
(
lambda
:
model_class
(
config
))
model
=
utils
.
no_init
(
lambda
:
model_class
(
config
))
return
model
return
model
def
load
(
model_class
,
config
,
path
=
None
,
state_dict
=
None
,
strict
=
False
):
def
save
(
model
,
path
):
try
:
os
.
mkdir
(
path
)
except
:
pass
checkpoint
=
{}
for
i
,
x
in
enumerate
(
model
.
state_dict
()
.
items
()):
checkpoint
[
x
[
0
]]
=
f
"{path}/b{i}.pt"
torch
.
save
(
x
[
1
],
f
"{path}/b{i}.pt"
)
torch
.
save
(
checkpoint
,
f
"{path}/m.pt"
)
def
load_from_path
(
config_folder
=
None
,
strict
=
False
):
config_folder
=
Path
(
config_folder
)
config
=
_load_config_file
(
config_folder
/
"config.json"
)
model_class
=
basedformer
.
get_model
(
config
[
"model_class"
])[
0
]
config_class
=
basedformer
.
get_model
(
config
[
"model_class"
])[
1
]
model_path
=
config
[
"model_path"
]
model_config
=
config
[
"model_config"
]
model_config
=
config_class
(
**
model_config
)
print
(
model_config
)
if
model_path
==
"."
:
# model_path is the config_folder directory.
model_path
=
config_folder
model_path
=
Path
(
model_path
)
/
"lm"
model
=
_load_dict_model
(
model_class
,
model_config
,
model_path
,
strict
=
strict
)
return
model
def
_load_dict_model
(
model_class
,
config
,
path
=
None
,
state_dict
=
None
,
strict
=
False
):
# I am kinda sad that we will not have a load function in lm object itself.
# I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions -- actually nope.
# might be better to add load functions -- actually nope.
if
path
:
if
path
:
...
@@ -45,13 +73,16 @@ def load(model_class, config, path=None, state_dict=None, strict=False):
...
@@ -45,13 +73,16 @@ def load(model_class, config, path=None, state_dict=None, strict=False):
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
return
model
return
model
def
save
(
model
,
path
):
def
_load_config_file
(
config_file
):
try
:
os
.
mkdir
(
path
)
if
not
config_file
.
exists
():
except
:
pass
raise
FileNotFoundError
(
f
"Config file not found at {config_file}"
)
checkpoint
=
{}
for
i
,
x
in
enumerate
(
model
.
state_dict
()
.
items
()):
with
open
(
config_file
)
as
f
:
checkpoint
[
x
[
0
]]
=
f
"{path}/b{i}.pt"
config
=
json
.
load
(
f
)
torch
.
save
(
x
[
1
],
f
"{path}/b{i}.pt"
)
torch
.
save
(
checkpoint
,
f
"{path}/m.pt"
)
return
config
basedformer/sampling.py
View file @
c8d491e1
...
@@ -2,7 +2,6 @@ from basedformer import gptj
...
@@ -2,7 +2,6 @@ from basedformer import gptj
from
basedformer.utils
import
*
from
basedformer.utils
import
*
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
icecream
import
ic
from
icecream
import
ic
import
functorch
import
time
import
time
import
sys
import
sys
...
@@ -190,7 +189,6 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0
...
@@ -190,7 +189,6 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0
"rep_pen"
:
rep_pen
,
"rep_pen"
:
rep_pen
,
}
}
funcnomial
=
functorch
.
vmap
(
func_multinomial
,
randomness
=
"different"
)
for
_
in
range
(
tokens_to_generate
):
for
_
in
range
(
tokens_to_generate
):
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
)
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
)
logits
=
logits
[:,
-
1
,
:]
#get the last token in the seq
logits
=
logits
[:,
-
1
,
:]
#get the last token in the seq
...
...
eval_tasks/eval_adapter.py
View file @
c8d491e1
This diff is collapsed.
Click to expand it.
run_pyfra.py
View file @
c8d491e1
...
@@ -24,20 +24,25 @@ config_obj.create_service(overwrite=True)
...
@@ -24,20 +24,25 @@ config_obj.create_service(overwrite=True)
remote
=
config_obj
.
get_pyfra_remote
()
remote
=
config_obj
.
get_pyfra_remote
()
env1
=
remote
.
env
(
'noname'
,
python_version
=
None
)
env1
=
remote
.
env
(
'noname'
,
python_version
=
None
)
path
=
env1
.
path
(
'/home/xuser/diffusionstorage/workspace/kuru/basedformer'
)
path
=
env1
.
path
(
'/home/xuser/diffusionstorage/workspace/kuru/basedformer'
)
env1
.
sh
(
'pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl'
)
env1
.
sh
(
'pip install einops numpy'
)
if
False
:
env1
.
sh
(
'pip install tqdm'
)
env1
.
sh
(
'pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl'
)
env1
.
sh
(
'pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo'
)
env1
.
sh
(
'pip install einops numpy'
)
env1
.
sh
(
'pip3 install einops==0.4.1 pyyaml wandb'
)
env1
.
sh
(
'pip install tqdm'
)
env1
.
sh
(
'wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4'
)
env1
.
sh
(
'pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo'
)
env1
.
sh
(
'pip3 install dotmap icecream'
)
env1
.
sh
(
'pip3 install einops==0.4.1 pyyaml wandb'
)
path
.
sh
(
"pip3 install --editable ."
)
env1
.
sh
(
'wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4'
)
#path.sh("pip3 uninstall torch")
env1
.
sh
(
'pip3 install dotmap icecream'
)
#path.sh("pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113")
path
.
sh
(
"pip3 install --editable ."
)
#path.sh("pip3 uninstall torch")
#path.sh("pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113")
with
always_rerun
():
with
always_rerun
():
if
bash
:
if
True
:
path
.
sh
(
"bash"
)
path
.
sh
(
"python3 ../lm-evaluation-harness/main.py --model basedformer --batch_size 8 --model_args pretrained=/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/gptj-6b --device 0 --tasks lambada"
)
#path.sh("python3 ../lm-evaluation-harness/main.py --batch_size 8")
else
:
else
:
print
(
f
"Running {sys.argv[1]}"
)
print
(
f
"Running {sys.argv[1]}"
)
path
.
sh
(
f
'python3 {sys.argv[1]}'
)
path
.
sh
(
f
'python3 {sys.argv[1]}'
)
scripts/comparehf.py
View file @
c8d491e1
from
basedformer
import
gptj
from
basedformer
import
gptj
from
basedformer.utils
import
*
from
basedformer.utils
import
*
import
basedformer.lm_utils
as
lmu
import
time
import
time
import
torch
import
torch
...
@@ -8,11 +9,7 @@ import numpy as np
...
@@ -8,11 +9,7 @@ import numpy as np
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers
import
(
from
transformers
import
GPTNeoForCausalLM
AutoModelForCausalLM
,
GPTNeoForCausalLM
,
AutoConfig
,
)
#replicating timeit magic function of ipython
#replicating timeit magic function of ipython
def
timeit
(
func
,
r
=
1
,
n
=
5
,
quiet
=
False
,
function
=
None
,
do_tqdm
=
False
,
first
=
True
):
def
timeit
(
func
,
r
=
1
,
n
=
5
,
quiet
=
False
,
function
=
None
,
do_tqdm
=
False
,
first
=
True
):
precision
=
'ns'
precision
=
'ns'
...
@@ -67,10 +64,11 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
...
@@ -67,10 +64,11 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
with
torch
.
no_grad
():
with
torch
.
no_grad
():
based_model
=
gptj
.
load_gpt_j
()
.
cuda
()
.
half
()
.
eval
()
hf_model
=
no_init
(
lambda
:
GPTNeoForCausalLM
.
from_pretrained
(
'/home/xuser/models/j6b_ckpt_14001'
))
.
cuda
()
.
half
()
.
eval
()
print
(
"Loaded based model"
)
hf_model
=
no_init
(
lambda
:
AutoModelForCausalLM
.
from_pretrained
(
'/home/xuser/models/j6b_ckpt_14001'
))
.
cuda
()
.
half
()
.
eval
()
print
(
"Loaded hf model"
)
print
(
"Loaded hf model"
)
path
=
"/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/gptj-6b"
based_model
=
lmu
.
load_from_path
(
path
)
.
cuda
()
.
half
()
.
eval
()
print
(
"Loaded based model"
)
x
=
torch
.
randint
(
0
,
50256
,
(
1
,
2048
))
.
cuda
()
.
long
()
x
=
torch
.
randint
(
0
,
50256
,
(
1
,
2048
))
.
cuda
()
.
long
()
assert
torch
.
allclose
(
hf_model
.
transformer
.
wte
(
x
),
based_model
.
vocab_embed
(
x
))
assert
torch
.
allclose
(
hf_model
.
transformer
.
wte
(
x
),
based_model
.
vocab_embed
(
x
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment