Commit 9240d7ea authored by Wes Brown's avatar Wes Brown

Fix wandb reporting, do first zero-eval *before* evaluating the first step.

parent 60893ad7
...@@ -95,7 +95,7 @@ spec: ...@@ -95,7 +95,7 @@ spec:
- name: hypertrainer_image - name: hypertrainer_image
value: 'docker.io/gooseai/basedformer' value: 'docker.io/gooseai/basedformer'
- name: hypertrainer_tag - name: hypertrainer_tag
value: '3b75904' value: '60893ad'
templates: templates:
- name: main - name: main
......
...@@ -379,6 +379,7 @@ train_loader = torch_data.DataLoader(train_dataset, ...@@ -379,6 +379,7 @@ train_loader = torch_data.DataLoader(train_dataset,
wandb.init(project=train_config["project_id"], wandb.init(project=train_config["project_id"],
name=train_config["run_name"], name=train_config["run_name"],
config={**train_config, **model.config}) config={**train_config, **model.config})
print("wandb initialized")
if last_cp: if last_cp:
curr_step = opt.curr_step curr_step = opt.curr_step
...@@ -391,6 +392,8 @@ tokens_per_step = train_config['context_size'] * \ ...@@ -391,6 +392,8 @@ tokens_per_step = train_config['context_size'] * \
train_config['bs'] * \ train_config['bs'] * \
train_config['gas'] train_config['gas']
eval_fn(curr_step)
with tqdm(total=total_steps, initial=curr_step) as t: with tqdm(total=total_steps, initial=curr_step) as t:
for epoch in range(train_config['epochs']): for epoch in range(train_config['epochs']):
for input_ids, labels in train_loader: for input_ids, labels in train_loader:
...@@ -433,11 +436,12 @@ with tqdm(total=total_steps, initial=curr_step) as t: ...@@ -433,11 +436,12 @@ with tqdm(total=total_steps, initial=curr_step) as t:
sec_per_step = (time.perf_counter() - timex) sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step) step_per_sec = (1. / sec_per_step)
tokens_per_sec = step_per_sec * tokens_per_step tokens_per_sec = step_per_sec * tokens_per_step
curr_tokens = tokens_per_step * curr_step curr_tokens = tokens_per_step * (curr_step + 1)
t.set_description(f"{step_per_sec:.2f} steps/s, " t.set_description(f"{step_per_sec:.2f} steps/s, "
f"{sec_per_step:.2f}s/step, " f"{sec_per_step:.2f}s/step, "
f"{tokens_per_sec:.2f}tokens/s, " f"{tokens_per_sec:.2f}tokens/s, "
f"loss={loss:.4f}") f"loss={loss:.4f}, "
f"{curr_tokens} tokens processed")
wandb.log( wandb.log(
{ {
"train/epoch": float(curr_step) / float(epoch_steps), "train/epoch": float(curr_step) / float(epoch_steps),
...@@ -446,29 +450,18 @@ with tqdm(total=total_steps, initial=curr_step) as t: ...@@ -446,29 +450,18 @@ with tqdm(total=total_steps, initial=curr_step) as t:
"train/sec_per_step": sec_per_step, "train/sec_per_step": sec_per_step,
"train/step_per_sec": step_per_sec, "train/step_per_sec": step_per_sec,
"train/lr": opt.curr_lr, "train/lr": opt.curr_lr,
"train/loss_scale": scaler.get_scale() "train/loss_scale": scaler.get_scale(),
"train/tokens": curr_tokens,
}, },
step=curr_step) step=curr_step)
wandb.log(
{
"train_tokens/epoch": float(curr_step) / float(epoch_steps),
"train_tokens/loss": loss,
"train_tokens/tokens_per_sec": tokens_per_sec,
"train_tokens/sec_per_step": sec_per_step,
"train_tokens/step_per_sec": step_per_sec,
"train_tokens/lr": opt.curr_lr,
"train_tokens/loss_scale": scaler.get_scale()
},
step=curr_tokens)
if train_config["do_save"] and \ if train_config["do_save"] and \
curr_step % train_config["save_every"] == 0 and \ curr_step % train_config["save_every"] == 0 and \
curr_step != 0: curr_step != 0:
hypernetwork_saver(f"step_{curr_step}") hypernetwork_saver(f"step_{curr_step}")
print(f"\nSaved model at step {curr_step}") print(f"\nSaved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0: if curr_step % train_config["eval_every"] == 0 and curr_step != 0:
eval_fn(curr_step) eval_fn(curr_step)
curr_step += 1 curr_step += 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment