Commit 6aa2089c authored by novelailab's avatar novelailab

add SplitCheckpoint support.

parent da39346c
......@@ -2,6 +2,12 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
def no_init(loading_code):
def dummy(self):
......@@ -19,6 +25,46 @@ def no_init(loading_code):
return result
SPLIT_WEIGHTS_NAME = "m.pt"
class SplitCheckpoint(MutableMapping):
def __init__(self, name_or_path, device="cpu", subfolder=None):
self.device = device
localpath = Path(name_or_path)
if subfolder is not None:
localpath = localpath / subfolder
if os.path.isfile(localpath):
self.chkpt_dir = localpath.parent
self.remote = False
elif os.path.isfile(localpath / SPLIT_WEIGHTS_NAME):
self.chkpt_dir = localpath
self.checkpoint = torch.load(str(localpath / SPLIT_WEIGHTS_NAME))
self.remote = False
self.checkpoint = self._load(SPLIT_WEIGHTS_NAME, None)
def _load(self, name, shape, **kwparams):
path = str(self.chkpt_dir / name)
return torch.load(path, **kwparams)
def __len__(self):
return len(self.checkpoint)
def __getitem__(self, key):
name = self.checkpoint[key]
if type(name) is tuple:
return self._load(name[0].split('/')[-1], name[1], map_location=self.device)
else:
return self._load(name.split('/')[-1], None, map_location=self.device)
def __setitem__(self, key, value):
return
def __delitem__(self, key, value):
return
def keys(self):
return self.checkpoint.keys()
def __iter__(self):
for key in self.checkpoint:
yield (key, self.__getitem__(key))
def __copy__(self):
return SplitCheckpoint(self.chkpt_dir, device=self.device)
def copy(self):
return SplitCheckpoint(self.chkpt_dir, device=self.device)
#TODO: Might change with non einsum functions?
def fixed_pos_embedding(dim=None, seq_len=None, x=None):
......@@ -187,13 +233,16 @@ class GPTModel(nn.Module):
return x
def load(self, path):
state_dict = torch.load(path)
state_dict = SplitCheckpoint(path, device="cuda")
self.load_state_dict(state_dict)
#TODO: Get SplitCheckpoint support
def save(self, path):
torch.save(self.state_dict(), path)
#TODO: Get SplitCheckpoint support
try: os.mkdir(path)
except: pass
checkpoint = {}
for i, x in enumerate(self.state_dict().items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
# TODO: Do we want to have the LM head as a seperate Class? Or just a function? I think we might be better off with a function here and maybe
# also for the self attention, we can just write a function that gets fed in the q, k, v.
......@@ -210,7 +259,7 @@ def load_gpt_j(path):
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-4,
"eps": 1e-5,
"activation": nn.GELU,
"Layer": GPTLayer
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment