Commit 03650022 authored by Sj-Si's avatar Sj-Si

Merge changes from dev

parents 0726a6e1 cb5b335a
...@@ -20,6 +20,12 @@ jobs: ...@@ -20,6 +20,12 @@ jobs:
cache-dependency-path: | cache-dependency-path: |
**/requirements*txt **/requirements*txt
launch.py launch.py
- name: Cache models
id: cache-models
uses: actions/cache@v3
with:
path: models
key: "2023-12-30"
- name: Install test dependencies - name: Install test dependencies
run: pip install wait-for-it -r requirements-test.txt run: pip install wait-for-it -r requirements-test.txt
env: env:
...@@ -33,6 +39,8 @@ jobs: ...@@ -33,6 +39,8 @@ jobs:
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
WEBUI_LAUNCH_LIVE_OUTPUT: "1" WEBUI_LAUNCH_LIVE_OUTPUT: "1"
PYTHONUNBUFFERED: "1" PYTHONUNBUFFERED: "1"
- name: Print installed packages
run: pip freeze
- name: Start test server - name: Start test server
run: > run: >
python -m coverage run python -m coverage run
...@@ -49,7 +57,7 @@ jobs: ...@@ -49,7 +57,7 @@ jobs:
2>&1 | tee output.txt & 2>&1 | tee output.txt &
- name: Run tests - name: Run tests
run: | run: |
wait-for-it --service 127.0.0.1:7860 -t 600 wait-for-it --service 127.0.0.1:7860 -t 20
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server - name: Kill test server
if: always() if: always()
......
...@@ -37,3 +37,4 @@ notification.mp3 ...@@ -37,3 +37,4 @@ notification.mp3
/node_modules /node_modules
/package-lock.json /package-lock.json
/.coverage* /.coverage*
/test/test_outputs
# Stable Diffusion web UI # Stable Diffusion web UI
A browser interface based on Gradio library for Stable Diffusion. A web interface for Stable Diffusion, implemented using Gradio library.
![](screenshot.png) ![](screenshot.png)
...@@ -151,11 +151,12 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al ...@@ -151,11 +151,12 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers - Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git - Spandrel - https://github.com/chaiNNer-org/spandrel implementing
- CodeFormer - https://github.com/sczhou/CodeFormer - GFPGAN - https://github.com/TencentARC/GFPGAN.git
- ESRGAN - https://github.com/xinntao/ESRGAN - CodeFormer - https://github.com/sczhou/CodeFormer
- SwinIR - https://github.com/JingyunLiang/SwinIR - ESRGAN - https://github.com/xinntao/ESRGAN
- Swin2SR - https://github.com/mv-lab/swin2sr - SwinIR - https://github.com/JingyunLiang/SwinIR
- Swin2SR - https://github.com/mv-lab/swin2sr
- LDSR - https://github.com/Hafiidz/latent-diffusion - LDSR - https://github.com/Hafiidz/latent-diffusion
- MiDaS - https://github.com/isl-org/MiDaS - MiDaS - https://github.com/isl-org/MiDaS
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
......
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
...@@ -3,6 +3,9 @@ import os ...@@ -3,6 +3,9 @@ import os
from collections import namedtuple from collections import namedtuple
import enum import enum
import torch.nn as nn
import torch.nn.functional as F
from modules import sd_models, cache, errors, hashes, shared from modules import sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
...@@ -115,6 +118,29 @@ class NetworkModule: ...@@ -115,6 +118,29 @@ class NetworkModule:
if hasattr(self.sd_module, 'weight'): if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape self.shape = self.sd_module.weight.shape
self.ops = None
self.extra_kwargs = {}
if isinstance(self.sd_module, nn.Conv2d):
self.ops = F.conv2d
self.extra_kwargs = {
'stride': self.sd_module.stride,
'padding': self.sd_module.padding
}
elif isinstance(self.sd_module, nn.Linear):
self.ops = F.linear
elif isinstance(self.sd_module, nn.LayerNorm):
self.ops = F.layer_norm
self.extra_kwargs = {
'normalized_shape': self.sd_module.normalized_shape,
'eps': self.sd_module.eps
}
elif isinstance(self.sd_module, nn.GroupNorm):
self.ops = F.group_norm
self.extra_kwargs = {
'num_groups': self.sd_module.num_groups,
'eps': self.sd_module.eps
}
self.dim = None self.dim = None
self.bias = weights.w.get("bias") self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
...@@ -137,7 +163,7 @@ class NetworkModule: ...@@ -137,7 +163,7 @@ class NetworkModule:
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None: if self.bias is not None:
updown = updown.reshape(self.bias.shape) updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape) updown = updown.reshape(output_shape)
if len(output_shape) == 4: if len(output_shape) == 4:
...@@ -155,5 +181,10 @@ class NetworkModule: ...@@ -155,5 +181,10 @@ class NetworkModule:
raise NotImplementedError() raise NotImplementedError()
def forward(self, x, y): def forward(self, x, y):
raise NotImplementedError() """A general forward implementation for all modules"""
if self.ops is None:
raise NotImplementedError()
else:
updown, ex_bias = self.calc_updown(self.sd_module.weight)
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
...@@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule): ...@@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
output_shape = self.weight.shape output_shape = self.weight.shape
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) updown = self.weight.to(orig_weight.device)
if self.ex_bias is not None: if self.ex_bias is not None:
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) ex_bias = self.ex_bias.to(orig_weight.device)
else: else:
ex_bias = None ex_bias = None
......
...@@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule): ...@@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
self.w2b = weights.w["b2.weight"] self.w2b = weights.w["b2.weight"]
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) w1b = self.w1b.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)] output_shape = [w1a.size(0), w1b.size(1)]
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
return self.finalize_updown(updown, orig_weight, output_shape) return self.finalize_updown(updown, orig_weight, output_shape)
...@@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule): ...@@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
self.t2 = weights.w.get("hada_t2") self.t2 = weights.w.get("hada_t2")
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) w1b = self.w1b.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)] output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None: if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)] output_shape = [w1a.size(1), w1b.size(1)]
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) t1 = self.t1.to(orig_weight.device)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:] output_shape += t1.shape[2:]
else: else:
...@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule): ...@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None: if self.t2 is not None:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) t2 = self.t2.to(orig_weight.device)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else: else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
......
...@@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule): ...@@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
self.on_input = weights.w["on_input"].item() self.on_input = weights.w["on_input"].item()
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) w = self.w.to(orig_weight.device)
output_shape = [w.size(0), orig_weight.size(1)] output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input: if self.on_input:
......
...@@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule): ...@@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
if self.w1 is not None: if self.w1 is not None:
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) w1 = self.w1.to(orig_weight.device)
else: else:
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) w1b = self.w1b.to(orig_weight.device)
w1 = w1a @ w1b w1 = w1a @ w1b
if self.w2 is not None: if self.w2 is not None:
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) w2 = self.w2.to(orig_weight.device)
elif self.t2 is None: elif self.t2 is None:
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) w2b = self.w2b.to(orig_weight.device)
w2 = w2a @ w2b w2 = w2a @ w2b
else: else:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) t2 = self.t2.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) w2b = self.w2b.to(orig_weight.device)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
......
...@@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule): ...@@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
return module return module
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) up = self.up_model.weight.to(orig_weight.device)
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) down = self.down_model.weight.to(orig_weight.device)
output_shape = [up.size(0), down.size(1)] output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None: if self.mid_model is not None:
# cp-decomposition # cp-decomposition
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) mid = self.mid_model.weight.to(orig_weight.device)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:] output_shape += mid.shape[2:]
else: else:
......
...@@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule): ...@@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) updown = self.w_norm.to(orig_weight.device)
if self.b_norm is not None: if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) ex_bias = self.b_norm.to(orig_weight.device)
else: else:
ex_bias = None ex_bias = None
......
...@@ -56,7 +56,7 @@ class NetworkModuleOFT(network.NetworkModule): ...@@ -56,7 +56,7 @@ class NetworkModuleOFT(network.NetworkModule):
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) oft_blocks = self.oft_blocks.to(orig_weight.device)
eye = torch.eye(self.block_size, device=self.oft_blocks.device) eye = torch.eye(self.block_size, device=self.oft_blocks.device)
if self.is_kohya: if self.is_kohya:
...@@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule): ...@@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule):
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) R = oft_blocks.to(orig_weight.device)
# This errors out for MultiheadAttention, might need to be handled up-stream # This errors out for MultiheadAttention, might need to be handled up-stream
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
...@@ -77,6 +77,6 @@ class NetworkModuleOFT(network.NetworkModule): ...@@ -77,6 +77,6 @@ class NetworkModuleOFT(network.NetworkModule):
) )
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
output_shape = orig_weight.shape output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape) return self.finalize_updown(updown, orig_weight, output_shape)
import gradio as gr
import logging import logging
import os import os
import re import re
...@@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No ...@@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
emb_db.skipped_embeddings[name] = embedding emb_db.skipped_embeddings[name] = embedding
if failed_to_load_networks: if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
sd_hijack.model_hijack.comments.append(lora_not_found_message)
if shared.opts.lora_not_found_warning_console:
print(f'\n{lora_not_found_message}\n')
if shared.opts.lora_not_found_gradio_warning:
gr.Warning(lora_not_found_message)
purge_networks_from_memory() purge_networks_from_memory()
...@@ -389,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn ...@@ -389,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if module is not None and hasattr(self, 'weight'): if module is not None and hasattr(self, 'weight'):
try: try:
with torch.no_grad(): with torch.no_grad():
updown, ex_bias = module.calc_updown(self.weight) if getattr(self, 'fp16_weight', None) is None:
weight = self.weight
bias = self.bias
else:
weight = self.fp16_weight.clone().to(self.weight.device)
bias = getattr(self, 'fp16_bias', None)
if bias is not None:
bias = bias.clone().to(self.bias.device)
updown, ex_bias = module.calc_updown(weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9 # inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'): if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None: if self.bias is None:
self.bias = torch.nn.Parameter(ex_bias) self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else: else:
self.bias += ex_bias self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e: except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
...@@ -444,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn ...@@ -444,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_current_names = wanted_names self.network_current_names = wanted_names
def network_forward(module, input, original_forward): def network_forward(org_module, input, original_forward):
""" """
Old way of applying Lora by executing operations during layer's forward. Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation. Stacking many loras this way results in big performance degradation.
""" """
if len(loaded_networks) == 0: if len(loaded_networks) == 0:
return original_forward(module, input) return original_forward(org_module, input)
input = devices.cond_cast_unet(input) input = devices.cond_cast_unet(input)
network_restore_weights_from_backup(module) network_restore_weights_from_backup(org_module)
network_reset_cached_weight(module) network_reset_cached_weight(org_module)
y = original_forward(module, input) y = original_forward(org_module, input)
network_layer_name = getattr(module, 'network_layer_name', None) network_layer_name = getattr(org_module, 'network_layer_name', None)
for lora in loaded_networks: for lora in loaded_networks:
module = lora.modules.get(network_layer_name, None) module = lora.modules.get(network_layer_name, None)
if module is None: if module is None:
......
...@@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra ...@@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
})) }))
......
...@@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.slider_preferred_weight = None self.slider_preferred_weight = None
self.edit_notes = None self.edit_notes = None
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes): def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
user_metadata = self.get_user_metadata(name) user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc user_metadata["description"] = desc
user_metadata["sd version"] = sd_version user_metadata["sd version"] = sd_version
user_metadata["activation text"] = activation_text user_metadata["activation text"] = activation_text
user_metadata["preferred weight"] = preferred_weight user_metadata["preferred weight"] = preferred_weight
user_metadata["negative text"] = negative_text
user_metadata["notes"] = notes user_metadata["notes"] = notes
self.write_user_metadata(name, user_metadata) self.write_user_metadata(name, user_metadata)
...@@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
user_metadata.get('activation text', ''), user_metadata.get('activation text', ''),
float(user_metadata.get('preferred weight', 0.0)), float(user_metadata.get('preferred weight', 0.0)),
user_metadata.get('negative text', ''),
gr.update(visible=True if tags else False), gr.update(visible=True if tags else False),
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
] ]
...@@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.taginfo = gr.HighlightedText(label="Training dataset tags") self.taginfo = gr.HighlightedText(label="Training dataset tags")
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
with gr.Row() as row_random_prompt: with gr.Row() as row_random_prompt:
with gr.Column(scale=8): with gr.Column(scale=8):
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
...@@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.taginfo, self.taginfo,
self.edit_activation_text, self.edit_activation_text,
self.slider_preferred_weight, self.slider_preferred_weight,
self.edit_negative_text,
row_random_prompt, row_random_prompt,
random_prompt, random_prompt,
] ]
...@@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.select_sd_version, self.select_sd_version,
self.edit_activation_text, self.edit_activation_text,
self.slider_preferred_weight, self.slider_preferred_weight,
self.edit_negative_text,
self.edit_notes, self.edit_notes,
] ]
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
...@@ -48,6 +48,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -48,6 +48,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
if activation_text: if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text) item["prompt"] += " + " + quote_js(" " + activation_text)
negative_prompt = item["user_metadata"].get("negative text")
item["negative_prompt"] = quote_js("")
if negative_prompt:
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
sd_version = item["user_metadata"].get("sd version") sd_version = item["user_metadata"].get("sd version")
if sd_version in network.SdVersion.__members__: if sd_version in network.SdVersion.__members__:
item["sd_version"] = sd_version item["sd_version"] = sd_version
......
import sys import sys
import PIL.Image import PIL.Image
import numpy as np
import torch
from tqdm import tqdm
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils
from scunet_model_arch import SCUNet
from modules.modelloader import load_file_from_url
from modules.shared import opts
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
...@@ -42,100 +35,37 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -42,100 +35,37 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2) scalers.append(scaler_data2)
self.scalers = scalers self.scalers = scalers
@staticmethod
@torch.no_grad()
def tiled_inference(img, model):
# test the image tile by tile
h, w = img.shape[2:]
tile = opts.SCUNET_tile
tile_overlap = opts.SCUNET_tile_overlap
if tile == 0:
return model(img)
device = devices.get_device_for('scunet')
assert tile % 8 == 0, "tile size should be a multiple of window_size"
sf = 1
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def do_upscale(self, img: PIL.Image.Image, selected_file): def do_upscale(self, img: PIL.Image.Image, selected_file):
devices.torch_gc() devices.torch_gc()
try: try:
model = self.load_model(selected_file) model = self.load_model(selected_file)
except Exception as e: except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img return img
device = devices.get_device_for('scunet') img = upscaler_utils.upscale_2(
tile = opts.SCUNET_tile img,
h, w = img.height, img.width model,
np_img = np.array(img) tile_size=shared.opts.SCUNET_tile,
np_img = np_img[:, :, ::-1] # RGB to BGR tile_overlap=shared.opts.SCUNET_tile_overlap,
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW scale=1, # ScuNET is a denoising model, not an upscaler
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore desc='ScuNET',
)
if tile > h or tile > w:
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
_img[:, :, :h, :w] = torch_img # pad image
torch_img = _img
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
devices.torch_gc() devices.torch_gc()
return img
output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB
return PIL.Image.fromarray((output * 255).astype(np.uint8))
def load_model(self, path: str): def load_model(self, path: str):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
if path.startswith("http"): if path.startswith("http"):
# TODO: this doesn't use `path` at all? # TODO: this doesn't use `path` at all?
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else: else:
filename = path filename = path
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for _, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
def on_ui_settings(): def on_ui_settings():
import gradio as gr import gradio as gr
from modules import shared
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
......
This diff is collapsed.
import logging
import sys import sys
import platform
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
from modules.shared import opts, state
from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir') logger = logging.getLogger(__name__)
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
...@@ -37,26 +32,28 @@ class UpscalerSwinIR(Upscaler): ...@@ -37,26 +32,28 @@ class UpscalerSwinIR(Upscaler):
scalers.append(model_data) scalers.append(model_data)
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img, model_file): def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \ current_config = (model_file, shared.opts.SWIN_tile)
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
current_config = (model_file, opts.SWIN_tile)
if use_compile and self._cached_model_config == current_config: if self._cached_model_config == current_config:
model = self._cached_model model = self._cached_model
else: else:
self._cached_model = None
try: try:
model = self.load_model(model_file) model = self.load_model(model_file)
except Exception as e: except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img return img
model = model.to(device_swinir, dtype=devices.dtype) self._cached_model = model
if use_compile: self._cached_model_config = current_config
model = torch.compile(model)
self._cached_model = model img = upscaler_utils.upscale_2(
self._cached_model_config = current_config img,
img = upscale(img, model) model,
tile_size=shared.opts.SWIN_tile,
tile_overlap=shared.opts.SWIN_tile_overlap,
scale=model.scale,
desc="SwinIR",
)
devices.torch_gc() devices.torch_gc()
return img return img
...@@ -69,115 +66,22 @@ class UpscalerSwinIR(Upscaler): ...@@ -69,115 +66,22 @@ class UpscalerSwinIR(Upscaler):
) )
else: else:
filename = path filename = path
if filename.endswith(".v2.pth"):
model = Swin2SR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
)
params = None
else:
model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",
)
params = "params_ema"
pretrained_model = torch.load(filename) model_descriptor = modelloader.load_spandrel_model(
if params is not None: filename,
model.load_state_dict(pretrained_model[params], strict=True) device=self._get_device(),
else: prefer_half=(devices.dtype == torch.float16),
model.load_state_dict(pretrained_model, strict=True) expected_architecture="SwinIR",
return model )
if getattr(shared.opts, 'SWIN_torch_compile', False):
try:
def upscale( model_descriptor.model.compile()
img, except Exception:
model, logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
tile=None, return model_descriptor
tile_overlap=None,
window_size=8, def _get_device(self):
scale=4, return devices.get_device_for('swinir')
):
tile = tile or opts.SWIN_tile
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = inference(img, model, tile, tile_overlap, window_size, scale)
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(
output[[2, 1, 0], :, :], (1, 2, 0)
) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
return Image.fromarray(output, "RGB")
def inference(img, model, tile, tile_overlap, window_size, scale):
# test the image tile by tile
b, c, h, w = img.size()
tile = min(tile, h, w)
assert tile % window_size == 0, "tile size should be a multiple of window_size"
sf = scale
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
if state.interrupted or state.skipped:
break
for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def on_ui_settings(): def on_ui_settings():
...@@ -185,8 +89,7 @@ def on_ui_settings(): ...@@ -185,8 +89,7 @@ def on_ui_settings():
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_ui_settings(on_ui_settings)
This diff is collapsed.
This diff is collapsed.
import math import math
import gradio as gr import gradio as gr
from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste from modules import scripts, shared, ui_components, ui_settings, infotext_utils
from modules.ui_components import FormColumn from modules.ui_components import FormColumn
...@@ -25,7 +25,7 @@ class ExtraOptionsSection(scripts.Script): ...@@ -25,7 +25,7 @@ class ExtraOptionsSection(scripts.Script):
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping}
with gr.Blocks() as interface: with gr.Blocks() as interface:
with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
......
This diff is collapsed.
This diff is collapsed.
...@@ -183,8 +183,10 @@ onUiLoaded(setupExtraNetworks); ...@@ -183,8 +183,10 @@ onUiLoaded(setupExtraNetworks);
var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/; var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g; var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;
function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var re_extranet_neg = /\(([^:^>]+:[\d.]+)\)/;
var m = text.match(re_extranet); var re_extranet_g_neg = /\(([^:^>]+:[\d.]+)\)/g;
function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) {
var m = text.match(isNeg ? re_extranet_neg : re_extranet);
var replaced = false; var replaced = false;
var newTextareaText; var newTextareaText;
if (m) { if (m) {
...@@ -192,8 +194,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { ...@@ -192,8 +194,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
var extraTextAfterNet = m[2]; var extraTextAfterNet = m[2];
var partToSearch = m[1]; var partToSearch = m[1];
var foundAtPosition = -1; var foundAtPosition = -1;
newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) { newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) {
m = found.match(re_extranet); m = found.match(isNeg ? re_extranet_neg : re_extranet);
if (m[1] == partToSearch) { if (m[1] == partToSearch) {
replaced = true; replaced = true;
foundAtPosition = pos; foundAtPosition = pos;
...@@ -203,7 +205,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { ...@@ -203,7 +205,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
}); });
if (foundAtPosition >= 0) { if (foundAtPosition >= 0) {
if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
} }
if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) { if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
...@@ -228,14 +230,23 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { ...@@ -228,14 +230,23 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
return false; return false;
} }
function cardClicked(tabname, textToAdd, allowNegativePrompt) { function updatePromptArea(text, textArea, isNeg) {
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) { if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) {
textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd; textArea.value = textArea.value + opts.extra_networks_add_text_separator + text;
} }
updateInput(textarea); updateInput(textArea);
}
function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) {
if (textToAddNegative.length > 0) {
updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"));
updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true);
} else {
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
updatePromptArea(textToAdd, textarea);
}
} }
function saveCardPreview(event, tabname, filename) { function saveCardPreview(event, tabname, filename) {
......
...@@ -150,6 +150,14 @@ function submit() { ...@@ -150,6 +150,14 @@ function submit() {
return res; return res;
} }
function submit_txt2img_upscale() {
var res = submit(...arguments);
res[2] = selected_gallery_index();
return res;
}
function submit_img2img() { function submit_img2img() {
showSubmitButtons('img2img', false); showSubmitButtons('img2img', false);
......
This diff is collapsed.
...@@ -107,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( ...@@ -107,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True}, {"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False}, {"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
{"key": "infotext", "type": str, "default": None},
] ]
).generate_model() ).generate_model()
...@@ -124,6 +126,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( ...@@ -124,6 +126,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True}, {"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False}, {"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
{"key": "infotext", "type": str, "default": None},
] ]
).generate_model() ).generate_model()
......
...@@ -62,16 +62,15 @@ def cache(subsection): ...@@ -62,16 +62,15 @@ def cache(subsection):
if cache_data is None: if cache_data is None:
with cache_lock: with cache_lock:
if cache_data is None: if cache_data is None:
if not os.path.isfile(cache_filename): try:
with open(cache_filename, "r", encoding="utf8") as file:
cache_data = json.load(file)
except FileNotFoundError:
cache_data = {}
except Exception:
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
cache_data = {} cache_data = {}
else:
try:
with open(cache_filename, "r", encoding="utf8") as file:
cache_data = json.load(file)
except Exception:
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
cache_data = {}
s = cache_data.get(subsection, {}) s = cache_data.get(subsection, {})
cache_data[subsection] = s cache_data[subsection] = s
......
...@@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): ...@@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
shared.state.skipped = False shared.state.skipped = False
shared.state.interrupted = False shared.state.interrupted = False
shared.state.stopping_generation = False
shared.state.job_count = 0 shared.state.job_count = 0
if not add_stats: if not add_stats:
......
...@@ -77,7 +77,9 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po ...@@ -77,7 +77,9 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json')) parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) parser.add_argument("--freeze-settings", action='store_true', help="disable editing of all settings globally", default=False)
parser.add_argument("--freeze-settings-in-sections", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like "saving-images,upscaling". The list of setting names can be found in the modules/shared_options.py file', default=None)
parser.add_argument("--freeze-specific-settings", type=str, help='disable editing of individual settings by specifying a comma-delimited list like "samples_save,samples_format". The list of setting names can be found in the config.json file', default=None)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json')) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
......
This diff is collapsed.
This diff is collapsed.
import os from __future__ import annotations
import cv2 import logging
import torch
import modules.face_restoration
import modules.shared
from modules import shared, devices, modelloader, errors
from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue.
model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
codeformer = None
def setup_model(dirname):
os.makedirs(model_path, exist_ok=True)
path = modules.paths.paths.get("CodeFormer", None)
if path is None:
return
try:
from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer
from basicsr.utils import img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface
net_class = CodeFormer
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
def name(self):
return "CodeFormer"
def __init__(self, dirname):
self.net = None
self.face_helper = None
self.cmd_dir = dirname
def create_models(self): import torch
if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer)
return self.net, self.face_helper
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
if len(model_paths) != 0:
ckpt_path = model_paths[0]
else:
print("Unable to load codeformer model.")
return None, None
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)
net.eval()
if hasattr(retinaface, 'device'):
retinaface.device = devices.device_codeformer
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
self.net = net
self.face_helper = face_helper
return net, face_helper
def send_model_to(self, device):
self.net.to(device)
self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device)
def restore(self, np_image, w=None):
np_image = np_image[:, :, ::-1]
original_resolution = np_image.shape[0:2]
self.create_models() from modules import (
if self.net is None or self.face_helper is None: devices,
return np_image errors,
face_restoration,
face_restoration_utils,
modelloader,
shared,
)
self.send_model_to(devices.device_codeformer) logger = logging.getLogger(__name__)
self.face_helper.clean_all() model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
self.face_helper.read_image(np_image) model_download_name = 'codeformer-v0.1.0.pth'
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face()
for cropped_face in self.face_helper.cropped_faces: # used by e.g. postprocessing_codeformer.py
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) codeformer: face_restoration.FaceRestoration | None = None
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
try:
with torch.no_grad():
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
devices.torch_gc()
except Exception:
errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8') class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
self.face_helper.add_restored_face(restored_face) def name(self):
return "CodeFormer"
self.face_helper.get_inverse_affine(None) def load_net(self) -> torch.Module:
for model_path in modelloader.load_models(
model_path=self.model_path,
model_url=model_url,
command_path=self.model_path,
download_name=model_download_name,
ext_filter=['.pth'],
):
return modelloader.load_spandrel_model(
model_path,
device=devices.device_codeformer,
expected_architecture='CodeFormer',
).model
raise ValueError("No codeformer model found")
restored_img = self.face_helper.paste_faces_to_input_image() def get_device(self):
restored_img = restored_img[:, :, ::-1] return devices.device_codeformer
if original_resolution != restored_img.shape[0:2]: def restore(self, np_image, w: float | None = None):
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) if w is None:
w = getattr(shared.opts, "code_former_weight", 0.5)
self.face_helper.clean_all() def restore_face(cropped_face_t):
assert self.net is not None
return self.net(cropped_face_t, w=w, adain=True)[0]
if shared.opts.face_restoration_unload: return self.restore_with_helper(np_image, restore_face)
self.send_model_to(devices.cpu)
return restored_img
global codeformer def setup_model(dirname: str) -> None:
global codeformer
try:
codeformer = FaceRestorerCodeFormer(dirname) codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer) shared.face_restorers.append(codeformer)
except Exception: except Exception:
errors.report("Error setting up CodeFormer", exc_info=True) errors.report("Error setting up CodeFormer", exc_info=True)
# sys.path = stored_sys_path
This diff is collapsed.
...@@ -107,8 +107,8 @@ def check_versions(): ...@@ -107,8 +107,8 @@ def check_versions():
import torch import torch
import gradio import gradio
expected_torch_version = "2.0.0" expected_torch_version = "2.1.2"
expected_xformers_version = "0.0.20" expected_xformers_version = "0.0.23.post1"
expected_gradio_version = "3.41.2" expected_gradio_version = "3.41.2"
if version.parse(torch.__version__) < version.parse(expected_torch_version): if version.parse(torch.__version__) < version.parse(expected_torch_version):
......
This diff is collapsed.
This diff is collapsed.
...@@ -32,11 +32,12 @@ class ExtensionMetadata: ...@@ -32,11 +32,12 @@ class ExtensionMetadata:
self.config = configparser.ConfigParser() self.config = configparser.ConfigParser()
filepath = os.path.join(path, self.filename) filepath = os.path.join(path, self.filename)
if os.path.isfile(filepath): # `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),
try: # so no need to check whether the file exists beforehand.
self.config.read(filepath) try:
except Exception: self.config.read(filepath)
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True) except Exception:
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
self.canonical_name = canonical_name.lower().strip() self.canonical_name = canonical_name.lower().strip()
......
...@@ -206,7 +206,7 @@ def parse_prompts(prompts): ...@@ -206,7 +206,7 @@ def parse_prompts(prompts):
return res, extra_data return res, extra_data
def get_user_metadata(filename): def get_user_metadata(filename, lister=None):
if filename is None: if filename is None:
return {} return {}
...@@ -215,7 +215,8 @@ def get_user_metadata(filename): ...@@ -215,7 +215,8 @@ def get_user_metadata(filename):
metadata = {} metadata = {}
try: try:
if os.path.isfile(metadata_filename): exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
if exists:
with open(metadata_filename, "r", encoding="utf8") as file: with open(metadata_filename, "r", encoding="utf8") as file:
metadata = json.load(file) metadata = json.load(file)
except Exception as e: except Exception as e:
......
This diff is collapsed.
This diff is collapsed.
import os
import sys
from modules import modelloader, devices
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
class UpscalerHAT(Upscaler):
def __init__(self, dirname):
self.name = "HAT"
self.scalers = []
self.user_path = dirname
super().__init__()
for file in self.find_models(ext_filter=[".pt", ".pth"]):
name = modelloader.friendly_name(file)
scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
try:
model = self.load_model(selected_model)
except Exception as e:
print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan) # TODO: should probably be device_hat
return upscale_with_model(
model,
img,
tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
)
def load_model(self, path: str):
if not os.path.isfile(path):
raise FileNotFoundError(f"Model file {path} not found")
return modelloader.load_spandrel_model(
path,
device=devices.device_esrgan, # TODO: should probably be device_hat
expected_architecture='HAT',
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment