Commit 5768afc7 authored by Aarni Koskela's avatar Aarni Koskela

Add utility to inspect a model's parameters (to get dtype/device)

parent a84e8421
...@@ -4,6 +4,7 @@ from functools import lru_cache ...@@ -4,6 +4,7 @@ from functools import lru_cache
import torch import torch
from modules import errors, shared from modules import errors, shared
from modules.torch_utils import get_param
if sys.platform == "darwin": if sys.platform == "darwin":
from modules import mac_specific from modules import mac_specific
...@@ -131,7 +132,7 @@ patch_module_list = [ ...@@ -131,7 +132,7 @@ patch_module_list = [
def manual_cast_forward(self, *args, **kwargs): def manual_cast_forward(self, *args, **kwargs):
org_dtype = next(self.parameters()).dtype org_dtype = get_param(self).dtype
self.to(dtype) self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
......
...@@ -11,6 +11,7 @@ from torchvision import transforms ...@@ -11,6 +11,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors from modules import devices, paths, shared, lowvram, modelloader, errors
from modules.torch_utils import get_param
blip_image_eval_size = 384 blip_image_eval_size = 384
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'
...@@ -131,7 +132,7 @@ class InterrogateModels: ...@@ -131,7 +132,7 @@ class InterrogateModels:
self.clip_model = self.clip_model.to(devices.device_interrogate) self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = next(self.clip_model.parameters()).dtype self.dtype = get_param(self.clip_model).dtype
def send_clip_to_ram(self): def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory: if not shared.opts.interrogate_keep_models_in_memory:
......
...@@ -6,6 +6,7 @@ import sgm.models.diffusion ...@@ -6,6 +6,7 @@ import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser from modules import devices, shared, prompt_parser
from modules.torch_utils import get_param
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
...@@ -90,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt ...@@ -90,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def extend_sdxl(model): def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
dtype = next(model.model.diffusion_model.parameters()).dtype dtype = get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn' model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt' model.cond_stage_key = 'txt'
......
from __future__ import annotations
import torch.nn
def get_param(model) -> torch.nn.Parameter:
"""
Find the first parameter in a model or module.
"""
if hasattr(model, "model") and hasattr(model.model, "parameters"):
# Unpeel a model descriptor to get at the actual Torch module.
model = model.model
for param in model.parameters():
return param
raise ValueError(f"No parameters found in model {model!r}")
...@@ -7,6 +7,7 @@ import tqdm ...@@ -7,6 +7,7 @@ import tqdm
from PIL import Image from PIL import Image
from modules import images, shared from modules import images, shared
from modules.torch_utils import get_param
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image): ...@@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image):
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
model_weight = next(iter(model.model.parameters())) param = get_param(model)
img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
......
...@@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta ...@@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from transformers import XLMRobertaModel,XLMRobertaTokenizer from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional from typing import Optional
from modules.torch_utils import get_param
class BertSeriesConfig(BertConfig): class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
...@@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): ...@@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init() self.post_init()
def encode(self,c): def encode(self,c):
device = next(self.parameters()).device device = get_param(self).device
text = self.tokenizer(c, text = self.tokenizer(c,
truncation=True, truncation=True,
max_length=77, max_length=77,
......
...@@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta ...@@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from transformers import XLMRobertaModel,XLMRobertaTokenizer from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional from typing import Optional
from modules.torch_utils import get_param
class BertSeriesConfig(BertConfig): class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
...@@ -68,7 +71,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): ...@@ -68,7 +71,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init() self.post_init()
def encode(self,c): def encode(self,c):
device = next(self.parameters()).device device = get_param(self).device
text = self.tokenizer(c, text = self.tokenizer(c,
truncation=True, truncation=True,
max_length=77, max_length=77,
......
import types
import pytest
import torch
from modules.torch_utils import get_param
@pytest.mark.parametrize("wrapped", [True, False])
def test_get_param(wrapped):
mod = torch.nn.Linear(1, 1)
cpu = torch.device("cpu")
mod.to(dtype=torch.float16, device=cpu)
if wrapped:
# more or less how spandrel wraps a thing
mod = types.SimpleNamespace(model=mod)
p = get_param(mod)
assert p.dtype == torch.float16
assert p.device == cpu
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment