Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
c8d491e1
Commit
c8d491e1
authored
May 11, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
eval harness works
parent
44751bc6
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
158 additions
and
421 deletions
+158
-421
.gitignore
.gitignore
+3
-1
basedformer/__init__.py
basedformer/__init__.py
+8
-0
basedformer/gptj.py
basedformer/gptj.py
+4
-2
basedformer/lm_utils.py
basedformer/lm_utils.py
+40
-9
basedformer/sampling.py
basedformer/sampling.py
+0
-2
eval_tasks/eval_adapter.py
eval_tasks/eval_adapter.py
+80
-387
run_pyfra.py
run_pyfra.py
+17
-12
scripts/comparehf.py
scripts/comparehf.py
+6
-8
No files found.
.gitignore
View file @
c8d491e1
...
@@ -132,4 +132,6 @@ models
...
@@ -132,4 +132,6 @@ models
gptjconvert
gptjconvert
j6b_vanilla
j6b_vanilla
wandb
wandb
*.map
*.map
\ No newline at end of file
pretrained
lm_cache
\ No newline at end of file
basedformer/__init__.py
View file @
c8d491e1
from
.
import
gptj
MODEL_MAP
=
{
"gptj"
:
(
gptj
.
GPTJModel
,
gptj
.
GPTJConfig
),
}
def
get_model
(
model_name
:
str
):
return
MODEL_MAP
[
model_name
]
basedformer/gptj.py
View file @
c8d491e1
...
@@ -13,7 +13,7 @@ except ImportError:
...
@@ -13,7 +13,7 @@ except ImportError:
import
os
import
os
from
pathlib
import
Path
from
pathlib
import
Path
import
math
import
math
from
basedformer
import
lm_
base
from
basedformer
import
lm_
utils
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
def
fixed_pos_embedding
(
dim
=
None
,
seq_len
=
None
,
x
=
None
):
def
fixed_pos_embedding
(
dim
=
None
,
seq_len
=
None
,
x
=
None
):
...
@@ -192,6 +192,7 @@ class GPTJLayer(nn.Module):
...
@@ -192,6 +192,7 @@ class GPTJLayer(nn.Module):
class
GPTJModel
(
nn
.
Module
):
class
GPTJModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
self
.
n_layer
=
config
.
n_layer
self
.
n_layer
=
config
.
n_layer
self
.
hidden_dim
=
config
.
hidden_dim
self
.
hidden_dim
=
config
.
hidden_dim
self
.
vocab_embed
=
nn
.
Embedding
(
config
.
vocab_dim
,
self
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
vocab_embed
=
nn
.
Embedding
(
config
.
vocab_dim
,
self
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
...
@@ -248,6 +249,7 @@ class GPTJModel(nn.Module):
...
@@ -248,6 +249,7 @@ class GPTJModel(nn.Module):
class
GPTJConfig
:
class
GPTJConfig
:
n_layer
:
int
=
6
n_layer
:
int
=
6
n_head
:
int
=
8
n_head
:
int
=
8
n_tokens
:
int
=
2048
hidden_dim
:
int
=
512
hidden_dim
:
int
=
512
vocab_dim
:
int
=
50400
vocab_dim
:
int
=
50400
eps
:
float
=
1e-5
eps
:
float
=
1e-5
...
@@ -265,5 +267,5 @@ def load_gpt_j(path="models/6b", state_dict=None):
...
@@ -265,5 +267,5 @@ def load_gpt_j(path="models/6b", state_dict=None):
"eps"
:
1e-5
"eps"
:
1e-5
}
}
config
=
GPTJConfig
(
**
config
)
config
=
GPTJConfig
(
**
config
)
model
=
lm_
base
.
load
(
GPTJModel
,
config
,
path
)
model
=
lm_
utils
.
_load_dict_model
(
GPTJModel
,
config
,
path
)
return
model
return
model
basedformer/lm_
base
.py
→
basedformer/lm_
utils
.py
View file @
c8d491e1
from
basedformer
import
utils
from
basedformer
import
utils
import
basedformer
import
math
import
math
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -35,7 +36,34 @@ def no_init(model_class, config):
...
@@ -35,7 +36,34 @@ def no_init(model_class, config):
model
=
utils
.
no_init
(
lambda
:
model_class
(
config
))
model
=
utils
.
no_init
(
lambda
:
model_class
(
config
))
return
model
return
model
def
load
(
model_class
,
config
,
path
=
None
,
state_dict
=
None
,
strict
=
False
):
def
save
(
model
,
path
):
try
:
os
.
mkdir
(
path
)
except
:
pass
checkpoint
=
{}
for
i
,
x
in
enumerate
(
model
.
state_dict
()
.
items
()):
checkpoint
[
x
[
0
]]
=
f
"{path}/b{i}.pt"
torch
.
save
(
x
[
1
],
f
"{path}/b{i}.pt"
)
torch
.
save
(
checkpoint
,
f
"{path}/m.pt"
)
def
load_from_path
(
config_folder
=
None
,
strict
=
False
):
config_folder
=
Path
(
config_folder
)
config
=
_load_config_file
(
config_folder
/
"config.json"
)
model_class
=
basedformer
.
get_model
(
config
[
"model_class"
])[
0
]
config_class
=
basedformer
.
get_model
(
config
[
"model_class"
])[
1
]
model_path
=
config
[
"model_path"
]
model_config
=
config
[
"model_config"
]
model_config
=
config_class
(
**
model_config
)
print
(
model_config
)
if
model_path
==
"."
:
# model_path is the config_folder directory.
model_path
=
config_folder
model_path
=
Path
(
model_path
)
/
"lm"
model
=
_load_dict_model
(
model_class
,
model_config
,
model_path
,
strict
=
strict
)
return
model
def
_load_dict_model
(
model_class
,
config
,
path
=
None
,
state_dict
=
None
,
strict
=
False
):
# I am kinda sad that we will not have a load function in lm object itself.
# I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions -- actually nope.
# might be better to add load functions -- actually nope.
if
path
:
if
path
:
...
@@ -45,13 +73,16 @@ def load(model_class, config, path=None, state_dict=None, strict=False):
...
@@ -45,13 +73,16 @@ def load(model_class, config, path=None, state_dict=None, strict=False):
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
return
model
return
model
def
save
(
model
,
path
):
def
_load_config_file
(
config_file
):
try
:
os
.
mkdir
(
path
)
if
not
config_file
.
exists
():
except
:
pass
raise
FileNotFoundError
(
f
"Config file not found at {config_file}"
)
checkpoint
=
{}
for
i
,
x
in
enumerate
(
model
.
state_dict
()
.
items
()):
with
open
(
config_file
)
as
f
:
checkpoint
[
x
[
0
]]
=
f
"{path}/b{i}.pt"
config
=
json
.
load
(
f
)
torch
.
save
(
x
[
1
],
f
"{path}/b{i}.pt"
)
torch
.
save
(
checkpoint
,
f
"{path}/m.pt"
)
return
config
basedformer/sampling.py
View file @
c8d491e1
...
@@ -2,7 +2,6 @@ from basedformer import gptj
...
@@ -2,7 +2,6 @@ from basedformer import gptj
from
basedformer.utils
import
*
from
basedformer.utils
import
*
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
icecream
import
ic
from
icecream
import
ic
import
functorch
import
time
import
time
import
sys
import
sys
...
@@ -190,7 +189,6 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0
...
@@ -190,7 +189,6 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0
"rep_pen"
:
rep_pen
,
"rep_pen"
:
rep_pen
,
}
}
funcnomial
=
functorch
.
vmap
(
func_multinomial
,
randomness
=
"different"
)
for
_
in
range
(
tokens_to_generate
):
for
_
in
range
(
tokens_to_generate
):
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
)
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
)
logits
=
logits
[:,
-
1
,
:]
#get the last token in the seq
logits
=
logits
[:,
-
1
,
:]
#get the last token in the seq
...
...
eval_tasks/eval_adapter.py
View file @
c8d491e1
import
best_download
# patch best_download (eval harness downloader) to only happen on the first local rank
fn
=
best_download
.
download_file
import
os
import
sys
import
dataclasses
from
functools
import
partial
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
))
)
from
tqdm
import
tqdm
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn
as
nn
import
transformers
from
lm_eval.models.gpt2
import
GPT2LM
import
basedformer.sampling
as
sampling
from
lm_eval
import
tasks
,
evaluator
,
utils
,
base
from
lm_eval.base
import
BaseLM
from
basedformer
import
optimizer
,
utils
,
gptj
,
noemblm
,
gpt2
from
basedformer
import
gptj
class
BasedformerLM
(
BaseLM
):
class
EvalHarnessAdapter
(
GPT2LM
):
def
__init__
(
"""
self
,
An adapter to run NeoX models on LM Evaluation Harness (https://github.com/EleutherAI/lm-evaluation-harness) tasks.
device
=
"cuda"
,
pretrained
:
nn
.
Module
=
None
,
Args:
tokenizer
=
None
,
model: A NeoX Model
batch_size
=
1
,
forward_step_fn: A function that runs a forward pass through the model, returning `tuple(loss, logits)`.
):
neox_args: a NeoXArgs object containing the model configuration.
super
()
.
__init__
()
batch_size (optional): An argument to override the batch size, which defaults to batch size per gpu * dp world size.
"""
if
device
:
if
device
not
in
[
"cuda"
,
"cpu"
]:
def
__init__
(
self
,
model
,
forward_step_fn
,
neox_args
,
batch_size
=
None
):
device
=
int
(
device
)
self
.
cache_hook
=
base
.
CacheHook
(
None
)
self
.
_device
=
torch
.
device
(
device
)
self
.
model
=
model
print
(
f
"Using device '{device}'"
)
self
.
neox_args
=
neox_args
else
:
self
.
tokenizer
=
neox_args
.
tokenizer
print
(
"Device not specified"
)
self
.
_device
=
torch
.
device
(
f
"cuda:{neox_args.local_rank}"
)
print
(
f
"Cuda Available? {torch.cuda.is_available()}"
)
self
.
_eot_token_id
=
50256
self
.
_device
=
(
self
.
_max_length
=
neox_args
.
max_position_embeddings
//
2
torch
.
device
(
"cuda"
)
self
.
_max_gen_toks
=
128
if
torch
.
cuda
.
is_available
()
self
.
_vocab_size
=
neox_args
.
padded_vocab_size
else
torch
.
device
(
"cpu"
)
)
# parallelism args:
self
.
is_main
=
neox_args
.
rank
==
0
self
.
is_local_main
=
neox_args
.
local_rank
==
0
self
.
is_model_parallel
=
neox_args
.
model_parallel_size
>
1
self
.
is_pipe_parallel
=
self
.
model
.
is_pipe_parallel
self
.
is_data_parallel
=
self
.
model
.
is_data_parallel
self
.
is_last_stage
=
(
True
if
not
self
.
is_pipe_parallel
else
model
.
is_last_stage
()
)
# only the last stage of the pipeline model will receive the logits
self
.
dp_world_size
=
mpu
.
get_data_parallel_world_size
()
self
.
dp_rank
=
mpu
.
get_data_parallel_rank
()
self
.
dp_group
=
mpu
.
get_data_parallel_group
()
self
.
is_mp_rank_0
=
mpu
.
get_model_parallel_rank
()
==
0
self
.
_batch_size
=
batch_size
or
(
# TODO: update this to be less of a hack once subfolder is fixed in HF
neox_args
.
batch_size
*
self
.
dp_world_size
self
.
gpt2
=
pretrained
)
# default batch size to bs per gpu * dp size
# some utility functions:
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
# we need to patch tokenizer methods, because lm_eval uses them internally:
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
.
encode
=
self
.
tokenizer
.
tokenize
"gpt2"
if
tokenizer
is
None
else
tokenizer
,
self
.
tokenizer
.
decode
=
self
.
tokenizer
.
detokenize
self
.
_forward_step_fn
=
partial
(
forward_step_fn
,
neox_args
=
neox_args
,
timers
=
None
,
return_logits
=
True
)
self
.
generate
=
partial
(
generate_samples_from_prompt
,
neox_args
=
neox_args
,
model
=
model
,
maximum_tokens
=
self
.
_max_gen_toks
,
temperature
=
0.0
,
)
)
@
property
assert
isinstance
(
def
vocab_size
(
self
):
self
.
tokenizer
,
return
self
.
_vocab_size
(
transformers
.
GPT2Tokenizer
,
transformers
.
GPT2TokenizerFast
,
transformers
.
T5Tokenizer
,
transformers
.
T5TokenizerFast
,
),
),
"this tokenizer has not been checked for compatibility yet!"
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
if
isinstance
(
self
.
tokenizer
,
(
transformers
.
GPT2Tokenizer
,
transformers
.
GPT2TokenizerFast
)
):
assert
self
.
tokenizer
.
encode
(
"hello
\n\n
hello"
)
==
[
31373
,
198
,
198
,
31373
,
],
self
.
tokenizer
.
encode
(
"hello
\n\n
hello"
)
# multithreading and batching
self
.
batch_size_per_gpu
=
batch_size
# todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
@
property
@
property
def
eot_token_id
(
self
):
def
eot_token_id
(
self
):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return
self
.
_
eos_token_id
return
self
.
tokenizer
.
eos_token_id
@
property
@
property
def
max_length
(
self
):
def
max_length
(
self
):
return
self
.
_max_length
return
self
.
gpt2
.
n_tokens
@
property
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
):
return
self
.
_max_gen_toks
return
256
@
property
@
property
def
batch_size
(
self
):
def
batch_size
(
self
):
return
self
.
_batch_size
# TODO: fix multi-gpu
return
self
.
batch_size_per_gpu
# * gpus
@
property
@
property
def
device
(
self
):
def
device
(
self
):
# TODO: fix multi-gpu
return
self
.
_device
return
self
.
_device
def
tok_encode
(
self
,
string
:
str
):
def
tok_encode
(
self
,
string
:
str
):
return
self
.
tokenizer
.
encode
(
string
)
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
False
)
def
tok_decode
(
self
,
tokens
):
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
)
return
self
.
tokenizer
.
decode
(
tokens
)
def
greedy_until
(
self
,
requests
):
def
_model_call
(
self
,
inps
):
"""
Greedy until is lm_eval harness' way to say "do greedy generation" - necessary for some tasks.
the eval harness dispatches requests to the model, and the model does argmax generation, the results of which
are returned to the eval harness to evaluate.
TODO: batched / data parallel generation
:param requests: Dictionary of requests containing the context (prompt) and 'until' - a token or
list of stop tokens.
"""
self
.
model
.
module
.
inference_mode
(
use_cache
=
True
)
# tell model to cache kv pairs
res
=
[]
def
_collate
(
x
):
toks
=
self
.
tokenizer
.
encode
(
x
[
0
])
return
(
len
(
toks
),
x
[
0
])
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
context
,
until
in
tqdm
(
reord
.
get_reordered
(),
"Running greedy generation"
):
if
isinstance
(
until
,
str
):
until
=
[
until
]
stop_tokens
=
[
self
.
tokenizer
.
encode
(
i
)
for
i
in
until
]
cont
=
self
.
generate
(
text
=
context
,
stop_tokens
=
stop_tokens
,
recompute
=
self
.
neox_args
.
recompute
,
)
if
cont
:
s
=
cont
[
0
][
"text"
]
or
""
else
:
s
=
""
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
self
.
model
.
module
.
train_mode
()
# set back to train mode
return
reord
.
get_original
(
res
)
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
"""
"""
In this method, the model doesn't do any generation, but just returns log likelihoods
inps: a torch tensor of shape [batch, sequence]
for the next token, which eval harness uses to evaluate.
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
:param requests: Dictionary of requests containing the context and the expected continuation.
logits returned from the model
:param disable_tqdm: If True, disable tqdm progress bar.
"""
"""
self
.
model
.
module
.
inference_mode
(
use_cache
=
False
)
# tell model to gather parallel outputs, but not cache key-value pairs
disable_tqdm
=
disable_tqdm
if
self
.
is_main
else
True
res
=
[]
res_len
=
0
# storing the result length for later
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
gpt2
(
inps
)[
0
][:,
:,
:
self
.
vocab_size
+
1
]
def
_collate
(
x
):
toks
=
x
[
1
]
+
x
[
2
]
return
(
-
len
(
toks
),
tuple
(
toks
))
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
inps
,
contlens
,
inplens
,
padding_length
=
[],
[],
[],
None
for
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
# when too long to fit in context, truncate from the left
inp
=
torch
.
tensor
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
],
dtype
=
torch
.
long
,
)
.
to
(
self
.
device
)
(
inplen
,)
=
inp
.
shape
cont
=
continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length
=
(
padding_length
if
padding_length
is
not
None
else
inplen
)
# pad to length
inp
=
torch
.
cat
(
[
inp
,
# [seq]
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
)
.
to
(
inp
.
device
),
# [padding_length - seq]
],
dim
=
0
,
)
inps
.
append
(
inp
.
unsqueeze
(
0
))
contlens
.
append
(
cont
)
inplens
.
append
(
inplen
)
logits
=
self
.
_model_call
(
torch
.
cat
(
inps
,
dim
=
0
))
res_len
+=
len
(
chunk
)
if
logits
is
not
None
:
multi_logits
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
# [batch, seq, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
inps
,
inplens
,
contlens
):
contlen
=
len
(
cont_toks
)
logits
=
logits
[
inplen
-
contlen
:
inplen
]
.
unsqueeze
(
0
)
# [1, seq, vocab]
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
# cont_toks :: [1, seq]
cont_toks
=
(
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
)
.
unsqueeze
(
0
)
.
to
(
multi_logits
.
device
)
)
max_equal
=
(
greedy_tokens
==
cont_toks
)
.
all
()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)
)
.
squeeze
(
-
1
)
# [1, seq]
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
# partial caching
if
cache_key
is
not
None
:
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
res
.
append
(
answer
)
# broadcast results to all ranks
if
self
.
is_pipe_parallel
:
src_rank
=
self
.
model
.
grid
.
stage_to_global
(
self
.
model
.
num_stages
-
1
)
if
res
:
logits_sums
,
max_equals
=
list
(
zip
(
*
res
))
logits_sums
=
torch
.
FloatTensor
(
logits_sums
)
.
cuda
()
max_equals
=
torch
.
LongTensor
(
max_equals
)
.
cuda
()
else
:
logits_sums
=
torch
.
zeros
(
res_len
,
dtype
=
torch
.
float32
)
.
cuda
()
max_equals
=
torch
.
zeros
(
res_len
,
dtype
=
torch
.
int64
)
.
cuda
()
torch
.
distributed
.
broadcast
(
tensor
=
logits_sums
,
src
=
src_rank
,
group
=
mpu
.
get_pipe_parallel_group
(),
)
torch
.
distributed
.
broadcast
(
tensor
=
max_equals
,
src
=
src_rank
,
group
=
mpu
.
get_pipe_parallel_group
()
)
max_equals
=
[
bool
(
i
)
for
i
in
max_equals
.
tolist
()]
logits_sums
=
logits_sums
.
tolist
()
res
=
list
(
zip
(
logits_sums
,
max_equals
))
self
.
model
.
module
.
train_mode
()
# set back to train mode
return
reord
.
get_original
(
res
)
def
_dp_scatter
(
self
,
inps
):
"""
Scatters the inputs to all data parallel ranks.
"""
batch_size
=
inps
.
shape
[
0
]
padded
=
False
if
batch_size
%
self
.
dp_world_size
!=
0
:
# The last batch could potentially not fill the full batch size (if the dataset size is not divisible by batch size)
# In this case we pad the batch
padded_size
=
self
.
dp_world_size
-
(
batch_size
%
self
.
dp_world_size
)
print_rank_0
(
f
"WARNING: Batch size ({batch_size}) must be divisible by dp world size ({self.dp_world_size}). Padding inputs to {padded_size}."
)
inps
=
torch
.
cat
(
[
inps
]
+
[
inps
[
0
:
1
,
...
]
for
_
in
range
(
padded_size
)],
dim
=
0
)
# pad with first inp item
padded
=
True
assert
(
inps
.
shape
[
0
]
%
self
.
dp_world_size
==
0
),
f
"batch size ({inps.shape[0]}) must be divisible by dp world size ({self.dp_world_size})"
# get a chunk for each data parallel rank
chunk_size
=
inps
.
shape
[
0
]
//
self
.
dp_world_size
inps
=
inps
[
self
.
dp_rank
*
chunk_size
:
(
self
.
dp_rank
+
1
)
*
chunk_size
]
# make a dummy dataloader / iterator to pass to model
# we need to do this because deepspeed pipe parallel only takes an iterator
# in this format
return
iter
([{
"text"
:
F
.
pad
(
inps
,
pad
=
(
0
,
1
))}]),
padded
def
_dp_gather
(
self
,
logits
):
"""
Gather logits from all data parallel ranks
"""
if
logits
is
not
None
:
tensor_list
=
[
torch
.
zeros_like
(
logits
)
for
_
in
range
(
self
.
dp_world_size
)]
torch
.
distributed
.
all_gather
(
tensor_list
,
logits
,
group
=
mpu
.
get_data_parallel_group
()
)
logits
=
torch
.
cat
(
tensor_list
,
dim
=
0
)
return
logits
def
_model_call
(
self
,
inps
):
batch_size
=
inps
.
shape
[
0
]
# scatter inputs to all dp ranks:
inps
,
padded
=
self
.
_dp_scatter
(
inps
)
if
self
.
neox_args
.
is_pipe_parallel
:
# need these flags to stop deepspeed pipe parallel from hanging
self
.
model
.
first_output_send
=
True
self
.
model
.
pipe_recv_buf
=
None
_
,
logits
=
self
.
_forward_step_fn
(
model
=
self
.
model
,
data_iterator
=
inps
)
# gather outputs from all dp ranks:
logits
=
self
.
_dp_gather
(
logits
)
# if logits have been padded (normally just last item where batch size is unequal)
# restore to original shape
if
padded
and
logits
is
not
None
:
logits
=
logits
[:
batch_size
,
...
]
return
logits
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
# Isn't used because we override `greedy_until``.
return
sampling
.
generate_greedy
(
self
.
gpt2
.
forward
,
context
,
max_length
)
raise
NotImplementedError
()
@
torch
.
no_grad
()
def
run_eval
(
self
,
eval_tasks
=
None
,
num_fewshot
=
0
,
bootstrap_iters
=
2
,
description_dict
=
None
,
use_cache
=
True
,
name
=
"neox"
,
limit
=
None
):
was_training
=
self
.
model
.
training
self
.
model
.
eval
()
in_micro_batches
=
(
self
.
model
.
micro_batches
)
# store input microbatches - we need to set to 1 during eval, but want to return to its original value after
self
.
model
.
micro_batches
=
1
if
eval_tasks
is
None
:
eval_tasks
=
[
"lambada"
,
"piqa"
,
"hellaswag"
,
"winogrande"
,
"mathqa"
,
"pubmedqa"
,
]
# **HACK INCOMING**:
# first get task dict on local main rank
# the tasks are downloaded *as they are initialized*, and the downloads don't like multithreading.
# so we download them once on the local main rank, wait, and then initialize them on all other ranks, which *should* load from the cache.
if
self
.
is_local_main
:
task_dict
=
tasks
.
get_task_dict
(
eval_tasks
)
# torch barrier
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
barrier
()
task_dict
=
tasks
.
get_task_dict
(
eval_tasks
)
lm
=
self
if
use_cache
:
# TODO(jon-tow): Append a subset of `neox_args` to the cache database
# name arg to distinguish model runs that use different configurations.
lm
=
base
.
CachingLM
(
lm
,
'lm_cache/'
+
name
+
'.db'
)
results
=
evaluator
.
evaluate
(
lm
=
lm
,
task_dict
=
tasks
.
get_task_dict
(
eval_tasks
),
description_dict
=
description_dict
,
num_fewshot
=
num_fewshot
,
limit
=
limit
,
bootstrap_iters
=
bootstrap_iters
,
)
results
[
"config"
]
=
{
"model"
:
name
,
"model_args"
:
dataclasses
.
asdict
(
self
.
neox_args
),
"num_fewshot"
:
num_fewshot
,
"batch_size"
:
self
.
batch_size
,
"device"
:
str
(
self
.
device
),
"no_cache"
:
not
use_cache
,
"limit"
:
limit
,
"bootstrap_iters"
:
bootstrap_iters
,
"description_dict"
:
description_dict
}
if
was_training
:
self
.
model
.
train
()
self
.
model
.
micro_batches
=
in_micro_batches
return
results
if
__name__
==
'__main__'
:
def
run_eval_harness
(
based_model
=
gptj
.
load_gpt_j
()
.
cuda
()
.
half
()
.
eval
()
model
,
adapter
=
BasedformerLM
(
pretrained
=
based_model
,
batch_size
=
1
)
forward_step_fn
,
adapter
.
run_eval
(
eval_tasks
=
[
'lambada'
,
'piqa'
],
num_fewshot
=
0
,
bootstrap_iters
=
2
)
neox_args
,
\ No newline at end of file
batch_size
=
None
,
eval_tasks
=
None
,
num_fewshot
=
0
,
bootstrap_iters
=
2
,
):
adapter
=
EvalHarnessAdapter
(
model
,
forward_step_fn
,
neox_args
,
batch_size
)
return
adapter
.
run_eval
(
eval_tasks
=
eval_tasks
,
num_fewshot
=
num_fewshot
,
bootstrap_iters
=
bootstrap_iters
)
\ No newline at end of file
run_pyfra.py
View file @
c8d491e1
...
@@ -24,20 +24,25 @@ config_obj.create_service(overwrite=True)
...
@@ -24,20 +24,25 @@ config_obj.create_service(overwrite=True)
remote
=
config_obj
.
get_pyfra_remote
()
remote
=
config_obj
.
get_pyfra_remote
()
env1
=
remote
.
env
(
'noname'
,
python_version
=
None
)
env1
=
remote
.
env
(
'noname'
,
python_version
=
None
)
path
=
env1
.
path
(
'/home/xuser/diffusionstorage/workspace/kuru/basedformer'
)
path
=
env1
.
path
(
'/home/xuser/diffusionstorage/workspace/kuru/basedformer'
)
env1
.
sh
(
'pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl'
)
env1
.
sh
(
'pip install einops numpy'
)
if
False
:
env1
.
sh
(
'pip install tqdm'
)
env1
.
sh
(
'pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl'
)
env1
.
sh
(
'pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo'
)
env1
.
sh
(
'pip install einops numpy'
)
env1
.
sh
(
'pip3 install einops==0.4.1 pyyaml wandb'
)
env1
.
sh
(
'pip install tqdm'
)
env1
.
sh
(
'wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4'
)
env1
.
sh
(
'pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo'
)
env1
.
sh
(
'pip3 install dotmap icecream'
)
env1
.
sh
(
'pip3 install einops==0.4.1 pyyaml wandb'
)
path
.
sh
(
"pip3 install --editable ."
)
env1
.
sh
(
'wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4'
)
#path.sh("pip3 uninstall torch")
env1
.
sh
(
'pip3 install dotmap icecream'
)
#path.sh("pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113")
path
.
sh
(
"pip3 install --editable ."
)
#path.sh("pip3 uninstall torch")
#path.sh("pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113")
with
always_rerun
():
with
always_rerun
():
if
bash
:
if
True
:
path
.
sh
(
"bash"
)
path
.
sh
(
"python3 ../lm-evaluation-harness/main.py --model basedformer --batch_size 8 --model_args pretrained=/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/gptj-6b --device 0 --tasks lambada"
)
#path.sh("python3 ../lm-evaluation-harness/main.py --batch_size 8")
else
:
else
:
print
(
f
"Running {sys.argv[1]}"
)
print
(
f
"Running {sys.argv[1]}"
)
path
.
sh
(
f
'python3 {sys.argv[1]}'
)
path
.
sh
(
f
'python3 {sys.argv[1]}'
)
scripts/comparehf.py
View file @
c8d491e1
from
basedformer
import
gptj
from
basedformer
import
gptj
from
basedformer.utils
import
*
from
basedformer.utils
import
*
import
basedformer.lm_utils
as
lmu
import
time
import
time
import
torch
import
torch
...
@@ -8,11 +9,7 @@ import numpy as np
...
@@ -8,11 +9,7 @@ import numpy as np
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers
import
(
from
transformers
import
GPTNeoForCausalLM
AutoModelForCausalLM
,
GPTNeoForCausalLM
,
AutoConfig
,
)
#replicating timeit magic function of ipython
#replicating timeit magic function of ipython
def
timeit
(
func
,
r
=
1
,
n
=
5
,
quiet
=
False
,
function
=
None
,
do_tqdm
=
False
,
first
=
True
):
def
timeit
(
func
,
r
=
1
,
n
=
5
,
quiet
=
False
,
function
=
None
,
do_tqdm
=
False
,
first
=
True
):
precision
=
'ns'
precision
=
'ns'
...
@@ -67,10 +64,11 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
...
@@ -67,10 +64,11 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
with
torch
.
no_grad
():
with
torch
.
no_grad
():
based_model
=
gptj
.
load_gpt_j
()
.
cuda
()
.
half
()
.
eval
()
hf_model
=
no_init
(
lambda
:
GPTNeoForCausalLM
.
from_pretrained
(
'/home/xuser/models/j6b_ckpt_14001'
))
.
cuda
()
.
half
()
.
eval
()
print
(
"Loaded based model"
)
hf_model
=
no_init
(
lambda
:
AutoModelForCausalLM
.
from_pretrained
(
'/home/xuser/models/j6b_ckpt_14001'
))
.
cuda
()
.
half
()
.
eval
()
print
(
"Loaded hf model"
)
print
(
"Loaded hf model"
)
path
=
"/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/gptj-6b"
based_model
=
lmu
.
load_from_path
(
path
)
.
cuda
()
.
half
()
.
eval
()
print
(
"Loaded based model"
)
x
=
torch
.
randint
(
0
,
50256
,
(
1
,
2048
))
.
cuda
()
.
long
()
x
=
torch
.
randint
(
0
,
50256
,
(
1
,
2048
))
.
cuda
()
.
long
()
assert
torch
.
allclose
(
hf_model
.
transformer
.
wte
(
x
),
based_model
.
vocab_embed
(
x
))
assert
torch
.
allclose
(
hf_model
.
transformer
.
wte
(
x
),
based_model
.
vocab_embed
(
x
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment