Commit a70dfb64 authored by AUTOMATIC1111's avatar AUTOMATIC1111

change import statements for #14478

parent be5f1acc
...@@ -4,7 +4,7 @@ from functools import lru_cache ...@@ -4,7 +4,7 @@ from functools import lru_cache
import torch import torch
from modules import errors, shared from modules import errors, shared
from modules.torch_utils import get_param from modules import torch_utils
if sys.platform == "darwin": if sys.platform == "darwin":
from modules import mac_specific from modules import mac_specific
...@@ -132,7 +132,7 @@ patch_module_list = [ ...@@ -132,7 +132,7 @@ patch_module_list = [
def manual_cast_forward(self, *args, **kwargs): def manual_cast_forward(self, *args, **kwargs):
org_dtype = get_param(self).dtype org_dtype = torch_utils.get_param(self).dtype
self.to(dtype) self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
......
...@@ -10,8 +10,7 @@ import torch.hub ...@@ -10,8 +10,7 @@ import torch.hub
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
from modules.torch_utils import get_param
blip_image_eval_size = 384 blip_image_eval_size = 384
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'
...@@ -132,7 +131,7 @@ class InterrogateModels: ...@@ -132,7 +131,7 @@ class InterrogateModels:
self.clip_model = self.clip_model.to(devices.device_interrogate) self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = get_param(self.clip_model).dtype self.dtype = torch_utils.get_param(self.clip_model).dtype
def send_clip_to_ram(self): def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory: if not shared.opts.interrogate_keep_models_in_memory:
......
...@@ -6,7 +6,7 @@ import sgm.models.diffusion ...@@ -6,7 +6,7 @@ import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser from modules import devices, shared, prompt_parser
from modules.torch_utils import get_param from modules import torch_utils
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
...@@ -91,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt ...@@ -91,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def extend_sdxl(model): def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
dtype = get_param(model.model.diffusion_model).dtype dtype = torch_utils.get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn' model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt' model.cond_stage_key = 'txt'
......
...@@ -6,8 +6,7 @@ import torch ...@@ -6,8 +6,7 @@ import torch
import tqdm import tqdm
from PIL import Image from PIL import Image
from modules import images, shared from modules import images, shared, torch_utils
from modules.torch_utils import get_param
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): ...@@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image):
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
param = get_param(model) param = torch_utils.get_param(model)
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
with torch.no_grad(): with torch.no_grad():
......
...@@ -5,7 +5,7 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta ...@@ -5,7 +5,7 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from transformers import XLMRobertaModel,XLMRobertaTokenizer from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional from typing import Optional
from modules.torch_utils import get_param from modules import torch_utils
class BertSeriesConfig(BertConfig): class BertSeriesConfig(BertConfig):
...@@ -65,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): ...@@ -65,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init() self.post_init()
def encode(self,c): def encode(self,c):
device = get_param(self).device device = torch_utils.get_param(self).device
text = self.tokenizer(c, text = self.tokenizer(c,
truncation=True, truncation=True,
max_length=77, max_length=77,
......
...@@ -4,8 +4,7 @@ import torch ...@@ -4,8 +4,7 @@ import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional from typing import Optional
from modules import torch_utils
from modules.torch_utils import get_param
class BertSeriesConfig(BertConfig): class BertSeriesConfig(BertConfig):
...@@ -71,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): ...@@ -71,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init() self.post_init()
def encode(self,c): def encode(self,c):
device = get_param(self).device device = torch_utils.get_param(self).device
text = self.tokenizer(c, text = self.tokenizer(c,
truncation=True, truncation=True,
max_length=77, max_length=77,
......
...@@ -3,7 +3,7 @@ import types ...@@ -3,7 +3,7 @@ import types
import pytest import pytest
import torch import torch
from modules.torch_utils import get_param from modules import torch_utils
@pytest.mark.parametrize("wrapped", [True, False]) @pytest.mark.parametrize("wrapped", [True, False])
...@@ -14,6 +14,6 @@ def test_get_param(wrapped): ...@@ -14,6 +14,6 @@ def test_get_param(wrapped):
if wrapped: if wrapped:
# more or less how spandrel wraps a thing # more or less how spandrel wraps a thing
mod = types.SimpleNamespace(model=mod) mod = types.SimpleNamespace(model=mod)
p = get_param(mod) p = torch_utils.get_param(mod)
assert p.dtype == torch.float16 assert p.dtype == torch.float16
assert p.device == cpu assert p.device == cpu
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment