Commit 53bcc538 authored by Wes Brown's avatar Wes Brown

Fix shuffling and use provided context size on CLI.

parent 6c1a2d67
...@@ -263,6 +263,7 @@ train_config = { ...@@ -263,6 +263,7 @@ train_config = {
"amp": args.amp, "amp": args.amp,
"loss_scale": args.loss_scale, "loss_scale": args.loss_scale,
"eval_every": args.eval_every, "eval_every": args.eval_every,
"context_size": args.context_size,
} }
torch.manual_seed(train_config["seed"]) torch.manual_seed(train_config["seed"])
bs = train_config["bs"] bs = train_config["bs"]
...@@ -303,13 +304,14 @@ else: ...@@ -303,13 +304,14 @@ else:
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden # TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden
# states from the get_logits function. # states from the get_logits function.
print(opt.curr_step) print(opt.curr_step)
train_dataset = dataset.ShardedDataset(2049, train_config["data_path"]) train_dataset = dataset.ShardedDataset(train_config["context_size"],
train_config["data_path"])
if last_cp: if last_cp:
train_dataset.skip = opt.curr_step * bs * gas train_dataset.skip = opt.curr_step * bs * gas
train_loader = data.DataLoader(train_dataset, train_loader = data.DataLoader(train_dataset,
batch_size=bs * gas, batch_size=bs * gas,
shuffle=False, shuffle=True,
num_workers=0) num_workers=0)
wandb.init(project="hypernetwork-tests", wandb.init(project="hypernetwork-tests",
name=train_config["run_name"], name=train_config["run_name"],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment