Commit 55a00bbf authored by Wes Brown's avatar Wes Brown

Epoch support, and mask `<|endoftext|>`

parent eebb1fa8
...@@ -28,7 +28,8 @@ prompts = ["<|endoftext|>", ...@@ -28,7 +28,8 @@ prompts = ["<|endoftext|>",
"The mercurial and beautiful", "The mercurial and beautiful",
"<|endoftext|>[ Author:", "<|endoftext|>[ Author:",
"<|endoftext|>[ Genre:", "<|endoftext|>[ Genre:",
"***"] "***",
"----"]
def _init_weights(module): def _init_weights(module):
...@@ -285,6 +286,7 @@ parser.add_argument("--logs", type=str, help="log directory location", ...@@ -285,6 +286,7 @@ parser.add_argument("--logs", type=str, help="log directory location",
parser.add_argument("--masked", type=bool, help="masked softmax fusion") parser.add_argument("--masked", type=bool, help="masked softmax fusion")
parser.add_argument("--sample_vanilla", type=bool, help="sample vanilla model") parser.add_argument("--sample_vanilla", type=bool, help="sample vanilla model")
parser.add_argument("--shuffle", type=bool, help="shuffle dataset contexts") parser.add_argument("--shuffle", type=bool, help="shuffle dataset contexts")
parser.add_argument("--epochs", type=int, help="number of epochs to train for")
parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False, parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False,
sample_vanilla=False, shuffle=False) sample_vanilla=False, shuffle=False)
args = parser.parse_args() args = parser.parse_args()
...@@ -312,6 +314,7 @@ train_config = { ...@@ -312,6 +314,7 @@ train_config = {
"context_size": args.context_size, "context_size": args.context_size,
"sample_vanilla": args.sample_vanilla, "sample_vanilla": args.sample_vanilla,
"shuffle": args.shuffle, "shuffle": args.shuffle,
"epochs": args.epochs,
} }
torch.manual_seed(train_config["seed"]) torch.manual_seed(train_config["seed"])
bs = train_config["bs"] bs = train_config["bs"]
...@@ -368,70 +371,79 @@ if last_cp: ...@@ -368,70 +371,79 @@ if last_cp:
else: else:
curr_step = 0 curr_step = 0
t = tqdm(train_loader, initial=curr_step) epoch_steps = len(train_loader)
total_steps = epoch_steps * train_config['epochs']
for input_ids, labels in t:
timex = time.perf_counter() with tqdm(total=total_steps, initial=curr_step) as t:
input_ids = input_ids.to(gpu) for epoch in range(train_config['epochs']):
labels = labels.to(gpu) for input_ids, labels in train_loader:
loss = 0 timex = time.perf_counter()
for x in range(train_config["gas"]): input_ids = input_ids.to(gpu)
with amp.autocast(enabled=train_config["amp"], labels = labels.to(gpu)
dtype=torch.float16): loss = 0
logits, _ = model(input_ids[x * bs:(x + 1) * bs, :].to(gpu), for x in range(train_config["gas"]):
hypernetwork=hypernetwork, with amp.autocast(enabled=train_config["amp"],
act_ck=True) dtype=torch.float16):
logits = logits.view(-1, logits.shape[-1]) logits, _ = model(input_ids[x * bs:(x + 1) * bs, :].to(gpu),
gas_labels = labels[x * bs:(x + 1) * bs, :].contiguous() hypernetwork=hypernetwork,
gas_labels = gas_labels.view(-1) act_ck=True)
gas_loss = F.cross_entropy(logits, gas_labels) logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x * bs:(x + 1) * bs, :].contiguous()
if train_config["loss_scale"]: gas_labels = gas_labels.view(-1)
scaler.scale(gas_loss).backward() gas_labels[gas_labels == 50256] = -100
else: gas_loss = F.cross_entropy(logits, gas_labels)
gas_loss.backward()
if train_config["loss_scale"]:
loss += gas_loss.item() scaler.scale(gas_loss).backward()
else:
loss = loss / gas gas_loss.backward()
if train_config["loss_scale"]:
scaler.unscale_(opt.optimizer) loss += gas_loss.item()
torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 1)
if train_config["loss_scale"]: loss = loss / gas
opt.step(scaler=scaler) if train_config["loss_scale"]:
else: scaler.unscale_(opt.optimizer)
opt.step() torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 1)
if train_config["loss_scale"]:
if train_config["loss_scale"]: opt.step(scaler=scaler)
scaler.update() else:
opt.step()
opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) if train_config["loss_scale"]:
step_per_sec = (1. / sec_per_step) scaler.update()
tokens_per_sec = (step_per_sec * train_config["context_size"]) * bs * gas
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step," opt.zero_grad()
+ f"{tokens_per_sec:.2f}tokens/s, loss={loss:.4f}") sec_per_step = (time.perf_counter() - timex)
wandb.log( step_per_sec = (1. / sec_per_step)
{ tokens_per_sec = (step_per_sec * train_config["context_size"]) * \
"train/loss": loss, bs * gas
"train/tokens_per_sec": tokens_per_sec, t.set_description(f"{step_per_sec:.2f} steps/s, "
"train/sec_per_step": sec_per_step, f"{sec_per_step:.2f}s/step, "
"train/step_per_sec": step_per_sec, f"{tokens_per_sec:.2f}tokens/s, "
"train/lr": opt.curr_lr, f"loss={loss:.4f}")
"train/loss_scale": scaler.get_scale() wandb.log(
}, {
step=curr_step) "train/epoch": float(curr_step) / float(epoch_steps),
"train/loss": loss,
if train_config["do_save"] and \ "train/tokens_per_sec": tokens_per_sec,
curr_step % train_config["save_every"] == 0 and \ "train/sec_per_step": sec_per_step,
curr_step != 0: "train/step_per_sec": step_per_sec,
hypernetwork_saver(f"step_{curr_step}") "train/lr": opt.curr_lr,
print(f"\nSaved model at step {curr_step}") "train/loss_scale": scaler.get_scale()
},
if curr_step % train_config["eval_every"] == 0: step=curr_step)
eval_fn(curr_step)
if train_config["do_save"] and \
curr_step += 1 curr_step % train_config["save_every"] == 0 and \
curr_step != 0:
hypernetwork_saver(f"step_{curr_step}")
print(f"\nSaved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0:
eval_fn(curr_step)
curr_step += 1
t.update(1)
eval_fn(curr_step) eval_fn(curr_step)
hypernetwork_saver("final") hypernetwork_saver("final")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment