Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
40e90836
Commit
40e90836
authored
May 13, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix fairseq by taking out eot and newline mapping
parent
aa6444e9
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
327 additions
and
64 deletions
+327
-64
basedformer/models/gptneo.py
basedformer/models/gptneo.py
+1
-1
basedformer/models/opt.py
basedformer/models/opt.py
+310
-0
run_pyfra.py
run_pyfra.py
+11
-6
scripts/comparefairseq.py
scripts/comparefairseq.py
+3
-55
scripts/fairseqport.py
scripts/fairseqport.py
+2
-2
No files found.
basedformer/models/gptneo.py
View file @
40e90836
...
...
@@ -171,7 +171,7 @@ class GPTNeoModel(base_lm.BaseModel):
base_lm
.
BaseModel
.
__init__
(
self
,
user_config
,
**
kwargs
)
self
.
pos_embed
=
nn
.
Embedding
(
self
.
config
.
n_tokens
,
self
.
config
.
hidden_dim
)
self
.
lm_head
=
nn
.
Linear
(
self
.
config
.
hidden_dim
,
self
.
config
.
vocab_dim
,
bias
=
False
)
#bias=False for
fairseq
models
#bias=False for
neo
models
def
get_embeds
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
if
kv
is
None
:
...
...
basedformer/models/opt.py
0 → 100644
View file @
40e90836
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basedformer.utils
import
*
from
torch.utils.checkpoint
import
checkpoint
as
ck
from
einops
import
rearrange
,
repeat
try
:
from
collections.abc
import
MutableMapping
except
ImportError
:
from
collections
import
MutableMapping
import
os
from
pathlib
import
Path
import
math
from
basedformer.models
import
base_lm
from
typing
import
Optional
,
Any
def
make_positions
(
tensor
,
padding_idx
:
int
,
onnx_trace
:
bool
=
False
):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask
=
tensor
.
ne
(
torch
.
tensor
(
50257
,
requires_grad
=
False
))
.
int
()
return
(
torch
.
cumsum
(
mask
,
dim
=
1
)
.
type_as
(
mask
)
*
mask
)
.
long
()
+
padding_idx
class
SinusoidalPositionalEmbedding
(
nn
.
Module
):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def
__init__
(
self
,
embedding_dim
,
padding_idx
,
init_size
=
1024
):
super
()
.
__init__
()
self
.
embedding_dim
=
embedding_dim
self
.
padding_idx
=
padding_idx
if
padding_idx
is
not
None
else
0
self
.
weights
=
SinusoidalPositionalEmbedding
.
get_embedding
(
init_size
,
embedding_dim
,
padding_idx
)
self
.
onnx_trace
=
False
self
.
register_buffer
(
"_float_tensor"
,
torch
.
tensor
(
1.0
,
requires_grad
=
False
)
.
float
())
self
.
max_positions
=
int
(
1e5
)
# print(embedding_dim, padding_idx, init_size)
def
prepare_for_onnx_export_
(
self
):
self
.
onnx_trace
=
True
@
staticmethod
def
get_embedding
(
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float
)
*
-
emb
)
emb
=
torch
.
arange
(
num_embeddings
,
dtype
=
torch
.
float
)
.
unsqueeze
(
1
)
*
emb
.
unsqueeze
(
0
)
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
.
view
(
num_embeddings
,
-
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
cat
([
emb
,
torch
.
zeros
(
num_embeddings
,
1
)],
dim
=
1
)
if
padding_idx
is
not
None
:
emb
[
padding_idx
,
:]
=
0
return
emb
def
forward
(
self
,
input
,
incremental_state
:
Optional
[
Any
]
=
None
,
timestep
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
Any
]
=
None
,
offset
:
Optional
[
int
]
=
0
):
"""Input is expected to be of size [bsz x seqlen]."""
bspair
=
input
.
shape
bsz
,
seq_len
=
bspair
[
0
],
bspair
[
1
]
max_pos
=
self
.
padding_idx
+
1
+
seq_len
+
offset
# print("max_pos: " + str(max_pos))
if
self
.
weights
is
None
or
max_pos
>
self
.
weights
.
size
(
0
):
# print("recomputing embeddings")
# recompute/expand embeddings if needed
self
.
weights
=
SinusoidalPositionalEmbedding
.
get_embedding
(
max_pos
,
self
.
embedding_dim
,
self
.
padding_idx
+
offset
)
self
.
weights
=
self
.
weights
.
to
(
self
.
_float_tensor
)
if
incremental_state
is
not
None
:
# positions is the same for every token when decoding a single step
pos
=
timestep
.
view
(
-
1
)[
0
]
+
1
if
timestep
is
not
None
else
seq_len
if
self
.
onnx_trace
:
return
(
self
.
weights
.
index_select
(
index
=
self
.
padding_idx
+
pos
+
offset
,
dim
=
0
)
.
unsqueeze
(
1
)
.
repeat
(
bsz
,
1
,
1
)
)
return
self
.
weights
[
self
.
padding_idx
+
pos
+
offset
,
:]
.
expand
(
bsz
,
1
,
-
1
)
positions
=
make_positions
(
input
,
self
.
padding_idx
+
offset
,
onnx_trace
=
self
.
onnx_trace
)
if
self
.
onnx_trace
:
flat_embeddings
=
self
.
weights
.
detach
()
.
index_select
(
0
,
positions
.
view
(
-
1
))
embedding_shape
=
torch
.
cat
(
(
bsz
.
view
(
1
),
seq_len
.
view
(
1
),
torch
.
tensor
([
-
1
],
dtype
=
torch
.
long
))
)
embeddings
=
torch
.
onnx
.
operators
.
reshape_from_tensor_shape
(
flat_embeddings
,
embedding_shape
)
return
embeddings
return
(
self
.
weights
.
index_select
(
0
,
positions
.
view
(
-
1
))
.
view
(
bsz
,
seq_len
,
-
1
)
.
detach
()
)
def
PositionalEmbedding
(
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
int
,
):
m
=
SinusoidalPositionalEmbedding
(
embedding_dim
,
padding_idx
,
init_size
=
num_embeddings
+
padding_idx
+
1
,
)
return
m
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
,
fp32_attn
=
True
):
if
fp32_attn
:
attn_weights
=
torch
.
matmul
(
query
.
float
(),
key
.
transpose
(
-
1
,
-
2
)
.
float
())
else
:
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
attn_weights
/
scale_attn
.
to
(
attn_weights
.
dtype
)
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
attn_weights
.
to
(
value
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value
)
.
to
(
value
.
dtype
)
return
attn_output
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
config
.
hidden_dim
//
config
.
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
config
.
hidden_dim
self
.
n_head
=
config
.
n_head
device
=
config
.
device
dtype
=
config
.
dtype
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
True
#fairseq has attn_bias
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
# seq_len, seq_len
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
,
self
.
config
.
fp32_attn
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
if
cache
:
return
x
,
(
key
,
value
)
else
:
return
x
,
None
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
ff1
=
nn
.
Linear
(
config
.
hidden_dim
,
config
.
hidden_dim
*
4
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff2
=
nn
.
Linear
(
config
.
hidden_dim
*
4
,
config
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
activation
=
config
.
activation
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
if
act_ck
:
x
=
ck
(
self
.
activation
,
x
)
else
:
x
=
self
.
activation
(
x
)
x
=
self
.
ff2
(
x
)
return
x
class
GPTFairLayer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_dim
=
config
.
hidden_dim
self
.
ln_preattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ln_postattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff
=
ff
(
config
)
self
.
attn
=
attn
(
config
)
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
cache
=
False
,
kv
=
None
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
,
kv
=
ck
(
self
.
attn
,
x
,
kv
=
kv
,
cache
=
cache
)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
,
kv
=
self
.
attn
(
x
,
kv
=
kv
,
cache
=
cache
)
x
=
residual
+
attn_out
residual
=
x
x
=
self
.
ln_postattn
(
x
)
ff_out
=
self
.
ff
(
x
,
act_ck
)
x
=
residual
+
ff_out
return
x
,
kv
class
GPTFairModel
(
base_lm
.
BaseModel
):
def
__init__
(
self
,
user_config
,
**
kwargs
):
self
.
default_config
=
{
'n_layer'
:
6
,
'n_head'
:
8
,
'n_tokens'
:
2049
,
'hidden_dim'
:
512
,
'vocab_dim'
:
50400
,
'fp32_attn'
:
True
,
#fairseq models are trained with fp32 attn
'eps'
:
1e-5
,
'device'
:
torch
.
device
(
'cuda'
),
'dtype'
:
torch
.
float16
,
'Layer'
:
GPTFairLayer
,
'activation'
:
F
.
gelu
,
'SelfAttention'
:
SelfAttention
,
'FeedForward'
:
FeedForward
,
}
base_lm
.
BaseModel
.
__init__
(
self
,
user_config
,
**
kwargs
)
# returns sinusoidal embeddings of shape: (1, n_tokens, 768)
self
.
register_buffer
(
"embed_scale"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
config
.
hidden_dim
,
requires_grad
=
False
)))
self
.
pos_embed
=
PositionalEmbedding
(
self
.
config
.
n_tokens
,
self
.
config
.
hidden_dim
,
1
)
self
.
lm_head
=
nn
.
Linear
(
self
.
config
.
hidden_dim
,
self
.
config
.
vocab_dim
,
bias
=
False
)
#bias=False for fairseq models
def
get_embeds
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
if
kv
is
None
:
kv
=
[
None
]
*
self
.
n_layer
past_length
=
0
else
:
past_length
=
kv
[
0
][
0
]
.
size
(
-
2
)
#get sequence dim of key
kv_new
=
[]
position_embeds
=
self
.
pos_embed
(
x
,
offset
=
past_length
)
input_embeds
=
self
.
vocab_embed
(
x
)
*
self
.
embed_scale
x
=
position_embeds
+
input_embeds
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
x
,
kvi
=
layer
(
x
,
layer_id
=
layer_id
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
,
kv
=
kv
[
layer_id
],
cache
=
cache
)
kv_new
.
append
(
kvi
)
x
=
self
.
ln_final
(
x
)
if
cache
:
return
x
,
kv_new
else
:
return
x
,
None
\ No newline at end of file
run_pyfra.py
View file @
40e90836
...
...
@@ -14,7 +14,7 @@ bash = False
config_obj
=
KubeConfig
()
config_obj
.
set_name
(
name
)
config_obj
.
set_gpu
(
gpu_name
=
GPU
.
RTX_A6000
,
amount
=
1
)
config_obj
.
set_ram
(
16
)
config_obj
.
set_ram
(
64
)
config_obj
.
set_cpu
(
4
)
config_obj
.
dry_run
(
dry
)
config_obj
.
print_information
()
...
...
@@ -31,18 +31,23 @@ if False:
env1
.
sh
(
'pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl'
)
env1
.
sh
(
'pip install einops numpy'
)
env1
.
sh
(
'pip install tqdm'
)
env1
.
sh
(
'pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo'
)
#
env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo')
env1
.
sh
(
'pip3 install einops==0.4.1 pyyaml wandb'
)
env1
.
sh
(
'wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4'
)
env1
.
sh
(
'pip3 install dotmap icecream'
)
path
.
sh
(
"pip3 install --editable ."
)
#path.sh("pip3 uninstall torch")
#path.sh("pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113")
with
always_rerun
():
if
True
:
#env1.sh('pip3 uninstall transformers')
#env1.sh('pip3 install transformers')
path
.
sh
(
"python3 ../lm-evaluation-harness/main.py --model basedformer --batch_size 8 --model_args pretrained=/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/gpt-neo-125m-ported --device 0 --tasks lambada --no_cache"
)
#path.sh('pip3 install --editable ../lm-evaluation-harness/.')
#env1.sh('pip3 install pytest')
#env1.sh('pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl')
path
.
sh
(
'pip3 uninstall huggingface_hub'
)
path
.
sh
(
'pip3 install huggingface_hub'
)
#path.sh('pip3 uninstall transformers')
#path.sh('pip3 install transformers')
#path.sh("python3 ../lm-evaluation-harness/main.py --model gpt2 --batch_size 8 --model_args pretrained=EleutherAI/gpt-neo-125M --device 0 --tasks lambada --no_cache")
path
.
sh
(
"python3 ../lm-evaluation-harness/main.py --model basedformer --batch_size 8 --model_args pretrained=/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/fairseq_125m --device 0 --tasks lambada --no_cache"
)
#path.sh("python3 ../lm-evaluation-harness/main.py --batch_size 8")
else
:
...
...
scripts/comparefairseq.py
View file @
40e90836
...
...
@@ -11,67 +11,15 @@ from contextlib import contextmanager
import
torch.nn.functional
as
F
from
transformers
import
GPTNeoForCausalLM
from
icecream
import
ic
#replicating timeit magic function of ipython
def
timeit
(
func
,
r
=
1
,
n
=
5
,
quiet
=
False
,
function
=
None
,
do_tqdm
=
False
,
first
=
True
):
precision
=
'ns'
r_arr
=
np
.
empty
([
2
,
r
])
# [0] = mean, [1] = std
if
function
:
func
.
__name__
=
function
.
__name__
for
i
in
tqdm
(
range
(
r
))
if
do_tqdm
else
range
(
r
):
n_arr
=
np
.
empty
(
n
)
for
k
in
range
(
n
):
start
=
perf_counter_ns
()
func
()
n_arr
[
k
]
=
perf_counter_ns
()
-
start
if
not
first
:
# delete the first element from n_arr numpy array
n_arr
=
np
.
delete
(
n_arr
,
0
)
r_arr
[
0
,
i
]
=
np
.
mean
(
n_arr
)
r_arr
[
1
,
i
]
=
np
.
std
(
n_arr
)
best
=
r_arr
[:,
np
.
argmin
(
r_arr
[
0
])]
# [0] = mean, [1] = std
#check if best[0] bigger than 1ms in numpy
if
best
[
0
]
<
1e3
:
precision
=
'ns'
elif
best
[
0
]
>=
1e9
:
print
(
'b'
)
best
[
0
]
=
best
[
0
]
*
1e-9
best
[
1
]
=
best
[
1
]
*
1e-9
precision
=
's'
elif
best
[
0
]
>=
1e6
:
best
[
0
]
=
best
[
0
]
*
1e-6
best
[
1
]
=
best
[
1
]
*
1e-6
precision
=
'ms'
elif
best
[
0
]
>=
1e3
:
precision
=
'μs'
best
[
0
]
=
best
[
0
]
*
1e-3
best
[
1
]
=
best
[
1
]
*
1e-3
if
not
quiet
:
if
precision
==
'ns'
:
print
(
f
"{func.__name__}: {best[0]:.0f}{precision} ± {best[1]:.0f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)"
)
if
precision
==
'μs'
:
print
(
f
"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)"
)
elif
precision
==
'ms'
:
print
(
f
"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)"
)
elif
precision
==
's'
:
print
(
f
"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)"
)
with
torch
.
no_grad
():
model_dir
=
'/home/xuser/diffusionstorage/
workspace/kuru/basedformer/pretrained/hf
_125m/'
model_dir
=
'/home/xuser/diffusionstorage/
models/fairseq/converted/en_dense_lm
_125m/'
hf_model
=
no_init
(
lambda
:
GPTNeoForCausalLM
.
from_pretrained
(
model_dir
))
.
cuda
()
.
half
()
.
eval
()
print
(
"Loaded hf model"
)
path
=
"/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/fairseq_125m"
based_model
=
lmu
.
load_from_path
(
path
)
.
cuda
()
.
half
()
.
eval
()
print
(
"Loaded based model"
)
x
=
torch
.
randint
(
0
,
5
0256
,
(
1
,
2048
))
.
cuda
()
.
long
()
x
=
torch
.
randint
(
0
,
5
1200
,
(
1
,
300
))
.
cuda
()
.
long
()
assert
torch
.
allclose
(
hf_model
.
transformer
.
wte
(
x
),
based_model
.
vocab_embed
(
x
))
hidden
=
hf_model
.
transformer
.
wte
(
x
)
...
...
@@ -85,7 +33,7 @@ with torch.no_grad():
ic
(
hf_model
.
transformer
.
h
[
layer
]
.
attn
(
hidden
)[
0
]
.
abs
()
.
mean
())
ic
(
based_model
.
layers
[
layer
]
.
attn
(
hidden
)[
0
]
.
abs
()
.
mean
())
ic
((
hf_model
.
transformer
.
h
[
layer
]
.
attn
(
hidden
)[
0
]
-
based_model
.
layers
[
layer
]
.
attn
(
hidden
)[
0
])
.
abs
()
.
mean
())
#
assert torch.allclose(hf_model.transformer.h[layer].attn(hidden)[0], based_model.layers[layer].attn(hidden)[0], rtol=1e-6)
assert
torch
.
allclose
(
hf_model
.
transformer
.
h
[
layer
]
.
attn
(
hidden
)[
0
],
based_model
.
layers
[
layer
]
.
attn
(
hidden
)[
0
],
rtol
=
1e-6
)
attn_out
=
hf_model
.
transformer
.
h
[
layer
]
.
attn
(
hidden
)[
0
]
hidden
=
residual
+
attn_out
residual
=
hidden
...
...
scripts/fairseqport.py
View file @
40e90836
...
...
@@ -110,12 +110,12 @@ with torch.no_grad():
wte
=
fairdict
[
"decoder.embed_tokens.weight"
]
.
clone
()
for
i
in
range
(
50260
):
wte
[
mapping
[
i
]]
=
fairdict
[
"decoder.embed_tokens.weight"
][
i
]
hack_embs
(
wte
)
#
hack_embs(wte)
save
(
wte
.
half
(),
"vocab_embed.weight"
)
lm_head
=
fairdict
[
"decoder.output_projection.weight"
]
.
clone
()
for
i
in
range
(
50260
):
lm_head
[
mapping
[
i
]]
=
fairdict
[
"decoder.output_projection.weight"
][
i
]
hack_embs
(
lm_head
)
#
hack_embs(lm_head)
save
(
lm_head
.
half
(),
"lm_head.weight"
)
save
(
torch
.
FloatTensor
(
1
),
"pos_embed._float_tensor"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment