Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
94ad5ad7
Commit
94ad5ad7
authored
Jul 05, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update hypetrain with API changes
parent
5a4b10c7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
13 deletions
+10
-13
hypertrain.py
hypertrain.py
+10
-13
No files found.
hypertrain.py
View file @
94ad5ad7
...
...
@@ -12,7 +12,7 @@ import wandb
import
numpy
as
np
from
torch.utils.checkpoint
import
checkpoint
as
ck
from
math
import
log2
,
ceil
from
basedformer
import
optimizer
,
lm_utils
from
basedformer
import
optimizer
,
lm_utils
,
dataset
from
basedformer.utils
import
*
import
glob
from
transformers
import
AutoTokenizer
...
...
@@ -191,10 +191,8 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None):
# we need 250 batch size to train the small GPT.
train_config
=
{
"data_path"
:
"/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map"
,
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v7_infilling.map",
##"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path"
:
"/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-fairseq-6b-2048-enwik9-again"
,
"lm_path"
:
"/home/xuser/nvme1/pretrained/sigurdv4"
,
"do_save"
:
True
,
"run_name"
:
"fairseq-6b-enwik9-6b-postln-bf16-2e-4-4bsz-every5layer"
,
"lr"
:
2e-4
,
...
...
@@ -215,7 +213,7 @@ gas = train_config["gas"]
Path
(
train_config
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model
=
lm_utils
.
load_from_path
(
"
pretrained/fairseq_6_7b
"
)
.
cuda
()
.
bfloat16
()
model
=
lm_utils
.
load_from_path
(
"
/home/xuser/nvme1/pretrained/sigurdv4
"
)
.
cuda
()
.
bfloat16
()
for
param
in
model
.
parameters
():
param
.
requires_grad
=
False
...
...
@@ -243,7 +241,7 @@ else:
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print
(
opt
.
curr_step
)
train_dataset
=
Fb
Dataset
(
2049
,
train_config
[
"data_path"
])
train_dataset
=
dataset
.
Sharded
Dataset
(
2049
,
train_config
[
"data_path"
])
if
last_cp
:
train_dataset
.
skip
=
opt
.
curr_step
*
bs
*
gas
...
...
@@ -309,13 +307,12 @@ for input_ids, labels in t:
},
step
=
curr_step
)
if
train_config
[
"do_save"
]:
if
curr_step
%
train_config
[
"save_every"
]
==
0
and
curr_step
!=
0
:
save_folder
=
Path
(
train_config
[
"save_path"
])
/
f
"step_{curr_step}"
save_folder
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
torch
.
save
(
hypernetwork
.
state_dict
(),
save_folder
/
"hyper.pt"
)
opt
.
save
(
save_folder
/
"opt"
)
print
(
f
"Saved model at step {curr_step}"
)
if
train_config
[
"do_save"
]
and
curr_step
%
train_config
[
"save_every"
]
==
0
and
curr_step
!=
0
:
save_folder
=
Path
(
train_config
[
"save_path"
])
/
f
"step_{curr_step}"
save_folder
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
torch
.
save
(
hypernetwork
.
state_dict
(),
save_folder
/
"hyper.pt"
)
opt
.
save
(
save_folder
/
"opt"
)
print
(
f
"Saved model at step {curr_step}"
)
if
curr_step
%
train_config
[
"eval_every"
]
==
0
:
sample
(
"<|endoftext|>"
,
500
,
3
,
hypernetwork
=
hypernetwork
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment