Commit eaba3d73 authored by AUTOMATIC1111's avatar AUTOMATIC1111

send weights to target device instead of CPU memory

parent 57e59c14
...@@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper): ...@@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper):
``` ```
""" """
def __init__(self, state_dict, device): def __init__(self, state_dict, device, weight_dtype_conversion=None):
super().__init__() super().__init__()
self.state_dict = state_dict self.state_dict = state_dict
self.device = device self.device = device
self.weight_dtype_conversion = weight_dtype_conversion or {}
self.default_dtype = self.weight_dtype_conversion.get('')
def get_weight_dtype(self, key):
key_first_term, _ = key.split('.', 1)
return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
def __enter__(self): def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization: if shared.cmd_opts.disable_model_loading_ram_optimization:
...@@ -167,24 +173,24 @@ class LoadStateDictOnMeta(ReplaceHelper): ...@@ -167,24 +173,24 @@ class LoadStateDictOnMeta(ReplaceHelper):
sd = self.state_dict sd = self.state_dict
device = self.device device = self.device
def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
used_param_keys = [] used_param_keys = []
for name, param in self._parameters.items(): for name, param in module._parameters.items():
if param is None: if param is None:
continue continue
key = prefix + name key = prefix + name
sd_param = sd.pop(key, None) sd_param = sd.pop(key, None)
if sd_param is not None: if sd_param is not None:
state_dict[key] = sd_param state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
used_param_keys.append(key) used_param_keys.append(key)
if param.is_meta: if param.is_meta:
dtype = sd_param.dtype if sd_param is not None else param.dtype dtype = sd_param.dtype if sd_param is not None else param.dtype
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
for name in self._buffers: for name in module._buffers:
key = prefix + name key = prefix + name
sd_param = sd.pop(key, None) sd_param = sd.pop(key, None)
...@@ -192,12 +198,12 @@ class LoadStateDictOnMeta(ReplaceHelper): ...@@ -192,12 +198,12 @@ class LoadStateDictOnMeta(ReplaceHelper):
state_dict[key] = sd_param state_dict[key] = sd_param
used_param_keys.append(key) used_param_keys.append(key)
original(self, state_dict, prefix, *args, **kwargs) original(module, state_dict, prefix, *args, **kwargs)
for key in used_param_keys: for key in used_param_keys:
state_dict.pop(key, None) state_dict.pop(key, None)
def load_state_dict(original, self, state_dict, strict=True): def load_state_dict(original, module, state_dict, strict=True):
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
...@@ -212,7 +218,7 @@ class LoadStateDictOnMeta(ReplaceHelper): ...@@ -212,7 +218,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
if state_dict == sd: if state_dict == sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
original(self, state_dict, strict=strict) original(module, state_dict, strict=strict)
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
......
...@@ -518,6 +518,13 @@ def send_model_to_cpu(m): ...@@ -518,6 +518,13 @@ def send_model_to_cpu(m):
devices.torch_gc() devices.torch_gc()
def model_target_device():
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
return devices.cpu
else:
return devices.device
def send_model_to_device(m): def send_model_to_device(m):
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
...@@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): ...@@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("create model") timer.record("create model")
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): if shared.cmd_opts.no_half:
weight_dtype_conversion = None
else:
weight_dtype_conversion = {
'first_stage_model': None,
'': torch.float16,
}
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer) load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict") timer.record("load weights from state dict")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment