Commit 94c0ad6f authored by Wes Brown's avatar Wes Brown

Set default to `2049` and correctly calculate the tokens per second.

parent 53bcc538
...@@ -233,7 +233,7 @@ parser.add_argument('--output_path', type=str, help="Root path of all output", ...@@ -233,7 +233,7 @@ parser.add_argument('--output_path', type=str, help="Root path of all output",
parser.add_argument('--no_resume', type=bool, default=False, parser.add_argument('--no_resume', type=bool, default=False,
help="Do not resume from last checkpoint") help="Do not resume from last checkpoint")
parser.add_argument("--context_size", type=int, help="Dataset context sizes", parser.add_argument("--context_size", type=int, help="Dataset context sizes",
default=2048) default=2049)
parser.add_argument("--project_id", type=str, help="Project ID for reporting", parser.add_argument("--project_id", type=str, help="Project ID for reporting",
default="hypernetwork-training") default="hypernetwork-training")
parser.add_argument("--logs", type=str, help="log directory location", parser.add_argument("--logs", type=str, help="log directory location",
...@@ -336,7 +336,6 @@ for input_ids, labels in t: ...@@ -336,7 +336,6 @@ for input_ids, labels in t:
logits, _ = model(input_ids[x * bs:(x + 1) * bs, :].to(gpu), logits, _ = model(input_ids[x * bs:(x + 1) * bs, :].to(gpu),
hypernetwork=hypernetwork, hypernetwork=hypernetwork,
act_ck=True) act_ck=True)
# print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = logits.view(-1, logits.shape[-1]) logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x * bs:(x + 1) * bs, :].contiguous() gas_labels = labels[x * bs:(x + 1) * bs, :].contiguous()
gas_labels = gas_labels.view(-1) gas_labels = gas_labels.view(-1)
...@@ -364,7 +363,7 @@ for input_ids, labels in t: ...@@ -364,7 +363,7 @@ for input_ids, labels in t:
opt.zero_grad() opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step) step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 2048) * bs * gas tokens_per_sec = (step_per_sec * train_config["context_size"]) * bs * gas
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step," t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step,"
+ f"{tokens_per_sec:.2f}tokens/s, loss={loss:.4f}") + f"{tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log( wandb.log(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment