Commit 0bde514b authored by Wes Brown's avatar Wes Brown

Add defaults, project id, logs.

parent 2db358ff
...@@ -291,9 +291,9 @@ parser.add_argument("--logs", type=str, help="log directory location", ...@@ -291,9 +291,9 @@ parser.add_argument("--logs", type=str, help="log directory location",
default="./logs") default="./logs")
parser.add_argument("--masked", type=bool, help="masked softmax fusion") parser.add_argument("--masked", type=bool, help="masked softmax fusion")
parser.add_argument("--sample_vanilla", type=bool, help="sample vanilla model") parser.add_argument("--sample_vanilla", type=bool, help="sample vanilla model")
parser.add_argument("--sample_tokens", type=int, parser.add_argument("--sample_tokens", type=int, default=500,
help="number of tokens to sample") help="number of tokens to sample")
parser.add_argument("--sample_num", type=int, parser.add_argument("--sample_num", type=int, default=3,
help="number of samples per prompt") help="number of samples per prompt")
parser.add_argument("--shuffle", type=bool, help="shuffle dataset contexts") parser.add_argument("--shuffle", type=bool, help="shuffle dataset contexts")
parser.add_argument("--epochs", type=int, help="number of epochs to train for") parser.add_argument("--epochs", type=int, help="number of epochs to train for")
...@@ -304,6 +304,7 @@ if args.output == '': ...@@ -304,6 +304,7 @@ if args.output == '':
args.output = f'./{args.run_name}' args.output = f'./{args.run_name}'
# we need 250 batch size to train the small GPT. # we need 250 batch size to train the small GPT.
train_config = { train_config = {
"project_id": args.project_id,
"data_path": args.dataset, "data_path": args.dataset,
"save_path": args.output, "save_path": args.output,
"lm_path": args.model, "lm_path": args.model,
...@@ -327,6 +328,7 @@ train_config = { ...@@ -327,6 +328,7 @@ train_config = {
"num_tokens": args.sample_tokens, "num_tokens": args.sample_tokens,
"shuffle": args.shuffle, "shuffle": args.shuffle,
"epochs": args.epochs, "epochs": args.epochs,
"logs": args.logs,
} }
torch.manual_seed(train_config["seed"]) torch.manual_seed(train_config["seed"])
bs = train_config["bs"] bs = train_config["bs"]
...@@ -374,7 +376,7 @@ train_loader = torch_data.DataLoader(train_dataset, ...@@ -374,7 +376,7 @@ train_loader = torch_data.DataLoader(train_dataset,
batch_size=bs * gas, batch_size=bs * gas,
shuffle=train_config["shuffle"], shuffle=train_config["shuffle"],
num_workers=0) num_workers=0)
wandb.init(project="hypernetwork-tests", wandb.init(project=train_config["project-id"],
name=train_config["run_name"], name=train_config["run_name"],
config={**train_config, **model.config}) config={**train_config, **model.config})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment