Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
1f8d04b2
Commit
1f8d04b2
authored
Apr 20, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
implement cached generation
parent
3a52a64a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
162 additions
and
68 deletions
+162
-68
basedformer/crosslm.py
basedformer/crosslm.py
+78
-48
basedformer/gptj.py
basedformer/gptj.py
+32
-15
basedformer/lm_base.py
basedformer/lm_base.py
+1
-1
run_pyfra.py
run_pyfra.py
+1
-1
scripts/comparehf.py
scripts/comparehf.py
+3
-3
scripts/test_cache.py
scripts/test_cache.py
+47
-0
No files found.
basedformer/crosslm.py
View file @
1f8d04b2
...
...
@@ -30,34 +30,6 @@ def apply_rotary_pos_emb(x, sincos, offset=0):
sin
,
cos
=
map
(
lambda
t
:
repeat
(
t
[
offset
:
x
.
shape
[
1
]
+
offset
,:],
"n d -> () n () (d j)"
,
j
=
2
),
sincos
)
return
(
x
*
cos
)
+
(
rotate_every_two
(
x
)
*
sin
)
def
_split_heads
(
tensor
,
num_heads
,
attn_head_size
,
rotary
):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape
=
tensor
.
size
()[:
-
1
]
+
(
num_heads
,
attn_head_size
)
tensor
=
tensor
.
view
(
*
new_shape
)
if
rotary
:
return
tensor
if
len
(
tensor
.
shape
)
==
5
:
return
tensor
.
permute
(
0
,
1
,
3
,
2
,
4
)
# (batch, blocks, head, block_length, head_features)
elif
len
(
tensor
.
shape
)
==
4
:
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
else
:
raise
ValueError
(
f
"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}"
)
def
_merge_heads
(
tensor
,
num_heads
,
attn_head_size
):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if
len
(
tensor
.
shape
)
==
5
:
tensor
=
tensor
.
permute
(
0
,
1
,
3
,
2
,
4
)
.
contiguous
()
elif
len
(
tensor
.
shape
)
==
4
:
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
)
.
contiguous
()
else
:
raise
ValueError
(
f
"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}"
)
new_shape
=
tensor
.
size
()[:
-
2
]
+
(
num_heads
*
attn_head_size
,)
return
tensor
.
view
(
new_shape
)
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
...
...
@@ -75,6 +47,54 @@ def _attn(query, key, value, causal_mask, masked_bias,
return
attn_output
class
CrossAttentionMod
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
hidden_dim
,
n_head
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
hidden_dim
//
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
hidden_dim
self
.
n_head
=
n_head
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
y
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
#latent query
key
=
self
.
k_proj
(
y
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
#context key
value
=
self
.
v_proj
(
y
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
#context value
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
# seq_len, seq_len
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
if
cache
:
return
x
,
(
key
,
value
)
else
:
return
x
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
hidden_dim
,
n_head
,
device
,
dtype
):
...
...
@@ -98,14 +118,13 @@ class SelfAttention(nn.Module):
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"cos"
,
cos
)
def
forward
(
self
,
x
):
query
=
self
.
q_proj
(
x
)
key
=
self
.
k_proj
(
x
)
value
=
self
.
v_proj
(
x
)
query
=
_split_heads
(
query
,
self
.
n_head
,
self
.
head_dim
,
True
)
key
=
_split_heads
(
key
,
self
.
n_head
,
self
.
head_dim
,
True
)
value
=
_split_heads
(
value
,
self
.
n_head
,
self
.
head_dim
,
False
)
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
offset
=
0
if
self
.
rotary_dim
<
self
.
head_dim
:
...
...
@@ -124,21 +143,32 @@ class SelfAttention(nn.Module):
else
:
key
=
apply_rotary_pos_emb
(
key
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
key
.
dtype
)
query
=
apply_rotary_pos_emb
(
query
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
query
.
dtype
)
key
=
key
.
permute
(
0
,
2
,
1
,
3
)
query
=
query
.
permute
(
0
,
2
,
1
,
3
)
query
=
query
.
transpose
(
1
,
2
)
key
=
key
.
transpose
(
1
,
2
)
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
#causal mask with generation in mind
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
_merge_heads
(
x
,
self
.
n_head
,
self
.
head_dim
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
return
x
if
cache
:
return
x
,
(
key
,
value
)
else
:
return
x
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
hidden_dim
,
activation
,
device
,
dtype
):
...
...
@@ -156,7 +186,7 @@ class FeedForward(nn.Module):
x
=
self
.
ff2
(
x
)
return
x
class
CrossGPT
Layer
(
nn
.
Module
):
class
GPTJ
Layer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
hidden_dim
,
n_head
,
eps
,
activation
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_dim
=
hidden_dim
...
...
@@ -202,8 +232,8 @@ class CrossGPTLayer(nn.Module):
return
x
class
CrossGPT
Model
(
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
n_layer
,
n_head
,
vocab_dim
,
eps
,
activation
=
gelu_new
,
Layer
=
CrossGPT
Layer
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
**
kwargs
):
class
GPTJ
Model
(
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
n_layer
,
n_head
,
vocab_dim
,
eps
,
activation
=
gelu_new
,
Layer
=
GPTJ
Layer
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
self
.
n_layer
=
n_layer
self
.
hidden_dim
=
hidden_dim
...
...
@@ -226,11 +256,11 @@ class CrossGPTModel(nn.Module):
x
=
self
.
ln_final
(
x
)
return
x
class
CrossGPT
BaseLM
(
lm_base
.
BaseLM
):
class
GPTJ
BaseLM
(
lm_base
.
BaseLM
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
nn
.
Module
.
__init__
(
self
)
lm_base
.
BaseLM
.
__init__
(
self
,
config
,
lm
)
self
.
model_class
=
CrossGPT
Model
self
.
model_class
=
GPTJ
Model
def
load_gpt_j
(
path
=
"models/6b"
,
state_dict
=
None
):
config
=
{
...
...
@@ -240,5 +270,5 @@ def load_gpt_j(path="models/6b", state_dict=None):
"vocab_dim"
:
50400
,
"eps"
:
1e-5
}
model
=
CrossGPT
BaseLM
.
load
(
config
,
path
,
state_dict
)
model
=
GPTJ
BaseLM
.
load
(
config
,
path
,
state_dict
)
return
model
basedformer/gptj.py
View file @
1f8d04b2
from
typing
import
KeysView
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -78,7 +79,11 @@ class SelfAttention(nn.Module):
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
offset
=
0
if
kv
:
offset
=
kv
[
0
]
.
shape
[
-
2
]
else
:
offset
=
0
if
self
.
rotary_dim
<
self
.
head_dim
:
k_rot
=
key
[:,
:,
:,
:
self
.
rotary_dim
]
k_pass
=
key
[:,
:,
:,
self
.
rotary_dim
:]
...
...
@@ -103,8 +108,8 @@ class SelfAttention(nn.Module):
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
key
=
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
value
=
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
#causal mask with generation in mind
...
...
@@ -120,7 +125,7 @@ class SelfAttention(nn.Module):
if
cache
:
return
x
,
(
key
,
value
)
else
:
return
x
return
x
,
None
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
hidden_dim
,
activation
,
device
,
dtype
):
...
...
@@ -147,16 +152,15 @@ class GPTJLayer(nn.Module):
self
.
attn
=
attn
(
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
device
=
device
,
dtype
=
dtype
)
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
):
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
,
cache
=
False
,
kv
=
None
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
=
ck
(
self
.
attn
,
x
)
attn_out
,
kv
=
ck
(
self
.
attn
,
x
,
kv
=
kv
,
cache
=
cache
)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
=
self
.
attn
(
x
)
attn_out
,
kv
=
self
.
attn
(
x
,
kv
=
kv
,
cache
=
cache
)
if
hypernetwork
:
if
diff_hypernets
:
...
...
@@ -182,7 +186,7 @@ class GPTJLayer(nn.Module):
if
hypernetwork
and
layer_id
%
every_n
==
0
:
x
=
x
+
hyper_out
return
x
return
x
,
kv
class
GPTJModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
n_layer
,
n_head
,
vocab_dim
,
eps
,
activation
=
gelu_new
,
Layer
=
GPTJLayer
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
**
kwargs
):
...
...
@@ -196,17 +200,30 @@ class GPTJModel(nn.Module):
for
_
in
range
(
n_layer
):
self
.
layers
.
append
(
Layer
(
attn
=
SelfAttention
,
ff
=
FeedForward
,
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
eps
=
eps
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
))
def
forward
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
x
=
self
.
get_embeds
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
def
forward
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
x
,
kv
=
self
.
get_embeds
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
,
kv
=
kv
,
cache
=
cache
)
x
=
self
.
lm_head
(
x
)
return
x
.
float
()
if
kv
:
return
x
.
float
(),
kv
else
:
return
x
.
float
()
def
get_embeds
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
def
get_embeds
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
if
kv
is
None
:
kv
=
[
None
]
*
self
.
n_layer
kv_new
=
[]
x
=
self
.
vocab_embed
(
x
)
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
,
layer_id
=
layer_id
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
,
kvi
=
layer
(
x
,
layer_id
=
layer_id
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
,
kv
=
kv
[
layer_id
],
cache
=
cache
)
kv_new
.
append
(
kvi
)
x
=
self
.
ln_final
(
x
)
return
x
if
cache
:
return
x
,
kv_new
else
:
return
x
,
None
class
GPTJBaseLM
(
lm_base
.
BaseLM
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
...
...
basedformer/lm_base.py
View file @
1f8d04b2
...
...
@@ -76,7 +76,7 @@ class BaseLM(nn.Module):
state_dict
=
utils
.
SplitCheckpoint
(
path
,
device
=
"cuda"
)
model
=
cls
(
config
)
model
.
lm
=
model
.
model_class
(
**
config
)
model
.
lm
=
utils
.
no_init
(
lambda
:
model
.
model_class
(
**
config
)
)
model
.
lm
.
load_state_dict
(
state_dict
,
strict
=
strict
)
return
model
...
...
run_pyfra.py
View file @
1f8d04b2
...
...
@@ -31,7 +31,7 @@ env1.sh('pip install tqdm')
env1
.
sh
(
'pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo'
)
env1
.
sh
(
'pip3 install einops==0.4.1 pyyaml wandb'
)
env1
.
sh
(
'wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4'
)
env1
.
sh
(
'pip3 install dotmap'
)
env1
.
sh
(
'pip3 install dotmap
icecream
'
)
with
always_rerun
():
if
bash
:
path
.
sh
(
"bash"
)
...
...
scripts/comparehf.py
View file @
1f8d04b2
...
...
@@ -81,11 +81,11 @@ with torch.no_grad():
hidden
=
hf_model
.
transformer
.
h
[
layer
]
.
ln_1
(
hidden
)
assert
torch
.
allclose
(
hf_model
.
transformer
.
h
[
layer
]
.
mlp
(
hidden
),
based_model
.
layers
[
layer
]
.
ff
(
hidden
))
hidden
=
hf_model
.
transformer
.
h
[
layer
]
.
mlp
(
hidden
)
assert
torch
.
allclose
(
hf_model
.
transformer
.
h
[
layer
]
.
attn
(
hidden
)[
0
],
based_model
.
layers
[
layer
]
.
attn
(
hidden
))
assert
torch
.
allclose
(
hf_model
.
transformer
.
h
[
layer
]
.
attn
(
hidden
)[
0
],
based_model
.
layers
[
layer
]
.
attn
(
hidden
)
[
0
]
)
hidden
=
hf_model
.
transformer
.
h
[
layer
]
.
attn
(
hidden
)[
0
]
assert
torch
.
allclose
(
hf_model
.
transformer
.
h
[
layer
](
hidden
)[
0
],
based_model
.
layers
[
layer
](
hidden
))
assert
torch
.
allclose
(
hf_model
.
transformer
.
h
[
layer
](
hidden
)[
0
],
based_model
.
layers
[
layer
](
hidden
)
[
0
]
)
assert
torch
.
allclose
(
hf_model
.
transformer
.
ln_f
(
hidden
),
based_model
.
ln_final
(
hidden
))
hidden
=
hf_model
.
transformer
.
ln_f
(
hidden
)
assert
torch
.
allclose
(
hf_model
.
transformer
(
x
)[
"last_hidden_state"
],
based_model
.
get_embeds
(
x
))
assert
torch
.
allclose
(
hf_model
.
transformer
(
x
)[
"last_hidden_state"
],
based_model
.
get_embeds
(
x
)
[
0
]
)
assert
torch
.
allclose
(
hf_model
(
x
)[
"logits"
],
based_model
(
x
))
\ No newline at end of file
scripts/test_cache.py
0 → 100644
View file @
1f8d04b2
from
basedformer
import
gptj
from
basedformer.utils
import
*
from
transformers
import
AutoTokenizer
from
icecream
import
ic
import
time
import
sys
def
print_top_k
(
logits
,
tokenizer
,
k
):
topk_ind
=
logits
.
topk
(
k
)[
1
]
for
x
in
range
(
topk_ind
.
shape
[
0
]):
for
y
in
range
(
topk_ind
.
shape
[
1
]):
print
(
"
\n
Token "
+
str
(
y
))
for
token
in
topk_ind
[
x
,
y
,
:]
.
tolist
():
print
(
tokenizer
.
decode
([
token
]),
end
=
" | "
)
def
main
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
'gpt2'
)
prompt
=
"""I fucked her with my huge donut, when she seen my donut she went"""
tokens
=
tokenizer
.
encode
(
prompt
)
print
(
"Prompt:"
)
for
x
in
range
(
len
(
tokens
)):
print
(
tokenizer
.
decode
([
tokens
[
x
]]),
end
=
" | "
)
print
(
"
\n
Generation:"
)
tokens
=
torch
.
LongTensor
(
tokens
)
.
unsqueeze
(
0
)
.
cuda
()
t
=
time
.
perf_counter
()
model
=
gptj
.
load_gpt_j
()
.
cuda
()
.
half
()
.
eval
()
model
=
model
.
lm
ic
(
time
.
perf_counter
()
-
t
)
with
torch
.
no_grad
():
kv
=
None
tokens_to_generate
=
50
in_tokens
=
tokens
accum_tokens
=
[]
for
x
in
range
(
tokens_to_generate
):
logits
,
kv
=
model
(
in_tokens
,
cache
=
True
,
kv
=
kv
)
in_tokens
=
logits
[:,
-
1
,
:]
.
topk
(
1
)[
1
]
#in_tokens = torch.cat([in_tokens, logits[:, -1, :].topk(1)[1]], dim=1)
print
(
tokenizer
.
decode
(
in_tokens
.
squeeze
(
1
)
.
tolist
()[
-
1
]),
end
=
" | "
)
#accum_tokens = torch.cat(accum_tokens, dim=1)
#accum_tokens = accum_tokens.squeeze(0).tolist()
#print("\n Final token list")
#print(tokenizer.decode(accum_tokens))
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment