Commit bef51aed authored by AUTOMATIC1111's avatar AUTOMATIC1111

Merge branch 'release_candidate'

parents cf2772fa 13984857
...@@ -86,8 +86,6 @@ module.exports = { ...@@ -86,8 +86,6 @@ module.exports = {
// imageviewer.js // imageviewer.js
modalPrevImage: "readonly", modalPrevImage: "readonly",
modalNextImage: "readonly", modalNextImage: "readonly",
// token-counters.js
setupTokenCounters: "readonly",
// localStorage.js // localStorage.js
localSet: "readonly", localSet: "readonly",
localGet: "readonly", localGet: "readonly",
......
...@@ -20,6 +20,12 @@ jobs: ...@@ -20,6 +20,12 @@ jobs:
cache-dependency-path: | cache-dependency-path: |
**/requirements*txt **/requirements*txt
launch.py launch.py
- name: Cache models
id: cache-models
uses: actions/cache@v3
with:
path: models
key: "2023-12-30"
- name: Install test dependencies - name: Install test dependencies
run: pip install wait-for-it -r requirements-test.txt run: pip install wait-for-it -r requirements-test.txt
env: env:
...@@ -33,6 +39,8 @@ jobs: ...@@ -33,6 +39,8 @@ jobs:
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
WEBUI_LAUNCH_LIVE_OUTPUT: "1" WEBUI_LAUNCH_LIVE_OUTPUT: "1"
PYTHONUNBUFFERED: "1" PYTHONUNBUFFERED: "1"
- name: Print installed packages
run: pip freeze
- name: Start test server - name: Start test server
run: > run: >
python -m coverage run python -m coverage run
...@@ -49,7 +57,7 @@ jobs: ...@@ -49,7 +57,7 @@ jobs:
2>&1 | tee output.txt & 2>&1 | tee output.txt &
- name: Run tests - name: Run tests
run: | run: |
wait-for-it --service 127.0.0.1:7860 -t 600 wait-for-it --service 127.0.0.1:7860 -t 20
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server - name: Kill test server
if: always() if: always()
......
...@@ -37,3 +37,4 @@ notification.mp3 ...@@ -37,3 +37,4 @@ notification.mp3
/node_modules /node_modules
/package-lock.json /package-lock.json
/.coverage* /.coverage*
/test/test_outputs
This diff is collapsed.
# Stable Diffusion web UI # Stable Diffusion web UI
A browser interface based on Gradio library for Stable Diffusion. A web interface for Stable Diffusion, implemented using Gradio library.
![](screenshot.png) ![](screenshot.png)
...@@ -151,11 +151,12 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al ...@@ -151,11 +151,12 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers - Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git - Spandrel - https://github.com/chaiNNer-org/spandrel implementing
- CodeFormer - https://github.com/sczhou/CodeFormer - GFPGAN - https://github.com/TencentARC/GFPGAN.git
- ESRGAN - https://github.com/xinntao/ESRGAN - CodeFormer - https://github.com/sczhou/CodeFormer
- SwinIR - https://github.com/JingyunLiang/SwinIR - ESRGAN - https://github.com/xinntao/ESRGAN
- Swin2SR - https://github.com/mv-lab/swin2sr - SwinIR - https://github.com/JingyunLiang/SwinIR
- Swin2SR - https://github.com/mv-lab/swin2sr
- LDSR - https://github.com/Hafiidz/latent-diffusion - LDSR - https://github.com/Hafiidz/latent-diffusion
- MiDaS - https://github.com/isl-org/MiDaS - MiDaS - https://github.com/isl-org/MiDaS
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
......
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
...@@ -3,6 +3,9 @@ import os ...@@ -3,6 +3,9 @@ import os
from collections import namedtuple from collections import namedtuple
import enum import enum
import torch.nn as nn
import torch.nn.functional as F
from modules import sd_models, cache, errors, hashes, shared from modules import sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
...@@ -115,6 +118,29 @@ class NetworkModule: ...@@ -115,6 +118,29 @@ class NetworkModule:
if hasattr(self.sd_module, 'weight'): if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape self.shape = self.sd_module.weight.shape
self.ops = None
self.extra_kwargs = {}
if isinstance(self.sd_module, nn.Conv2d):
self.ops = F.conv2d
self.extra_kwargs = {
'stride': self.sd_module.stride,
'padding': self.sd_module.padding
}
elif isinstance(self.sd_module, nn.Linear):
self.ops = F.linear
elif isinstance(self.sd_module, nn.LayerNorm):
self.ops = F.layer_norm
self.extra_kwargs = {
'normalized_shape': self.sd_module.normalized_shape,
'eps': self.sd_module.eps
}
elif isinstance(self.sd_module, nn.GroupNorm):
self.ops = F.group_norm
self.extra_kwargs = {
'num_groups': self.sd_module.num_groups,
'eps': self.sd_module.eps
}
self.dim = None self.dim = None
self.bias = weights.w.get("bias") self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
...@@ -137,7 +163,7 @@ class NetworkModule: ...@@ -137,7 +163,7 @@ class NetworkModule:
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None: if self.bias is not None:
updown = updown.reshape(self.bias.shape) updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape) updown = updown.reshape(output_shape)
if len(output_shape) == 4: if len(output_shape) == 4:
...@@ -155,5 +181,10 @@ class NetworkModule: ...@@ -155,5 +181,10 @@ class NetworkModule:
raise NotImplementedError() raise NotImplementedError()
def forward(self, x, y): def forward(self, x, y):
raise NotImplementedError() """A general forward implementation for all modules"""
if self.ops is None:
raise NotImplementedError()
else:
updown, ex_bias = self.calc_updown(self.sd_module.weight)
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
...@@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule): ...@@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
output_shape = self.weight.shape output_shape = self.weight.shape
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) updown = self.weight.to(orig_weight.device)
if self.ex_bias is not None: if self.ex_bias is not None:
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) ex_bias = self.ex_bias.to(orig_weight.device)
else: else:
ex_bias = None ex_bias = None
......
...@@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule): ...@@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
self.w2b = weights.w["b2.weight"] self.w2b = weights.w["b2.weight"]
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) w1b = self.w1b.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)] output_shape = [w1a.size(0), w1b.size(1)]
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
return self.finalize_updown(updown, orig_weight, output_shape) return self.finalize_updown(updown, orig_weight, output_shape)
...@@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule): ...@@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
self.t2 = weights.w.get("hada_t2") self.t2 = weights.w.get("hada_t2")
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) w1b = self.w1b.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)] output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None: if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)] output_shape = [w1a.size(1), w1b.size(1)]
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) t1 = self.t1.to(orig_weight.device)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:] output_shape += t1.shape[2:]
else: else:
...@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule): ...@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None: if self.t2 is not None:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) t2 = self.t2.to(orig_weight.device)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else: else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
......
...@@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule): ...@@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
self.on_input = weights.w["on_input"].item() self.on_input = weights.w["on_input"].item()
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) w = self.w.to(orig_weight.device)
output_shape = [w.size(0), orig_weight.size(1)] output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input: if self.on_input:
......
...@@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule): ...@@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
if self.w1 is not None: if self.w1 is not None:
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) w1 = self.w1.to(orig_weight.device)
else: else:
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) w1b = self.w1b.to(orig_weight.device)
w1 = w1a @ w1b w1 = w1a @ w1b
if self.w2 is not None: if self.w2 is not None:
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) w2 = self.w2.to(orig_weight.device)
elif self.t2 is None: elif self.t2 is None:
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) w2b = self.w2b.to(orig_weight.device)
w2 = w2a @ w2b w2 = w2a @ w2b
else: else:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) t2 = self.t2.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) w2b = self.w2b.to(orig_weight.device)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
......
...@@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule): ...@@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
return module return module
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) up = self.up_model.weight.to(orig_weight.device)
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) down = self.down_model.weight.to(orig_weight.device)
output_shape = [up.size(0), down.size(1)] output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None: if self.mid_model is not None:
# cp-decomposition # cp-decomposition
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) mid = self.mid_model.weight.to(orig_weight.device)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:] output_shape += mid.shape[2:]
else: else:
......
...@@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule): ...@@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) updown = self.w_norm.to(orig_weight.device)
if self.b_norm is not None: if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) ex_bias = self.b_norm.to(orig_weight.device)
else: else:
ex_bias = None ex_bias = None
......
import torch import torch
import network import network
from lyco_helpers import factorization
from einops import rearrange from einops import rearrange
...@@ -22,20 +21,28 @@ class NetworkModuleOFT(network.NetworkModule): ...@@ -22,20 +21,28 @@ class NetworkModuleOFT(network.NetworkModule):
self.org_module: list[torch.Module] = [self.sd_module] self.org_module: list[torch.Module] = [self.sd_module]
self.scale = 1.0 self.scale = 1.0
self.is_R = False
self.is_boft = False
# kohya-ss # kohya-ss/New LyCORIS OFT/BOFT
if "oft_blocks" in weights.w.keys(): if "oft_blocks" in weights.w.keys():
self.is_kohya = True
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
self.alpha = weights.w["alpha"] # alpha is constraint self.alpha = weights.w.get("alpha", None) # alpha is constraint
self.dim = self.oft_blocks.shape[0] # lora dim self.dim = self.oft_blocks.shape[0] # lora dim
# LyCORIS # Old LyCORIS OFT
elif "oft_diag" in weights.w.keys(): elif "oft_diag" in weights.w.keys():
self.is_kohya = False self.is_R = True
self.oft_blocks = weights.w["oft_diag"] self.oft_blocks = weights.w["oft_diag"]
# self.alpha is unused # self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
# LyCORIS BOFT
if self.oft_blocks.dim() == 4:
self.is_boft = True
self.rescale = weights.w.get('rescale', None)
if self.rescale is not None:
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d] is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
...@@ -47,36 +54,65 @@ class NetworkModuleOFT(network.NetworkModule): ...@@ -47,36 +54,65 @@ class NetworkModuleOFT(network.NetworkModule):
elif is_other_linear: elif is_other_linear:
self.out_dim = self.sd_module.embed_dim self.out_dim = self.sd_module.embed_dim
if self.is_kohya: self.num_blocks = self.dim
self.constraint = self.alpha * self.out_dim self.block_size = self.out_dim // self.dim
self.num_blocks = self.dim self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
self.block_size = self.out_dim // self.dim if self.is_R:
else:
self.constraint = None self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.block_size = self.dim
self.num_blocks = self.out_dim // self.dim
elif self.is_boft:
self.boft_m = self.oft_blocks.shape[0]
self.num_blocks = self.oft_blocks.shape[1]
self.block_size = self.oft_blocks.shape[2]
self.boft_b = self.block_size
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) oft_blocks = self.oft_blocks.to(orig_weight.device)
eye = torch.eye(self.block_size, device=self.oft_blocks.device) eye = torch.eye(self.block_size, device=oft_blocks.device)
if self.is_kohya: if not self.is_R:
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix
norm_Q = torch.norm(block_Q.flatten()) if self.constraint != 0:
new_norm_Q = torch.clamp(norm_Q, max=self.constraint) norm_Q = torch.norm(block_Q.flatten())
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) R = oft_blocks.to(orig_weight.device)
# This errors out for MultiheadAttention, might need to be handled up-stream if not self.is_boft:
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) # This errors out for MultiheadAttention, might need to be handled up-stream
merged_weight = torch.einsum( merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
'k n m, k n ... -> k m ...', merged_weight = torch.einsum(
R, 'k n m, k n ... -> k m ...',
merged_weight R,
) merged_weight
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') )
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight else:
# TODO: determine correct value for scale
scale = 1.0
m = self.boft_m
b = self.boft_b
r_b = b // 2
inp = orig_weight
for i in range(m):
bi = R[i] # b_num, b_size, b_size
if i == 0:
# Apply multiplier/scale and rescale into first weight
bi = bi * scale + (1 - scale) * eye
inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b)
inp = rearrange(inp, "(d b) ... -> d b ...", b=b)
inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)
inp = rearrange(inp, "d b ... -> (d b) ...")
inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b)
merged_weight = inp
# Rescale mechanism
if self.rescale is not None:
merged_weight = self.rescale.to(merged_weight) * merged_weight
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
output_shape = orig_weight.shape output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape) return self.finalize_updown(updown, orig_weight, output_shape)
import gradio as gr
import logging import logging
import os import os
import re import re
...@@ -259,11 +260,11 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No ...@@ -259,11 +260,11 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
loaded_networks.clear() loaded_networks.clear()
networks_on_disk = [available_network_aliases.get(name, None) for name in names] networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
if any(x is None for x in networks_on_disk): if any(x is None for x in networks_on_disk):
list_available_networks() list_available_networks()
networks_on_disk = [available_network_aliases.get(name, None) for name in names] networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
failed_to_load_networks = [] failed_to_load_networks = []
...@@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No ...@@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
emb_db.skipped_embeddings[name] = embedding emb_db.skipped_embeddings[name] = embedding
if failed_to_load_networks: if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
sd_hijack.model_hijack.comments.append(lora_not_found_message)
if shared.opts.lora_not_found_warning_console:
print(f'\n{lora_not_found_message}\n')
if shared.opts.lora_not_found_gradio_warning:
gr.Warning(lora_not_found_message)
purge_networks_from_memory() purge_networks_from_memory()
...@@ -389,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn ...@@ -389,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if module is not None and hasattr(self, 'weight'): if module is not None and hasattr(self, 'weight'):
try: try:
with torch.no_grad(): with torch.no_grad():
updown, ex_bias = module.calc_updown(self.weight) if getattr(self, 'fp16_weight', None) is None:
weight = self.weight
bias = self.bias
else:
weight = self.fp16_weight.clone().to(self.weight.device)
bias = getattr(self, 'fp16_bias', None)
if bias is not None:
bias = bias.clone().to(self.bias.device)
updown, ex_bias = module.calc_updown(weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9 # inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'): if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None: if self.bias is None:
self.bias = torch.nn.Parameter(ex_bias) self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else: else:
self.bias += ex_bias self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e: except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
...@@ -444,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn ...@@ -444,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_current_names = wanted_names self.network_current_names = wanted_names
def network_forward(module, input, original_forward): def network_forward(org_module, input, original_forward):
""" """
Old way of applying Lora by executing operations during layer's forward. Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation. Stacking many loras this way results in big performance degradation.
""" """
if len(loaded_networks) == 0: if len(loaded_networks) == 0:
return original_forward(module, input) return original_forward(org_module, input)
input = devices.cond_cast_unet(input) input = devices.cond_cast_unet(input)
network_restore_weights_from_backup(module) network_restore_weights_from_backup(org_module)
network_reset_cached_weight(module) network_reset_cached_weight(org_module)
y = original_forward(module, input) y = original_forward(org_module, input)
network_layer_name = getattr(module, 'network_layer_name', None) network_layer_name = getattr(org_module, 'network_layer_name', None)
for lora in loaded_networks: for lora in loaded_networks:
module = lora.modules.get(network_layer_name, None) module = lora.modules.get(network_layer_name, None)
if module is None: if module is None:
......
import os import os
from modules import paths from modules import paths
from modules.paths_internal import normalized_filepath
def preload(parser): def preload(parser):
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) parser.add_argument("--lora-dir", type=normalized_filepath, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
parser.add_argument("--lyco-dir-backcompat", type=str, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS')) parser.add_argument("--lyco-dir-backcompat", type=normalized_filepath, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))
...@@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra ...@@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
})) }))
......
...@@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.slider_preferred_weight = None self.slider_preferred_weight = None
self.edit_notes = None self.edit_notes = None
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes): def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
user_metadata = self.get_user_metadata(name) user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc user_metadata["description"] = desc
user_metadata["sd version"] = sd_version user_metadata["sd version"] = sd_version
user_metadata["activation text"] = activation_text user_metadata["activation text"] = activation_text
user_metadata["preferred weight"] = preferred_weight user_metadata["preferred weight"] = preferred_weight
user_metadata["negative text"] = negative_text
user_metadata["notes"] = notes user_metadata["notes"] = notes
self.write_user_metadata(name, user_metadata) self.write_user_metadata(name, user_metadata)
...@@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
user_metadata.get('activation text', ''), user_metadata.get('activation text', ''),
float(user_metadata.get('preferred weight', 0.0)), float(user_metadata.get('preferred weight', 0.0)),
user_metadata.get('negative text', ''),
gr.update(visible=True if tags else False), gr.update(visible=True if tags else False),
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
] ]
...@@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.taginfo = gr.HighlightedText(label="Training dataset tags") self.taginfo = gr.HighlightedText(label="Training dataset tags")
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
with gr.Row() as row_random_prompt: with gr.Row() as row_random_prompt:
with gr.Column(scale=8): with gr.Column(scale=8):
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
...@@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.taginfo, self.taginfo,
self.edit_activation_text, self.edit_activation_text,
self.slider_preferred_weight, self.slider_preferred_weight,
self.edit_negative_text,
row_random_prompt, row_random_prompt,
random_prompt, random_prompt,
] ]
...@@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.select_sd_version, self.select_sd_version,
self.edit_activation_text, self.edit_activation_text,
self.slider_preferred_weight, self.slider_preferred_weight,
self.edit_negative_text,
self.edit_notes, self.edit_notes,
] ]
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
...@@ -24,13 +24,16 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -24,13 +24,16 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
alias = lora_on_disk.get_alias() alias = lora_on_disk.get_alias()
search_terms = [self.search_terms_from_path(lora_on_disk.filename)]
if lora_on_disk.hash:
search_terms.append(lora_on_disk.hash)
item = { item = {
"name": name, "name": name,
"filename": lora_on_disk.filename, "filename": lora_on_disk.filename,
"shorthash": lora_on_disk.shorthash, "shorthash": lora_on_disk.shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""), "search_terms": search_terms,
"local_preview": f"{path}.{shared.opts.samples_format}", "local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": lora_on_disk.metadata, "metadata": lora_on_disk.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
...@@ -45,6 +48,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -45,6 +48,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
if activation_text: if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text) item["prompt"] += " + " + quote_js(" " + activation_text)
negative_prompt = item["user_metadata"].get("negative text")
item["negative_prompt"] = quote_js("")
if negative_prompt:
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
sd_version = item["user_metadata"].get("sd version") sd_version = item["user_metadata"].get("sd version")
if sd_version in network.SdVersion.__members__: if sd_version in network.SdVersion.__members__:
item["sd_version"] = sd_version item["sd_version"] = sd_version
......
import sys import sys
import PIL.Image import PIL.Image
import numpy as np
import torch
from tqdm import tqdm
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils
from scunet_model_arch import SCUNet
from modules.modelloader import load_file_from_url
from modules.shared import opts
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
...@@ -42,100 +35,37 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -42,100 +35,37 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2) scalers.append(scaler_data2)
self.scalers = scalers self.scalers = scalers
@staticmethod
@torch.no_grad()
def tiled_inference(img, model):
# test the image tile by tile
h, w = img.shape[2:]
tile = opts.SCUNET_tile
tile_overlap = opts.SCUNET_tile_overlap
if tile == 0:
return model(img)
device = devices.get_device_for('scunet')
assert tile % 8 == 0, "tile size should be a multiple of window_size"
sf = 1
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def do_upscale(self, img: PIL.Image.Image, selected_file): def do_upscale(self, img: PIL.Image.Image, selected_file):
devices.torch_gc() devices.torch_gc()
try: try:
model = self.load_model(selected_file) model = self.load_model(selected_file)
except Exception as e: except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img return img
device = devices.get_device_for('scunet') img = upscaler_utils.upscale_2(
tile = opts.SCUNET_tile img,
h, w = img.height, img.width model,
np_img = np.array(img) tile_size=shared.opts.SCUNET_tile,
np_img = np_img[:, :, ::-1] # RGB to BGR tile_overlap=shared.opts.SCUNET_tile_overlap,
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW scale=1, # ScuNET is a denoising model, not an upscaler
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore desc='ScuNET',
)
if tile > h or tile > w:
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
_img[:, :, :h, :w] = torch_img # pad image
torch_img = _img
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
devices.torch_gc() devices.torch_gc()
return img
output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB
return PIL.Image.fromarray((output * 255).astype(np.uint8))
def load_model(self, path: str): def load_model(self, path: str):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
if path.startswith("http"): if path.startswith("http"):
# TODO: this doesn't use `path` at all? # TODO: this doesn't use `path` at all?
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else: else:
filename = path filename = path
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for _, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
def on_ui_settings(): def on_ui_settings():
import gradio as gr import gradio as gr
from modules import shared
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
......
This diff is collapsed.
import logging
import sys import sys
import platform
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
from modules.shared import opts, state
from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir') logger = logging.getLogger(__name__)
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
...@@ -37,26 +32,28 @@ class UpscalerSwinIR(Upscaler): ...@@ -37,26 +32,28 @@ class UpscalerSwinIR(Upscaler):
scalers.append(model_data) scalers.append(model_data)
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img, model_file): def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \ current_config = (model_file, shared.opts.SWIN_tile)
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
current_config = (model_file, opts.SWIN_tile)
if use_compile and self._cached_model_config == current_config: if self._cached_model_config == current_config:
model = self._cached_model model = self._cached_model
else: else:
self._cached_model = None
try: try:
model = self.load_model(model_file) model = self.load_model(model_file)
except Exception as e: except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img return img
model = model.to(device_swinir, dtype=devices.dtype) self._cached_model = model
if use_compile: self._cached_model_config = current_config
model = torch.compile(model)
self._cached_model = model img = upscaler_utils.upscale_2(
self._cached_model_config = current_config img,
img = upscale(img, model) model,
tile_size=shared.opts.SWIN_tile,
tile_overlap=shared.opts.SWIN_tile_overlap,
scale=model.scale,
desc="SwinIR",
)
devices.torch_gc() devices.torch_gc()
return img return img
...@@ -69,115 +66,22 @@ class UpscalerSwinIR(Upscaler): ...@@ -69,115 +66,22 @@ class UpscalerSwinIR(Upscaler):
) )
else: else:
filename = path filename = path
if filename.endswith(".v2.pth"):
model = Swin2SR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
)
params = None
else:
model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",
)
params = "params_ema"
pretrained_model = torch.load(filename) model_descriptor = modelloader.load_spandrel_model(
if params is not None: filename,
model.load_state_dict(pretrained_model[params], strict=True) device=self._get_device(),
else: prefer_half=(devices.dtype == torch.float16),
model.load_state_dict(pretrained_model, strict=True) expected_architecture="SwinIR",
return model )
if getattr(shared.opts, 'SWIN_torch_compile', False):
try:
def upscale( model_descriptor.model.compile()
img, except Exception:
model, logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
tile=None, return model_descriptor
tile_overlap=None,
window_size=8, def _get_device(self):
scale=4, return devices.get_device_for('swinir')
):
tile = tile or opts.SWIN_tile
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = inference(img, model, tile, tile_overlap, window_size, scale)
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(
output[[2, 1, 0], :, :], (1, 2, 0)
) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
return Image.fromarray(output, "RGB")
def inference(img, model, tile, tile_overlap, window_size, scale):
# test the image tile by tile
b, c, h, w = img.size()
tile = min(tile, h, w)
assert tile % window_size == 0, "tile size should be a multiple of window_size"
sf = scale
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
if state.interrupted or state.skipped:
break
for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def on_ui_settings(): def on_ui_settings():
...@@ -185,8 +89,7 @@ def on_ui_settings(): ...@@ -185,8 +89,7 @@ def on_ui_settings():
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_ui_settings(on_ui_settings)
This diff is collapsed.
This diff is collapsed.
...@@ -218,6 +218,8 @@ onUiLoaded(async() => { ...@@ -218,6 +218,8 @@ onUiLoaded(async() => {
canvas_hotkey_fullscreen: "KeyS", canvas_hotkey_fullscreen: "KeyS",
canvas_hotkey_move: "KeyF", canvas_hotkey_move: "KeyF",
canvas_hotkey_overlap: "KeyO", canvas_hotkey_overlap: "KeyO",
canvas_hotkey_shrink_brush: "KeyQ",
canvas_hotkey_grow_brush: "KeyW",
canvas_disabled_functions: [], canvas_disabled_functions: [],
canvas_show_tooltip: true, canvas_show_tooltip: true,
canvas_auto_expand: true, canvas_auto_expand: true,
...@@ -227,6 +229,8 @@ onUiLoaded(async() => { ...@@ -227,6 +229,8 @@ onUiLoaded(async() => {
const functionMap = { const functionMap = {
"Zoom": "canvas_hotkey_zoom", "Zoom": "canvas_hotkey_zoom",
"Adjust brush size": "canvas_hotkey_adjust", "Adjust brush size": "canvas_hotkey_adjust",
"Hotkey shrink brush": "canvas_hotkey_shrink_brush",
"Hotkey enlarge brush": "canvas_hotkey_grow_brush",
"Moving canvas": "canvas_hotkey_move", "Moving canvas": "canvas_hotkey_move",
"Fullscreen": "canvas_hotkey_fullscreen", "Fullscreen": "canvas_hotkey_fullscreen",
"Reset Zoom": "canvas_hotkey_reset", "Reset Zoom": "canvas_hotkey_reset",
...@@ -686,7 +690,9 @@ onUiLoaded(async() => { ...@@ -686,7 +690,9 @@ onUiLoaded(async() => {
const hotkeyActions = { const hotkeyActions = {
[hotkeysConfig.canvas_hotkey_reset]: resetZoom, [hotkeysConfig.canvas_hotkey_reset]: resetZoom,
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
[hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen,
[hotkeysConfig.canvas_hotkey_shrink_brush]: () => adjustBrushSize(elemId, 10),
[hotkeysConfig.canvas_hotkey_grow_brush]: () => adjustBrushSize(elemId, -10)
}; };
const action = hotkeyActions[event.code]; const action = hotkeyActions[event.code];
......
...@@ -4,6 +4,8 @@ from modules import shared ...@@ -4,6 +4,8 @@ from modules import shared
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), { shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
"canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"),
"canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"),
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"), "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "), "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
...@@ -11,5 +13,5 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas ...@@ -11,5 +13,5 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"), "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"), "canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"), "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}), "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size","Hotkey enlarge brush","Hotkey shrink brush","Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
})) }))
import math import math
import gradio as gr import gradio as gr
from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste from modules import scripts, shared, ui_components, ui_settings, infotext_utils
from modules.ui_components import FormColumn from modules.ui_components import FormColumn
...@@ -25,7 +25,7 @@ class ExtraOptionsSection(scripts.Script): ...@@ -25,7 +25,7 @@ class ExtraOptionsSection(scripts.Script):
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping}
with gr.Blocks() as interface: with gr.Blocks() as interface:
with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
......
This diff is collapsed.
<div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}> <div class="card" style="{style}" onclick="{card_clicked}" data-name="{name}" {sort_keys}>
{background_image} {background_image}
<div class="button-row"> <div class="button-row">{copy_path_button}{metadata_button}{edit_button}</div>
{metadata_button} <div class="actions">
{edit_button} <div class="additional">{search_terms}</div>
</div> <span class="name">{name}</span>
<div class='actions'> <span class="description">{description}</span>
<div class='additional'>
<span style="display:none" class='search_term{search_only}'>{search_term}</span>
</div>
<span class='name'>{name}</span>
<span class='description'>{description}</span>
</div> </div>
</div> </div>
<div class="copy-path-button card-button"
title="Copy path to clipboard"
onclick="extraNetworksCopyCardPath(event, '{filename}')"
data-clipboard-text="{filename}">
</div>
\ No newline at end of file
<div class="edit-button card-button"
title="Edit metadata"
onclick="extraNetworksEditUserMetadata(event, '{tabname}', '{extra_networks_tabname}', '{name}')">
</div>
\ No newline at end of file
<div class="metadata-button card-button"
title="Show internal metadata"
onclick="extraNetworksRequestMetadata(event, '{extra_networks_tabname}', '{name}')">
</div>
\ No newline at end of file
<div id='{tabname}_{extra_networks_tabname}_pane' class='extra-network-pane'>
<div class="extra-network-control" id="{tabname}_{extra_networks_tabname}_controls" style="display:none" >
<div class="extra-network-control--search">
<input
id="{tabname}_{extra_networks_tabname}_extra_search"
class="extra-network-control--search-text"
type="search"
placeholder="Filter files"
>
</div>
<div
id="{tabname}_{extra_networks_tabname}_extra_sort"
class="extra-network-control--sort"
data-sortmode="{data_sortmode}"
data-sortkey="{data_sortkey}"
title="Sort by path"
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--sort-icon"></i>
</div>
<div
id="{tabname}_{extra_networks_tabname}_extra_sort_dir"
class="extra-network-control--sort-dir"
data-sortdir="{data_sortdir}"
title="Sort ascending"
onclick="extraNetworksControlSortDirOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--sort-dir-icon"></i>
</div>
<div
id="{tabname}_{extra_networks_tabname}_extra_tree_view"
class="extra-network-control--tree-view {tree_view_btn_extra_class}"
title="Enable Tree View"
onclick="extraNetworksControlTreeViewOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--tree-view-icon"></i>
</div>
<div
id="{tabname}_{extra_networks_tabname}_extra_refresh"
class="extra-network-control--refresh"
title="Refresh page"
onclick="extraNetworksControlRefreshOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--refresh-icon"></i>
</div>
</div>
<div class="extra-network-pane-content">
<div id='{tabname}_{extra_networks_tabname}_tree' class='extra-network-tree {tree_view_div_extra_class}'>
{tree_html}
</div>
<div id='{tabname}_{extra_networks_tabname}_cards' class='extra-network-cards'>
{items_html}
</div>
</div>
</div>
\ No newline at end of file
<span data-filterable-item-text hidden>{search_terms}</span>
<div class="tree-list-content {subclass}"
type="button"
onclick="extraNetworksTreeOnClick(event, '{tabname}', '{extra_networks_tabname}');{onclick_extra}"
data-path="{data_path}"
data-hash="{data_hash}"
>
<span class='tree-list-item-action tree-list-item-action--leading'>
{action_list_item_action_leading}
</span>
<span class="tree-list-item-visual tree-list-item-visual--leading">
{action_list_item_visual_leading}
</span>
<span class="tree-list-item-label tree-list-item-label--truncate">
{action_list_item_label}
</span>
<span class="tree-list-item-visual tree-list-item-visual--trailing">
{action_list_item_visual_trailing}
</span>
<span class="tree-list-item-action tree-list-item-action--trailing">
{action_list_item_action_trailing}
</span>
</div>
\ No newline at end of file
This diff is collapsed.
...@@ -2,8 +2,11 @@ ...@@ -2,8 +2,11 @@
function extensions_apply(_disabled_list, _update_list, disable_all) { function extensions_apply(_disabled_list, _update_list, disable_all) {
var disable = []; var disable = [];
var update = []; var update = [];
const extensions_input = gradioApp().querySelectorAll('#extensions input[type="checkbox"]');
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) { if (extensions_input.length == 0) {
throw Error("Extensions page not yet loaded.");
}
extensions_input.forEach(function(x) {
if (x.name.startsWith("enable_") && !x.checked) { if (x.name.startsWith("enable_") && !x.checked) {
disable.push(x.name.substring(7)); disable.push(x.name.substring(7));
} }
......
This diff is collapsed.
...@@ -45,8 +45,15 @@ function formatTime(secs) { ...@@ -45,8 +45,15 @@ function formatTime(secs) {
} }
} }
var originalAppTitle = undefined;
onUiLoaded(function() {
originalAppTitle = document.title;
});
function setTitle(progress) { function setTitle(progress) {
var title = 'Stable Diffusion'; var title = originalAppTitle;
if (opts.show_progress_in_title && progress) { if (opts.show_progress_in_title && progress) {
title = '[' + progress.trim() + '] ' + title; title = '[' + progress.trim() + '] ' + title;
......
(function() { (function() {
const GRADIO_MIN_WIDTH = 320; const GRADIO_MIN_WIDTH = 320;
const GRID_TEMPLATE_COLUMNS = '1fr 16px 1fr';
const PAD = 16; const PAD = 16;
const DEBOUNCE_TIME = 100; const DEBOUNCE_TIME = 100;
const DOUBLE_TAP_DELAY = 200; //ms
const R = { const R = {
tracking: false, tracking: false,
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
leftCol: null, leftCol: null,
leftColStartWidth: null, leftColStartWidth: null,
screenX: null, screenX: null,
lastTapTime: null,
}; };
let resizeTimer; let resizeTimer;
...@@ -23,21 +24,17 @@ ...@@ -23,21 +24,17 @@
function displayResizeHandle(parent) { function displayResizeHandle(parent) {
if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) { if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) {
parent.style.display = 'flex'; parent.style.display = 'flex';
if (R.handle != null) { parent.resizeHandle.style.display = "none";
R.handle.style.opacity = '0';
}
return false; return false;
} else { } else {
parent.style.display = 'grid'; parent.style.display = 'grid';
if (R.handle != null) { parent.resizeHandle.style.display = "block";
R.handle.style.opacity = '100';
}
return true; return true;
} }
} }
function afterResize(parent) { function afterResize(parent) {
if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != GRID_TEMPLATE_COLUMNS) { if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != parent.style.originalGridTemplateColumns) {
const oldParentWidth = R.parentWidth; const oldParentWidth = R.parentWidth;
const newParentWidth = parent.offsetWidth; const newParentWidth = parent.offsetWidth;
const widthL = parseInt(parent.style.gridTemplateColumns.split(' ')[0]); const widthL = parseInt(parent.style.gridTemplateColumns.split(' ')[0]);
...@@ -52,6 +49,14 @@ ...@@ -52,6 +49,14 @@
} }
function setup(parent) { function setup(parent) {
function onDoubleClick(evt) {
evt.preventDefault();
evt.stopPropagation();
parent.style.gridTemplateColumns = parent.style.originalGridTemplateColumns;
}
const leftCol = parent.firstElementChild; const leftCol = parent.firstElementChild;
const rightCol = parent.lastElementChild; const rightCol = parent.lastElementChild;
...@@ -59,63 +64,97 @@ ...@@ -59,63 +64,97 @@
parent.style.display = 'grid'; parent.style.display = 'grid';
parent.style.gap = '0'; parent.style.gap = '0';
parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS; const gridTemplateColumns = `${parent.children[0].style.flexGrow}fr ${PAD}px ${parent.children[1].style.flexGrow}fr`;
parent.style.gridTemplateColumns = gridTemplateColumns;
parent.style.originalGridTemplateColumns = gridTemplateColumns;
const resizeHandle = document.createElement('div'); const resizeHandle = document.createElement('div');
resizeHandle.classList.add('resize-handle'); resizeHandle.classList.add('resize-handle');
parent.insertBefore(resizeHandle, rightCol); parent.insertBefore(resizeHandle, rightCol);
parent.resizeHandle = resizeHandle;
resizeHandle.addEventListener('mousedown', (evt) => {
if (evt.button !== 0) return; ['mousedown', 'touchstart'].forEach((eventType) => {
resizeHandle.addEventListener(eventType, (evt) => {
evt.preventDefault(); if (eventType.startsWith('mouse')) {
evt.stopPropagation(); if (evt.button !== 0) return;
} else {
document.body.classList.add('resizing'); if (evt.changedTouches.length !== 1) return;
R.tracking = true; const currentTime = new Date().getTime();
R.parent = parent; if (R.lastTapTime && currentTime - R.lastTapTime <= DOUBLE_TAP_DELAY) {
R.parentWidth = parent.offsetWidth; onDoubleClick(evt);
R.handle = resizeHandle; return;
R.leftCol = leftCol; }
R.leftColStartWidth = leftCol.offsetWidth;
R.screenX = evt.screenX; R.lastTapTime = currentTime;
}
evt.preventDefault();
evt.stopPropagation();
document.body.classList.add('resizing');
R.tracking = true;
R.parent = parent;
R.parentWidth = parent.offsetWidth;
R.leftCol = leftCol;
R.leftColStartWidth = leftCol.offsetWidth;
if (eventType.startsWith('mouse')) {
R.screenX = evt.screenX;
} else {
R.screenX = evt.changedTouches[0].screenX;
}
});
}); });
resizeHandle.addEventListener('dblclick', (evt) => { resizeHandle.addEventListener('dblclick', onDoubleClick);
evt.preventDefault();
evt.stopPropagation();
parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
});
afterResize(parent); afterResize(parent);
} }
window.addEventListener('mousemove', (evt) => { ['mousemove', 'touchmove'].forEach((eventType) => {
if (evt.button !== 0) return; window.addEventListener(eventType, (evt) => {
if (eventType.startsWith('mouse')) {
if (R.tracking) { if (evt.button !== 0) return;
evt.preventDefault(); } else {
evt.stopPropagation(); if (evt.changedTouches.length !== 1) return;
}
const delta = R.screenX - evt.screenX; if (R.tracking) {
const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - GRADIO_MIN_WIDTH - PAD), GRADIO_MIN_WIDTH); if (eventType.startsWith('mouse')) {
setLeftColGridTemplate(R.parent, leftColWidth); evt.preventDefault();
} }
evt.stopPropagation();
let delta = 0;
if (eventType.startsWith('mouse')) {
delta = R.screenX - evt.screenX;
} else {
delta = R.screenX - evt.changedTouches[0].screenX;
}
const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - GRADIO_MIN_WIDTH - PAD), GRADIO_MIN_WIDTH);
setLeftColGridTemplate(R.parent, leftColWidth);
}
});
}); });
window.addEventListener('mouseup', (evt) => { ['mouseup', 'touchend'].forEach((eventType) => {
if (evt.button !== 0) return; window.addEventListener(eventType, (evt) => {
if (eventType.startsWith('mouse')) {
if (evt.button !== 0) return;
} else {
if (evt.changedTouches.length !== 1) return;
}
if (R.tracking) { if (R.tracking) {
evt.preventDefault(); evt.preventDefault();
evt.stopPropagation(); evt.stopPropagation();
R.tracking = false; R.tracking = false;
document.body.classList.remove('resizing'); document.body.classList.remove('resizing');
} }
});
}); });
......
...@@ -55,8 +55,8 @@ onOptionsChanged(function() { ...@@ -55,8 +55,8 @@ onOptionsChanged(function() {
}); });
opts._categories.forEach(function(x) { opts._categories.forEach(function(x) {
var section = x[0]; var section = localization[x[0]] ?? x[0];
var category = x[1]; var category = localization[x[1]] ?? x[1];
var span = document.createElement('SPAN'); var span = document.createElement('SPAN');
span.textContent = category; span.textContent = category;
......
...@@ -48,11 +48,6 @@ function setupTokenCounting(id, id_counter, id_button) { ...@@ -48,11 +48,6 @@ function setupTokenCounting(id, id_counter, id_button) {
var counter = gradioApp().getElementById(id_counter); var counter = gradioApp().getElementById(id_counter);
var textarea = gradioApp().querySelector(`#${id} > label > textarea`); var textarea = gradioApp().querySelector(`#${id} > label > textarea`);
if (opts.disable_token_counters) {
counter.style.display = "none";
return;
}
if (counter.parentElement == prompt.parentElement) { if (counter.parentElement == prompt.parentElement) {
return; return;
} }
...@@ -61,15 +56,32 @@ function setupTokenCounting(id, id_counter, id_button) { ...@@ -61,15 +56,32 @@ function setupTokenCounting(id, id_counter, id_button) {
prompt.parentElement.style.position = "relative"; prompt.parentElement.style.position = "relative";
var func = onEdit(id, textarea, 800, function() { var func = onEdit(id, textarea, 800, function() {
gradioApp().getElementById(id_button)?.click(); if (counter.classList.contains("token-counter-visible")) {
gradioApp().getElementById(id_button)?.click();
}
}); });
promptTokenCountUpdateFunctions[id] = func; promptTokenCountUpdateFunctions[id] = func;
promptTokenCountUpdateFunctions[id_button] = func; promptTokenCountUpdateFunctions[id_button] = func;
} }
function setupTokenCounters() { function toggleTokenCountingVisibility(id, id_counter, id_button) {
setupTokenCounting('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button'); var counter = gradioApp().getElementById(id_counter);
setupTokenCounting('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
setupTokenCounting('img2img_prompt', 'img2img_token_counter', 'img2img_token_button'); counter.style.display = opts.disable_token_counters ? "none" : "block";
setupTokenCounting('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button'); counter.classList.toggle("token-counter-visible", !opts.disable_token_counters);
} }
function runCodeForTokenCounters(fun) {
fun('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
fun('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
fun('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
fun('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
}
onUiLoaded(function() {
runCodeForTokenCounters(setupTokenCounting);
});
onOptionsChanged(function() {
runCodeForTokenCounters(toggleTokenCountingVisibility);
});
...@@ -119,9 +119,18 @@ function create_submit_args(args) { ...@@ -119,9 +119,18 @@ function create_submit_args(args) {
return res; return res;
} }
function setSubmitButtonsVisibility(tabname, showInterrupt, showSkip, showInterrupting) {
gradioApp().getElementById(tabname + '_interrupt').style.display = showInterrupt ? "block" : "none";
gradioApp().getElementById(tabname + '_skip').style.display = showSkip ? "block" : "none";
gradioApp().getElementById(tabname + '_interrupting').style.display = showInterrupting ? "block" : "none";
}
function showSubmitButtons(tabname, show) { function showSubmitButtons(tabname, show) {
gradioApp().getElementById(tabname + '_interrupt').style.display = show ? "none" : "block"; setSubmitButtonsVisibility(tabname, !show, !show, false);
gradioApp().getElementById(tabname + '_skip').style.display = show ? "none" : "block"; }
function showSubmitInterruptingPlaceholder(tabname) {
setSubmitButtonsVisibility(tabname, false, true, true);
} }
function showRestoreProgressButton(tabname, show) { function showRestoreProgressButton(tabname, show) {
...@@ -150,6 +159,14 @@ function submit() { ...@@ -150,6 +159,14 @@ function submit() {
return res; return res;
} }
function submit_txt2img_upscale() {
var res = submit(...arguments);
res[2] = selected_gallery_index();
return res;
}
function submit_img2img() { function submit_img2img() {
showSubmitButtons('img2img', false); showSubmitButtons('img2img', false);
...@@ -302,8 +319,6 @@ onAfterUiUpdate(function() { ...@@ -302,8 +319,6 @@ onAfterUiUpdate(function() {
}); });
json_elem.parentElement.style.display = "none"; json_elem.parentElement.style.display = "none";
setupTokenCounters();
}); });
onOptionsChanged(function() { onOptionsChanged(function() {
......
This diff is collapsed.
...@@ -107,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( ...@@ -107,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True}, {"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False}, {"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
{"key": "infotext", "type": str, "default": None},
] ]
).generate_model() ).generate_model()
...@@ -124,6 +126,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( ...@@ -124,6 +126,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True}, {"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False}, {"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
{"key": "infotext", "type": str, "default": None},
] ]
).generate_model() ).generate_model()
......
...@@ -62,16 +62,15 @@ def cache(subsection): ...@@ -62,16 +62,15 @@ def cache(subsection):
if cache_data is None: if cache_data is None:
with cache_lock: with cache_lock:
if cache_data is None: if cache_data is None:
if not os.path.isfile(cache_filename): try:
with open(cache_filename, "r", encoding="utf8") as file:
cache_data = json.load(file)
except FileNotFoundError:
cache_data = {}
except Exception:
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
cache_data = {} cache_data = {}
else:
try:
with open(cache_filename, "r", encoding="utf8") as file:
cache_data = json.load(file)
except Exception:
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
cache_data = {}
s = cache_data.get(subsection, {}) s = cache_data.get(subsection, {})
cache_data[subsection] = s cache_data[subsection] = s
......
...@@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): ...@@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
shared.state.skipped = False shared.state.skipped = False
shared.state.interrupted = False shared.state.interrupted = False
shared.state.stopping_generation = False
shared.state.job_count = 0 shared.state.job_count = 0
if not add_stats: if not add_stats:
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -107,8 +107,8 @@ def check_versions(): ...@@ -107,8 +107,8 @@ def check_versions():
import torch import torch
import gradio import gradio
expected_torch_version = "2.0.0" expected_torch_version = "2.1.2"
expected_xformers_version = "0.0.20" expected_xformers_version = "0.0.23.post1"
expected_gradio_version = "3.41.2" expected_gradio_version = "3.41.2"
if version.parse(torch.__version__) < version.parse(expected_torch_version): if version.parse(torch.__version__) < version.parse(expected_torch_version):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -21,7 +21,10 @@ def calculate_sha256(filename): ...@@ -21,7 +21,10 @@ def calculate_sha256(filename):
def sha256_from_cache(filename, title, use_addnet_hash=False): def sha256_from_cache(filename, title, use_addnet_hash=False):
hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes") hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
ondisk_mtime = os.path.getmtime(filename) try:
ondisk_mtime = os.path.getmtime(filename)
except FileNotFoundError:
return None
if title not in hashes: if title not in hashes:
return None return None
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment