Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
39568281
Commit
39568281
authored
Apr 14, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
simplify attention on gpt2
parent
86e815ab
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
280 additions
and
44 deletions
+280
-44
basedformer/crosslm.py
basedformer/crosslm.py
+244
-0
basedformer/gpt2.py
basedformer/gpt2.py
+14
-37
basedformer/noemblm.py
basedformer/noemblm.py
+17
-2
train.py
train.py
+5
-5
No files found.
basedformer/crosslm.py
0 → 100644
View file @
39568281
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basedformer.utils
import
*
from
torch.utils.checkpoint
import
checkpoint
as
ck
from
einops
import
rearrange
,
repeat
try
:
from
collections.abc
import
MutableMapping
except
ImportError
:
from
collections
import
MutableMapping
import
os
from
pathlib
import
Path
import
math
from
basedformer
import
lm_base
def
fixed_pos_embedding
(
dim
=
None
,
seq_len
=
None
,
x
=
None
):
if
x
is
None
:
x
=
torch
.
empty
(
0
)
inv_freq
=
1.
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
.
to
(
x
.
dtype
)
.
to
(
x
.
device
)
sinusoid_inp
=
torch
.
einsum
(
'i , j -> i j'
,
torch
.
arange
(
seq_len
)
.
to
(
x
.
device
),
inv_freq
)
.
float
()
return
torch
.
sin
(
sinusoid_inp
),
torch
.
cos
(
sinusoid_inp
)
def
rotate_every_two
(
x
):
x1
=
x
[:,
:,
:,
::
2
]
x2
=
x
[:,
:,
:,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
rearrange
(
x
,
'... d j -> ... (d j)'
)
def
apply_rotary_pos_emb
(
x
,
sincos
,
offset
=
0
):
sin
,
cos
=
map
(
lambda
t
:
repeat
(
t
[
offset
:
x
.
shape
[
1
]
+
offset
,:],
"n d -> () n () (d j)"
,
j
=
2
),
sincos
)
return
(
x
*
cos
)
+
(
rotate_every_two
(
x
)
*
sin
)
def
_split_heads
(
tensor
,
num_heads
,
attn_head_size
,
rotary
):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape
=
tensor
.
size
()[:
-
1
]
+
(
num_heads
,
attn_head_size
)
tensor
=
tensor
.
view
(
*
new_shape
)
if
rotary
:
return
tensor
if
len
(
tensor
.
shape
)
==
5
:
return
tensor
.
permute
(
0
,
1
,
3
,
2
,
4
)
# (batch, blocks, head, block_length, head_features)
elif
len
(
tensor
.
shape
)
==
4
:
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
else
:
raise
ValueError
(
f
"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}"
)
def
_merge_heads
(
tensor
,
num_heads
,
attn_head_size
):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if
len
(
tensor
.
shape
)
==
5
:
tensor
=
tensor
.
permute
(
0
,
1
,
3
,
2
,
4
)
.
contiguous
()
elif
len
(
tensor
.
shape
)
==
4
:
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
)
.
contiguous
()
else
:
raise
ValueError
(
f
"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}"
)
new_shape
=
tensor
.
size
()[:
-
2
]
+
(
num_heads
*
attn_head_size
,)
return
tensor
.
view
(
new_shape
)
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
attn_weights
/
scale_attn
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
attn_weights
.
to
(
value
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value
)
.
to
(
value
.
dtype
)
return
attn_output
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
hidden_dim
,
n_head
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
hidden_dim
//
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
hidden_dim
self
.
n_head
=
n_head
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
sin
,
cos
=
fixed_pos_embedding
(
dim
=
self
.
rotary_dim
,
seq_len
=
max_positions
)
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"cos"
,
cos
)
def
forward
(
self
,
x
):
query
=
self
.
q_proj
(
x
)
key
=
self
.
k_proj
(
x
)
value
=
self
.
v_proj
(
x
)
query
=
_split_heads
(
query
,
self
.
n_head
,
self
.
head_dim
,
True
)
key
=
_split_heads
(
key
,
self
.
n_head
,
self
.
head_dim
,
True
)
value
=
_split_heads
(
value
,
self
.
n_head
,
self
.
head_dim
,
False
)
offset
=
0
if
self
.
rotary_dim
<
self
.
head_dim
:
k_rot
=
key
[:,
:,
:,
:
self
.
rotary_dim
]
k_pass
=
key
[:,
:,
:,
self
.
rotary_dim
:]
q_rot
=
query
[:,
:,
:,
:
self
.
rotary_dim
]
q_pass
=
query
[:,
:,
:,
self
.
rotary_dim
:]
k_rot
=
apply_rotary_pos_emb
(
k_rot
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
k_rot
.
dtype
)
q_rot
=
apply_rotary_pos_emb
(
q_rot
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
q_rot
.
dtype
)
key
=
torch
.
cat
([
k_rot
,
k_pass
],
dim
=-
1
)
query
=
torch
.
cat
([
q_rot
,
q_pass
],
dim
=-
1
)
else
:
key
=
apply_rotary_pos_emb
(
key
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
key
.
dtype
)
query
=
apply_rotary_pos_emb
(
query
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
query
.
dtype
)
key
=
key
.
permute
(
0
,
2
,
1
,
3
)
query
=
query
.
permute
(
0
,
2
,
1
,
3
)
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
_merge_heads
(
x
,
self
.
n_head
,
self
.
head_dim
)
x
=
self
.
out_proj
(
x
)
return
x
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
hidden_dim
,
activation
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
self
.
ff1
=
nn
.
Linear
(
dim
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ff2
=
nn
.
Linear
(
hidden_dim
,
dim
,
device
=
device
,
dtype
=
dtype
)
self
.
activation
=
activation
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
if
act_ck
:
x
=
ck
(
self
.
activation
,
x
)
else
:
x
=
self
.
activation
(
x
)
x
=
self
.
ff2
(
x
)
return
x
class
CrossGPTLayer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
hidden_dim
,
n_head
,
eps
,
activation
,
device
,
dtype
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_dim
=
hidden_dim
self
.
ln_preattn
=
nn
.
LayerNorm
(
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
ff
=
ff
(
dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
*
4
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
)
self
.
attn
=
attn
(
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
device
=
device
,
dtype
=
dtype
)
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
=
ck
(
self
.
attn
,
x
)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
=
self
.
attn
(
x
)
if
hypernetwork
:
if
diff_hypernets
:
if
interleaving_layers
and
layer_id
%
every_n
==
0
:
if
self
.
tick
:
hyper_out
=
hypernetwork
[
0
](
x
)
self
.
tick
=
False
else
:
hyper_out
=
hypernetwork
[
1
](
x
)
self
.
tick
=
True
elif
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
[(
layer_id
//
every_n
)
-
1
](
x
)
else
:
if
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
(
x
)
ff_out
=
self
.
ff
(
x
,
act_ck
)
#order of addition matters, i had no idea... fixed a bug here.
x
=
attn_out
+
ff_out
+
residual
#x = residual + attn_out + ff_out -> doesn't match.
if
hypernetwork
and
layer_id
%
every_n
==
0
:
x
=
x
+
hyper_out
return
x
class
CrossGPTModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
n_layer
,
n_head
,
vocab_dim
,
eps
,
activation
=
gelu_new
,
Layer
=
CrossGPTLayer
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
self
.
n_layer
=
n_layer
self
.
hidden_dim
=
hidden_dim
self
.
vocab_embed
=
nn
.
Embedding
(
vocab_dim
,
self
.
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ln_final
=
nn
.
LayerNorm
(
self
.
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
lm_head
=
nn
.
Linear
(
hidden_dim
,
vocab_dim
,
bias
=
True
)
for
_
in
range
(
n_layer
):
self
.
layers
.
append
(
Layer
(
attn
=
SelfAttention
,
ff
=
FeedForward
,
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
eps
=
eps
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
))
def
forward
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
x
=
self
.
get_embeds
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
=
self
.
lm_head
(
x
)
return
x
.
float
()
def
get_embeds
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
x
=
self
.
vocab_embed
(
x
)
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
,
layer_id
=
layer_id
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
=
self
.
ln_final
(
x
)
return
x
class
CrossGPTBaseLM
(
lm_base
.
BaseLM
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
nn
.
Module
.
__init__
(
self
)
lm_base
.
BaseLM
.
__init__
(
self
,
config
,
lm
)
self
.
model_class
=
CrossGPTModel
def
load_gpt_j
(
path
=
"models/6b"
,
state_dict
=
None
):
config
=
{
"n_layer"
:
28
,
"n_head"
:
16
,
"hidden_dim"
:
4096
,
"vocab_dim"
:
50400
,
"eps"
:
1e-5
}
model
=
CrossGPTBaseLM
.
load
(
config
,
path
,
state_dict
)
return
model
basedformer/gpt2.py
View file @
39568281
...
...
@@ -38,34 +38,6 @@ def shift_tokens(x, amt, eps = 1e-5):
def
shift
(
x
,
amt
,
dim
=
-
1
):
return
F
.
pad
(
x
,
(
*
((
0
,
0
)
*
(
-
dim
-
1
)),
amt
,
-
amt
),
value
=
0.
)
def
_split_heads
(
tensor
,
num_heads
,
attn_head_size
,
rotary
):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape
=
tensor
.
size
()[:
-
1
]
+
(
num_heads
,
attn_head_size
)
tensor
=
tensor
.
view
(
*
new_shape
)
if
rotary
:
return
tensor
if
len
(
tensor
.
shape
)
==
5
:
return
tensor
.
permute
(
0
,
1
,
3
,
2
,
4
)
# (batch, blocks, head, block_length, head_features)
elif
len
(
tensor
.
shape
)
==
4
:
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
else
:
raise
ValueError
(
f
"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}"
)
def
_merge_heads
(
tensor
,
num_heads
,
attn_head_size
):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if
len
(
tensor
.
shape
)
==
5
:
tensor
=
tensor
.
permute
(
0
,
1
,
3
,
2
,
4
)
.
contiguous
()
elif
len
(
tensor
.
shape
)
==
4
:
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
)
.
contiguous
()
else
:
raise
ValueError
(
f
"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}"
)
new_shape
=
tensor
.
size
()[:
-
2
]
+
(
num_heads
*
attn_head_size
,)
return
tensor
.
view
(
new_shape
)
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
...
...
@@ -103,14 +75,19 @@ class SelfAttention(nn.Module):
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
query
=
self
.
q_proj
(
x
)
key
=
self
.
k_proj
(
x
)
value
=
self
.
v_proj
(
x
)
query
=
_split_heads
(
query
,
self
.
n_head
,
self
.
head_dim
,
True
)
key
=
_split_heads
(
key
,
self
.
n_head
,
self
.
head_dim
,
True
)
value
=
_split_heads
(
value
,
self
.
n_head
,
self
.
head_dim
,
False
)
def
forward
(
self
,
x
,
kv
=
None
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
self
.
n_head
,
S
,
self
.
head_dim
)
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
self
.
n_head
,
S
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
self
.
n_head
,
S
,
self
.
head_dim
)
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
key
=
key
.
permute
(
0
,
2
,
1
,
3
)
query
=
query
.
permute
(
0
,
2
,
1
,
3
)
...
...
@@ -122,7 +99,7 @@ class SelfAttention(nn.Module):
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
_merge_heads
(
x
,
self
.
n_head
,
self
.
head_dim
)
x
=
x
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
return
x
...
...
basedformer/noemblm.py
View file @
39568281
...
...
@@ -31,6 +31,14 @@ def token_shift_no_mix(x, window_size=1):
x
=
torch
.
cat
(
time_shifts
+
current_x
,
dim
=-
1
)
return
x
def
token_shift_fast
(
x
,
n_tokens
=
1
):
size
=
x
.
size
()[
-
1
]
//
(
n_tokens
+
1
)
seq_len
=
x
.
size
()[
-
2
]
padded_x
=
nn
.
functional
.
pad
(
x
[:,
:,
:
size
],
(
0
,
0
,
n_tokens
,
0
))
token_shifts
=
[
padded_x
[:,
offset
:(
offset
+
seq_len
)]
for
offset
in
range
(
n_tokens
)]
current_x
=
[
x
[:,
:,
len
(
token_shifts
)
*
size
:]]
x
=
torch
.
cat
(
token_shifts
+
current_x
,
dim
=-
1
)
return
x
def
_split_heads
(
tensor
,
num_heads
,
attn_head_size
,
rotary
):
"""
...
...
@@ -135,7 +143,7 @@ class FeedForward(nn.Module):
else
:
x
=
self
.
activation
(
x
)
x
=
token_shift_
no_mix
(
x
)
x
=
token_shift_
fast
(
x
)
x
=
self
.
ff2
(
x
)
return
x
...
...
@@ -148,6 +156,7 @@ class GPTNELayer(nn.Module):
self
.
ff
=
ff
(
dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
*
4
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
)
self
.
attn
=
attn
(
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
device
=
device
,
dtype
=
dtype
)
self
.
tick
=
True
self
.
residual_gate
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
):
residual
=
x
...
...
@@ -179,7 +188,13 @@ class GPTNELayer(nn.Module):
ff_out
=
self
.
ff
(
x
,
act_ck
)
#order of addition matters, i had no idea... fixed a bug here.
x
=
attn_out
+
ff_out
+
residual
if
layer_id
%
1
==
0
:
x
=
attn_out
+
ff_out
+
residual
else
:
x
=
(
attn_out
+
ff_out
)
*
residual
#x = residual + attn_out + ff_out -> doesn't match.
if
hypernetwork
and
layer_id
%
every_n
==
0
:
x
=
x
+
hyper_out
...
...
train.py
View file @
39568281
...
...
@@ -27,11 +27,11 @@ train_config = {
#"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
"data_path"
:
"/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map"
,
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path"
:
"/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2gptj-nopos-tokenshift
nomix
-3L-16A-1024H"
,
"do_save"
:
Tru
e
,
"run_name"
:
"gptj-nopos-tokenshift
nomix
-owt2-72M-3L-16A-1024H-fp16AMP-512ctx-16bs-1e-4lrinit"
,
"lr"
:
1e-
4
,
"end_lr"
:
1e-
4
,
"save_path"
:
"/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2gptj-nopos-tokenshift
-superhighlr-residualgate
-3L-16A-1024H"
,
"do_save"
:
Fals
e
,
"run_name"
:
"gptj-nopos-tokenshift
-superhighlr
-owt2-72M-3L-16A-1024H-fp16AMP-512ctx-16bs-1e-4lrinit"
,
"lr"
:
1e-
3
,
"end_lr"
:
1e-
3
,
"warmup_steps"
:
100
,
"bs"
:
16
,
"gas"
:
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment