Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
8d44445e
Commit
8d44445e
authored
May 09, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
config dataclass, not sure about this structure
i am sad
parent
a1b9e387
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
26 deletions
+48
-26
basedformer/gptj.py
basedformer/gptj.py
+48
-26
No files found.
basedformer/gptj.py
View file @
8d44445e
...
@@ -14,6 +14,7 @@ import os
...
@@ -14,6 +14,7 @@ import os
from
pathlib
import
Path
from
pathlib
import
Path
import
math
import
math
from
basedformer
import
lm_base
from
basedformer
import
lm_base
from
dataclasses
import
dataclass
def
fixed_pos_embedding
(
dim
=
None
,
seq_len
=
None
,
x
=
None
):
def
fixed_pos_embedding
(
dim
=
None
,
seq_len
=
None
,
x
=
None
):
if
x
is
None
:
if
x
is
None
:
...
@@ -51,23 +52,23 @@ def _attn(query, key, value, causal_mask, masked_bias,
...
@@ -51,23 +52,23 @@ def _attn(query, key, value, causal_mask, masked_bias,
class
SelfAttention
(
nn
.
Module
):
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
hidden_dim
,
n_head
,
device
,
dtype
):
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
hidden_dim
//
n_head
self
.
head_dim
=
config
.
hidden_dim
//
config
.
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
hidden_dim
self
.
hidden_dim
=
config
.
hidden_dim
self
.
n_head
=
n_head
self
.
n_head
=
config
.
n_head
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
attn_bias
=
False
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
sin
,
cos
=
fixed_pos_embedding
(
dim
=
self
.
rotary_dim
,
seq_len
=
max_positions
)
sin
,
cos
=
fixed_pos_embedding
(
dim
=
self
.
rotary_dim
,
seq_len
=
max_positions
)
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"cos"
,
cos
)
self
.
register_buffer
(
"cos"
,
cos
)
...
@@ -129,11 +130,11 @@ class SelfAttention(nn.Module):
...
@@ -129,11 +130,11 @@ class SelfAttention(nn.Module):
return
x
,
None
return
x
,
None
class
FeedForward
(
nn
.
Module
):
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
hidden_dim
,
activation
,
device
,
dtype
):
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
ff1
=
nn
.
Linear
(
dim
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ff1
=
nn
.
Linear
(
config
.
hidden_dim
,
config
.
hidden_dim
*
4
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff2
=
nn
.
Linear
(
hidden_dim
,
dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ff2
=
nn
.
Linear
(
config
.
hidden_dim
*
4
,
config
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
activation
=
activation
self
.
activation
=
config
.
activation
def
forward
(
self
,
x
,
act_ck
=
False
):
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
x
=
self
.
ff1
(
x
)
...
@@ -145,12 +146,11 @@ class FeedForward(nn.Module):
...
@@ -145,12 +146,11 @@ class FeedForward(nn.Module):
return
x
return
x
class
GPTJLayer
(
nn
.
Module
):
class
GPTJLayer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
hidden_dim
,
n_head
,
eps
,
activation
,
device
,
dtype
):
def
__init__
(
self
,
attn
,
ff
,
config
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
hidden_dim
=
hidden_dim
self
.
ln_preattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
type
)
self
.
ln_preattn
=
nn
.
LayerNorm
(
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
ff
=
ff
(
config
)
self
.
ff
=
ff
(
dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
*
4
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
)
self
.
attn
=
attn
(
config
)
self
.
attn
=
attn
(
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
device
=
device
,
dtype
=
dtype
)
self
.
tick
=
True
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
,
cache
=
False
,
kv
=
None
):
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
,
cache
=
False
,
kv
=
None
):
...
@@ -190,16 +190,22 @@ class GPTJLayer(nn.Module):
...
@@ -190,16 +190,22 @@ class GPTJLayer(nn.Module):
return
x
,
kv
return
x
,
kv
class
GPTJModel
(
nn
.
Module
):
class
GPTJModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
n_layer
,
n_head
,
vocab_dim
,
eps
,
activation
=
gelu_new
,
Layer
=
GPTJLayer
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
n_layer
=
n_layer
self
.
n_layer
=
config
.
n_layer
self
.
hidden_dim
=
hidden_dim
self
.
hidden_dim
=
config
.
hidden_dim
self
.
vocab_embed
=
nn
.
Embedding
(
vocab_dim
,
self
.
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
vocab_embed
=
nn
.
Embedding
(
config
.
vocab_dim
,
self
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ln_final
=
nn
.
LayerNorm
(
self
.
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
ln_final
=
nn
.
LayerNorm
(
self
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
layers
=
nn
.
ModuleList
([])
self
.
lm_head
=
nn
.
Linear
(
hidden_dim
,
vocab_dim
,
bias
=
True
)
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_dim
,
config
.
vocab_dim
,
bias
=
True
)
for
_
in
range
(
n_layer
):
for
_
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
Layer
(
attn
=
SelfAttention
,
ff
=
FeedForward
,
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
eps
=
eps
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
))
self
.
layers
.
append
(
config
.
Layer
(
attn
=
SelfAttention
,
ff
=
FeedForward
,
config
=
config
,
)
)
def
forward
(
self
,
x
,
target
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
def
forward
(
self
,
x
,
target
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
x
,
kv
=
self
.
get_embeds
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
,
kv
=
kv
,
cache
=
cache
)
x
,
kv
=
self
.
get_embeds
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
,
kv
=
kv
,
cache
=
cache
)
...
@@ -238,6 +244,22 @@ class GPTJModel(nn.Module):
...
@@ -238,6 +244,22 @@ class GPTJModel(nn.Module):
else
:
else
:
return
x
,
None
return
x
,
None
@
dataclass
class
GPTJConfig
:
n_layer
:
int
=
6
n_head
:
int
=
8
hidden_dim
:
int
=
512
vocab_dim
:
int
=
50400
eps
:
float
=
1e-5
device
:
torch
.
device
=
torch
.
device
(
'cuda'
)
dtype
:
torch
.
dtype
=
torch
.
float16
Layer
=
GPTJLayer
activation
=
gelu_new
def
from_dict
(
self
,
config_dict
):
for
k
,
v
in
config_dict
.
items
():
setattr
(
self
,
k
,
v
)
class
GPTJBaseLM
(
lm_base
.
BaseLM
):
class
GPTJBaseLM
(
lm_base
.
BaseLM
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
...
@@ -252,6 +274,6 @@ def load_gpt_j(path="models/6b", state_dict=None):
...
@@ -252,6 +274,6 @@ def load_gpt_j(path="models/6b", state_dict=None):
"vocab_dim"
:
50400
,
"vocab_dim"
:
50400
,
"eps"
:
1e-5
"eps"
:
1e-5
}
}
config
=
DotMap
(
config
)
config
=
GPTJConfig
(
**
config
)
model
=
GPTJBaseLM
.
load
(
config
,
path
,
state_dict
)
model
=
GPTJBaseLM
.
load
(
config
,
path
,
state_dict
)
return
model
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment