Commit ec37f8a4 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'master' into features-to-readme

parents 003d2c7f 85cb5918
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: bug-report
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Desktop (please complete the following information):**
- OS: [e.g. Windows, Linux]
- Browser [e.g. chrome, safari]
- Commit revision [looks like this: e68484500f76a33ba477d5a99340ab30451e557b; can be seen when launching webui.bat, or obtained manually by running `git rev-parse HEAD`]
**Additional context**
Add any other context about the problem here.
name: Bug Report
description: You think somethings is broken in the UI
title: "[Bug]: "
labels: ["bug-report"]
body:
- type: checkboxes
attributes:
label: Is there an existing issue for this?
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
options:
- label: I have searched the existing issues and checked the recent builds/commits
required: true
- type: markdown
attributes:
value: |
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
- type: textarea
id: what-did
attributes:
label: What happened?
description: Tell us what happened in a very clear and simple way
validations:
required: true
- type: textarea
id: steps
attributes:
label: Steps to reproduce the problem
description: Please provide us with precise step by step information on how to reproduce the bug
value: |
1. Go to ....
2. Press ....
3. ...
validations:
required: true
- type: textarea
id: what-should
attributes:
label: What should have happened?
description: tell what you think the normal behavior should be
validations:
required: true
- type: input
id: commit
attributes:
label: Commit where the problem happens
description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
validations:
required: true
- type: dropdown
id: platforms
attributes:
label: What platforms do you use to access UI ?
multiple: true
options:
- Windows
- Linux
- MacOS
- iOS
- Android
- Other/Cloud
- type: dropdown
id: browsers
attributes:
label: What browsers do you use to access the UI ?
multiple: true
options:
- Mozilla Firefox
- Google Chrome
- Brave
- Apple Safari
- Microsoft Edge
- type: textarea
id: cmdargs
attributes:
label: Command Line Arguments
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
render: Shell
- type: textarea
id: misc
attributes:
label: Additional information, context and logs
description: Please provide us with any relevant additional info, context or log output.
blank_issues_enabled: false
contact_links:
- name: WebUI Community Support
url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
about: Please ask and answer questions here.
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: 'suggestion'
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
name: Feature request
description: Suggest an idea for this project
title: "[Feature Request]: "
labels: ["suggestion"]
body:
- type: checkboxes
attributes:
label: Is there an existing issue for this?
description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
options:
- label: I have searched the existing issues and checked the recent builds/commits
required: true
- type: markdown
attributes:
value: |
*Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
- type: textarea
id: feature
attributes:
label: What would your feature do ?
description: Tell us about your feature in a very clear and simple way, and what problem it would solve
validations:
required: true
- type: textarea
id: workflow
attributes:
label: Proposed workflow
description: Please provide us with step by step information on how you'd like the feature to be accessed and used
value: |
1. Go to ....
2. Press ....
3. ...
validations:
required: true
- type: textarea
id: misc
attributes:
label: Additional information
description: Add any other context or screenshots about the feature request here.
...@@ -27,3 +27,4 @@ __pycache__ ...@@ -27,3 +27,4 @@ __pycache__
notification.mp3 notification.mp3
/SwinIR /SwinIR
/textual_inversion /textual_inversion
.vscode
\ No newline at end of file
...@@ -82,8 +82,8 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -82,8 +82,8 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- Use VAEs - Use VAEs
- Estimated completion time in progress bar - Estimated completion time in progress bar
- API - API
- Support for dedicated inpainting model by RunwayML. - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- Aesthetic Gradients, a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
## Installation and Running ## Installation and Running
......
...@@ -523,7 +523,6 @@ Affandi,0.7170285,nudity ...@@ -523,7 +523,6 @@ Affandi,0.7170285,nudity
Diane Arbus,0.655138,digipa-high-impact Diane Arbus,0.655138,digipa-high-impact
Joseph Ducreux,0.65247905,digipa-high-impact Joseph Ducreux,0.65247905,digipa-high-impact
Berthe Morisot,0.7165984,fineart Berthe Morisot,0.7165984,fineart
Hilma AF Klint,0.71643853,scribbles
Hilma af Klint,0.71643853,scribbles Hilma af Klint,0.71643853,scribbles
Filippino Lippi,0.7163017,fineart Filippino Lippi,0.7163017,fineart
Leonid Afremov,0.7163005,fineart Leonid Afremov,0.7163005,fineart
...@@ -738,14 +737,12 @@ Abraham Mignon,0.60605425,fineart ...@@ -738,14 +737,12 @@ Abraham Mignon,0.60605425,fineart
Albert Bloch,0.69573116,nudity Albert Bloch,0.69573116,nudity
Charles Dana Gibson,0.67155975,fineart Charles Dana Gibson,0.67155975,fineart
Alexandre-Évariste Fragonard,0.6507174,fineart Alexandre-Évariste Fragonard,0.6507174,fineart
Alexandre-Évariste Fragonard,0.6507174,fineart
Ernst Fuchs,0.6953538,nudity Ernst Fuchs,0.6953538,nudity
Alfredo Jaar,0.6952965,digipa-high-impact Alfredo Jaar,0.6952965,digipa-high-impact
Judy Chicago,0.6952246,weird Judy Chicago,0.6952246,weird
Frans van Mieris the Younger,0.6951849,fineart Frans van Mieris the Younger,0.6951849,fineart
Aertgen van Leyden,0.6951305,fineart Aertgen van Leyden,0.6951305,fineart
Emily Carr,0.69512105,fineart Emily Carr,0.69512105,fineart
Frances Macdonald,0.6950408,scribbles
Frances MacDonald,0.6950408,scribbles Frances MacDonald,0.6950408,scribbles
Hannah Höch,0.69495845,scribbles Hannah Höch,0.69495845,scribbles
Gillis Rombouts,0.58770025,fineart Gillis Rombouts,0.58770025,fineart
...@@ -895,7 +892,6 @@ Richard McGuire,0.6820089,scribbles ...@@ -895,7 +892,6 @@ Richard McGuire,0.6820089,scribbles
Anni Albers,0.65708244,digipa-high-impact Anni Albers,0.65708244,digipa-high-impact
Aleksey Savrasov,0.65207493,fineart Aleksey Savrasov,0.65207493,fineart
Wayne Barlowe,0.6537874,fineart Wayne Barlowe,0.6537874,fineart
Giorgio De Chirico,0.6815907,fineart
Giorgio de Chirico,0.6815907,fineart Giorgio de Chirico,0.6815907,fineart
Ernest Procter,0.6815795,fineart Ernest Procter,0.6815795,fineart
Adriaen Brouwer,0.6815058,fineart Adriaen Brouwer,0.6815058,fineart
...@@ -1241,7 +1237,6 @@ Betty Churcher,0.65387225,fineart ...@@ -1241,7 +1237,6 @@ Betty Churcher,0.65387225,fineart
Claes Corneliszoon Moeyaert,0.65386075,fineart Claes Corneliszoon Moeyaert,0.65386075,fineart
David Bomberg,0.6537477,fineart David Bomberg,0.6537477,fineart
Abraham Bosschaert,0.6535562,fineart Abraham Bosschaert,0.6535562,fineart
Giuseppe De Nittis,0.65354455,fineart
Giuseppe de Nittis,0.65354455,fineart Giuseppe de Nittis,0.65354455,fineart
John La Farge,0.65342575,fineart John La Farge,0.65342575,fineart
Frits Thaulow,0.65341854,fineart Frits Thaulow,0.65341854,fineart
...@@ -1522,7 +1517,6 @@ Gertrude Harvey,0.5903887,fineart ...@@ -1522,7 +1517,6 @@ Gertrude Harvey,0.5903887,fineart
Grant Wood,0.6266253,fineart Grant Wood,0.6266253,fineart
Fyodor Vasilyev,0.5234919,digipa-med-impact Fyodor Vasilyev,0.5234919,digipa-med-impact
Cagnaccio di San Pietro,0.6261671,fineart Cagnaccio di San Pietro,0.6261671,fineart
Cagnaccio Di San Pietro,0.6261671,fineart
Doris Boulton-Maude,0.62593174,fineart Doris Boulton-Maude,0.62593174,fineart
Adolf Hirémy-Hirschl,0.5946784,fineart Adolf Hirémy-Hirschl,0.5946784,fineart
Harold von Schmidt,0.6256755,fineart Harold von Schmidt,0.6256755,fineart
...@@ -2411,7 +2405,6 @@ Hermann Feierabend,0.5346168,digipa-high-impact ...@@ -2411,7 +2405,6 @@ Hermann Feierabend,0.5346168,digipa-high-impact
Antonio Donghi,0.4610982,digipa-low-impact Antonio Donghi,0.4610982,digipa-low-impact
Adonna Khare,0.4858036,digipa-med-impact Adonna Khare,0.4858036,digipa-med-impact
James Stokoe,0.5015107,digipa-med-impact James Stokoe,0.5015107,digipa-med-impact
Art & Language,0.5341332,digipa-high-impact
Agustín Fernández,0.53403986,fineart Agustín Fernández,0.53403986,fineart
Germán Londoño,0.5338712,fineart Germán Londoño,0.5338712,fineart
Emmanuelle Moureaux,0.5335641,digipa-high-impact Emmanuelle Moureaux,0.5335641,digipa-high-impact
......
...@@ -9,9 +9,38 @@ addEventListener('keydown', (event) => { ...@@ -9,9 +9,38 @@ addEventListener('keydown', (event) => {
let minus = "ArrowDown" let minus = "ArrowDown"
if (event.key != plus && event.key != minus) return; if (event.key != plus && event.key != minus) return;
selectionStart = target.selectionStart; let selectionStart = target.selectionStart;
selectionEnd = target.selectionEnd; let selectionEnd = target.selectionEnd;
if(selectionStart == selectionEnd) return; // If the user hasn't selected anything, let's select their current parenthesis block
if (selectionStart === selectionEnd) {
// Find opening parenthesis around current cursor
const before = target.value.substring(0, selectionStart);
let beforeParen = before.lastIndexOf("(");
if (beforeParen == -1) return;
let beforeParenClose = before.lastIndexOf(")");
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf("(", beforeParen - 1);
beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1);
}
// Find closing parenthesis around current cursor
const after = target.value.substring(selectionStart);
let afterParen = after.indexOf(")");
if (afterParen == -1) return;
let afterParenOpen = after.indexOf("(");
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(")", afterParen + 1);
afterParenOpen = after.indexOf("(", afterParenOpen + 1);
}
if (beforeParen === -1 || afterParen === -1) return;
// Set the selection to the text between the parenthesis
const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen);
const lastColon = parenContent.lastIndexOf(":");
selectionStart = beforeParen + 1;
selectionEnd = selectionStart + lastColon;
target.setSelectionRange(selectionStart, selectionEnd);
}
event.preventDefault(); event.preventDefault();
......
...@@ -91,6 +91,8 @@ titles = { ...@@ -91,6 +91,8 @@ titles = {
"Weighted sum": "Result = A * (1 - M) + B * M", "Weighted sum": "Result = A * (1 - M) + B * M",
"Add difference": "Result = A + (B - C) * M", "Add difference": "Result = A + (B - C) * M",
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
} }
......
// various functions for interation with ui.py not large enough to warrant putting them in separate files // various functions for interation with ui.py not large enough to warrant putting them in separate files
function set_theme(theme){
gradioURL = window.location.href
if (!gradioURL.includes('?__theme=')) {
window.location.replace(gradioURL + '?__theme=' + theme);
}
}
function selected_gallery_index(){ function selected_gallery_index(){
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item') var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item')
var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2') var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2')
......
...@@ -86,7 +86,24 @@ def git_clone(url, dir, name, commithash=None): ...@@ -86,7 +86,24 @@ def git_clone(url, dir, name, commithash=None):
if commithash is not None: if commithash is not None:
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
def version_check(commit):
try:
import requests
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
if commit != "<none>" and commits['commit']['sha'] != commit:
print("--------------------------------------------------------")
print("| You are not up to date with the most recent release. |")
print("| Consider running `git pull` to update. |")
print("--------------------------------------------------------")
elif commits['commit']['sha'] == commit:
print("You are up to date with the most recent release.")
else:
print("Not a git clone, can't perform version check.")
except Exception as e:
print("versipm check failed",e)
def prepare_enviroment(): def prepare_enviroment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
...@@ -110,13 +127,14 @@ def prepare_enviroment(): ...@@ -110,13 +127,14 @@ def prepare_enviroment():
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
args = shlex.split(commandline_args) sys.argv += shlex.split(commandline_args)
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
args, reinstall_xformers = extract_arg(args, '--reinstall-xformers') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
xformers = '--xformers' in args sys.argv, update_check = extract_arg(sys.argv, '--update-check')
deepdanbooru = '--deepdanbooru' in args xformers = '--xformers' in sys.argv
ngrok = '--ngrok' in args deepdanbooru = '--deepdanbooru' in sys.argv
ngrok = '--ngrok' in sys.argv
try: try:
commit = run(f"{git} rev-parse HEAD").strip() commit = run(f"{git} rev-parse HEAD").strip()
...@@ -125,7 +143,7 @@ def prepare_enviroment(): ...@@ -125,7 +143,7 @@ def prepare_enviroment():
print(f"Python {sys.version}") print(f"Python {sys.version}")
print(f"Commit hash: {commit}") print(f"Commit hash: {commit}")
if not is_installed("torch") or not is_installed("torchvision"): if not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
...@@ -138,9 +156,15 @@ def prepare_enviroment(): ...@@ -138,9 +156,15 @@ def prepare_enviroment():
if not is_installed("clip"): if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip") run_pip(f"install {clip_package}", "clip")
if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"): if (not is_installed("xformers") or reinstall_xformers) and xformers:
if platform.system() == "Windows": if platform.system() == "Windows":
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers") if platform.python_version().startswith("3.10"):
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
else:
print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
if not is_installed("xformers"):
exit(0)
elif platform.system() == "Linux": elif platform.system() == "Linux":
run_pip("install xformers", "xformers") run_pip("install xformers", "xformers")
...@@ -163,9 +187,10 @@ def prepare_enviroment(): ...@@ -163,9 +187,10 @@ def prepare_enviroment():
run_pip(f"install -r {requirements_file}", "requirements for Web UI") run_pip(f"install -r {requirements_file}", "requirements for Web UI")
sys.argv += args if update_check:
version_check(commit)
if "--exit" in args:
if "--exit" in sys.argv:
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
exit(0) exit(0)
......
This diff is collapsed.
from modules.api.processing import StableDiffusionProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import modules.shared as shared
import uvicorn
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
import json
import io
import base64
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
"do_not_save_grid": True
}
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
with self.queue_lock:
processed = process_images(p)
b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
def img2imgapi(self):
raise NotImplementedError
def extrasapi(self):
raise NotImplementedError
def pnginfoapi(self):
raise NotImplementedError
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)
from inflection import underscore
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, create_model
from modules.processing import StableDiffusionProcessingTxt2Img
import inspect
API_NOT_ALLOWED = [
"self",
"kwargs",
"sd_model",
"outpath_samples",
"outpath_grids",
"sampler_index",
"do_not_save_samples",
"do_not_save_grid",
"extra_generation_params",
"overlay_images",
"do_not_reload_embeddings",
"seed_enable_extras",
"prompt_for_display",
"sampler_noise_scheduler_override",
"ddim_discretize"
]
class ModelDef(BaseModel):
"""Assistance Class for Pydantic Dynamic Model Generation"""
field: str
field_alias: str
field_type: Any
field_value: Any
class PydanticModelGenerator:
"""
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
source_data is a snapshot of the default values produced by the class
params are the names of the actual keys required by __init__
"""
def __init__(
self,
model_name: str = None,
class_instance = None,
additional_fields = None,
):
def field_type_generator(k, v):
# field_type = str if not overrides.get(k) else overrides[k]["type"]
# print(k, v.annotation, v.default)
field_type = v.annotation
return Optional[field_type]
def merge_class_params(class_):
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
parameters = {}
for classes in all_classes:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters
self._model_name = model_name
self._class_data = merge_class_params(class_instance)
self._model_def = [
ModelDef(
field=underscore(k),
field_alias=k,
field_type=field_type_generator(k, v),
field_value=v.default
)
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
for fields in additional_fields:
self._model_def.append(ModelDef(
field=underscore(fields["key"]),
field_alias=fields["key"],
field_type=fields["type"],
field_value=fields["default"]))
def generate_model(self):
"""
Creates a pydantic BaseModel
from the json and overrides provided at initialization
"""
fields = {
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
}
DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True
DynamicModel.__config__.allow_mutation = True
return DynamicModel
StableDiffusionProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}]
).generate_model()
\ No newline at end of file
...@@ -157,8 +157,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o ...@@ -157,8 +157,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o
# sort by reverse by likelihood and normal for alpha, and format tag text as requested # sort by reverse by likelihood and normal for alpha, and format tag text as requested
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold: for weight, tag in unsorted_tags_in_theshold:
# note: tag_outformat will still have a colon if include_ranks is True tag_outformat = tag
tag_outformat = tag.replace(':', ' ')
if use_spaces: if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ') tag_outformat = tag_outformat.replace('_', ' ')
if use_escape: if use_escape:
......
...@@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if input_dir == '': if input_dir == '':
return outputs, "Please select an input directory.", '' return outputs, "Please select an input directory.", ''
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
for img in image_list: for img in image_list:
image = Image.open(img) try:
image = Image.open(img)
except Exception:
continue
imageArr.append(image) imageArr.append(image)
imageNameArr.append(img) imageNameArr.append(img)
else: else:
...@@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
while len(cached_images) > 2: while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))] del cached_images[next(iter(cached_images.keys()))]
if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
basename = ''
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
forced_filename=image_name if opts.use_original_name_batch else None)
if opts.enable_pnginfo: if opts.enable_pnginfo:
image.info = existing_pnginfo image.info = existing_pnginfo
...@@ -216,8 +223,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam ...@@ -216,8 +223,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
if theta_func1: if theta_func1:
for key in tqdm.tqdm(theta_1.keys()): for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key: if 'model' in key:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) if key in theta_2:
theta_1[key] = theta_func1(theta_1[key], t2) t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2, teritary_model del theta_2, teritary_model
for key in tqdm.tqdm(theta_0.keys()): for key in tqdm.tqdm(theta_0.keys()):
......
...@@ -4,13 +4,22 @@ import gradio as gr ...@@ -4,13 +4,22 @@ import gradio as gr
from modules.shared import script_path from modules.shared import script_path
from modules import shared from modules import shared
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code) re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update()) type_of_gr_update = type(gr.update())
def quote(text):
if ',' not in str(text):
return text
text = str(text)
text = text.replace('\\', '\\\\')
text = text.replace('"', '\\"')
return f'"{text}"'
def parse_generation_parameters(x: str): def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI: """parses generation parameters string, the one you see in text field under the picture in UI:
``` ```
...@@ -45,11 +54,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -45,11 +54,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else: else:
prompt += ("" if prompt == "" else "\n") + line prompt += ("" if prompt == "" else "\n") + line
if len(prompt) > 0: res["Prompt"] = prompt
res["Prompt"] = prompt res["Negative prompt"] = negative_prompt
if len(negative_prompt) > 0:
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline): for k, v in re_param.findall(lastline):
m = re_imagesize.match(v) m = re_imagesize.match(v)
...@@ -86,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None): ...@@ -86,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
else: else:
try: try:
valtype = type(output.value) valtype = type(output.value)
val = valtype(v)
if valtype == bool and v == "False":
val = False
else:
val = valtype(v)
res.append(gr.update(value=val)) res.append(gr.update(value=val))
except Exception: except Exception:
res.append(gr.update()) res.append(gr.update())
......
This diff is collapsed.
import html import html
import os import os
import re
import gradio as gr import gradio as gr
...@@ -9,11 +10,21 @@ from modules import sd_hijack, shared, devices ...@@ -9,11 +10,21 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists" if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
add_layer_norm=add_layer_norm,
activation_func=activation_func,
)
hypernet.save(fn) hypernet.save(fn)
shared.reload_hypernetworks() shared.reload_hypernetworks()
......
...@@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
is_inpaint = mode == 1 is_inpaint = mode == 1
is_batch = mode == 2 is_batch = mode == 2
...@@ -109,6 +109,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -109,6 +109,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert, inpainting_mask_invert=inpainting_mask_invert,
) )
shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
if shared.cmd_opts.enable_console_prompts: if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
......
...@@ -28,9 +28,11 @@ class InterrogateModels: ...@@ -28,9 +28,11 @@ class InterrogateModels:
clip_preprocess = None clip_preprocess = None
categories = None categories = None
dtype = None dtype = None
running_on_cpu = None
def __init__(self, content_dir): def __init__(self, content_dir):
self.categories = [] self.categories = []
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
if os.path.exists(content_dir): if os.path.exists(content_dir):
for filename in os.listdir(content_dir): for filename in os.listdir(content_dir):
...@@ -53,7 +55,11 @@ class InterrogateModels: ...@@ -53,7 +55,11 @@ class InterrogateModels:
def load_clip_model(self): def load_clip_model(self):
import clip import clip
model, preprocess = clip.load(clip_model_name) if self.running_on_cpu:
model, preprocess = clip.load(clip_model_name, device="cpu")
else:
model, preprocess = clip.load(clip_model_name)
model.eval() model.eval()
model = model.to(devices.device_interrogate) model = model.to(devices.device_interrogate)
...@@ -62,14 +68,14 @@ class InterrogateModels: ...@@ -62,14 +68,14 @@ class InterrogateModels:
def load(self): def load(self):
if self.blip_model is None: if self.blip_model is None:
self.blip_model = self.load_blip_model() self.blip_model = self.load_blip_model()
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.blip_model = self.blip_model.half() self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(devices.device_interrogate) self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None: if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model() self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.clip_model = self.clip_model.half() self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(devices.device_interrogate) self.clip_model = self.clip_model.to(devices.device_interrogate)
......
This diff is collapsed.
...@@ -275,7 +275,7 @@ re_attention = re.compile(r""" ...@@ -275,7 +275,7 @@ re_attention = re.compile(r"""
def parse_prompt_attention(text): def parse_prompt_attention(text):
""" """
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight. Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are: Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1 (abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12 (abc:3.12) - increases attention to abc by a multiplier of 3.12
......
...@@ -96,6 +96,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs): ...@@ -96,6 +96,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
class ScriptRunner: class ScriptRunner:
def __init__(self): def __init__(self):
self.scripts = [] self.scripts = []
self.titles = []
def setup_ui(self, is_img2img): def setup_ui(self, is_img2img):
for script_class, path in scripts_data: for script_class, path in scripts_data:
...@@ -107,9 +108,10 @@ class ScriptRunner: ...@@ -107,9 +108,10 @@ class ScriptRunner:
self.scripts.append(script) self.scripts.append(script)
titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index") dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs = [dropdown] inputs = [dropdown]
for script in self.scripts: for script in self.scripts:
...@@ -139,6 +141,15 @@ class ScriptRunner: ...@@ -139,6 +141,15 @@ class ScriptRunner:
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
def init_field(title):
if title == 'None':
return
script_index = self.titles.index(title)
script = self.scripts[script_index]
for i in range(script.args_from, script.args_to):
inputs[i].visible = True
dropdown.init_field = init_field
dropdown.change( dropdown.change(
fn=select_script, fn=select_script,
inputs=[dropdown], inputs=[dropdown],
......
...@@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward ...@@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
def apply_optimizations(): def apply_optimizations():
undo_optimizations() undo_optimizations()
...@@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens = remade_tokens[:last_comma] remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens) length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None: if embedding is None:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(weight) multipliers.append(weight)
...@@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text): def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
...@@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
token_count = len(remade_tokens) token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers) cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
...@@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes) hijack_fixes.append(fixes)
batch_multipliers.append(multipliers) batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text): def forward(self, text):
use_old = opts.use_old_emphasis_implementation use_old = opts.use_old_emphasis_implementation
if use_old: if use_old:
...@@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if len(used_custom_terms) > 0: if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old: if use_old:
self.hijack.fixes = hijack_fixes self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers) return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None z = None
i = 0 i = 0
while max(map(len, remade_batch_tokens)) != 0: while max(map(len, remade_batch_tokens)) != 0:
...@@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if fix[0] == i: if fix[0] == i:
fixes.append(fix[1]) fixes.append(fix[1])
self.hijack.fixes.append(fixes) self.hijack.fixes.append(fixes)
tokens = [] tokens = []
multipliers = [] multipliers = []
for j in range(len(remade_batch_tokens)): for j in range(len(remade_batch_tokens)):
...@@ -332,20 +332,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -332,20 +332,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers.append([1.0] * 75) multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers) z1 = self.process_tokens(tokens, multipliers)
z1 = shared.aesthetic_clip(z1, remade_batch_tokens)
z = z1 if z is None else torch.cat((z, z1), axis=-2) z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers batch_multipliers = rem_multipliers
i += 1 i += 1
return z return z
def process_tokens(self, remade_batch_tokens, batch_multipliers): def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation: if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device) tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
...@@ -385,8 +385,8 @@ class EmbeddingsWithFixes(torch.nn.Module): ...@@ -385,8 +385,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
for fixes, tensor in zip(batch_fixes, inputs_embeds): for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes: for offset, embedding in fixes:
emb = embedding.vec emb = embedding.vec
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor) vecs.append(tensor)
......
This diff is collapsed.
...@@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v): ...@@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v):
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation # Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def einsum_op(q, k, v): def einsum_op(q, k, v):
if q.device.type == 'cuda': if q.device.type == 'cuda':
...@@ -296,10 +296,16 @@ def xformers_attnblock_forward(self, x): ...@@ -296,10 +296,16 @@ def xformers_attnblock_forward(self, x):
try: try:
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
q1 = self.q(h_).contiguous() q = self.q(h_)
k1 = self.k(h_).contiguous() k = self.k(h_)
v = self.v(h_).contiguous() v = self.v(h_)
out = xformers.ops.memory_efficient_attention(q1, k1, v) b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out) out = self.proj_out(out)
return x + out return x + out
except NotImplementedError: except NotImplementedError:
......
...@@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config ...@@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices from modules import shared, modelloader, devices
from modules.paths import models_path from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(models_path, model_dir))
...@@ -20,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict() ...@@ -20,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict()
try: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging from transformers import logging, CLIPModel
logging.set_verbosity_error() logging.set_verbosity_error()
except Exception: except Exception:
...@@ -122,9 +123,34 @@ def select_checkpoint(): ...@@ -122,9 +123,34 @@ def select_checkpoint():
return checkpoint_info return checkpoint_info
chckpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}
def transform_checkpoint_dict_key(k):
for text, replacement in chckpoint_dict_replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
return k
def get_state_dict_from_checkpoint(pl_sd): def get_state_dict_from_checkpoint(pl_sd):
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
return pl_sd["state_dict"] pl_sd = pl_sd["state_dict"]
sd = {}
for k, v in pl_sd.items():
new_key = transform_checkpoint_dict_key(k)
if new_key is not None:
sd[new_key] = v
pl_sd.clear()
pl_sd.update(sd)
return pl_sd return pl_sd
...@@ -141,7 +167,7 @@ def load_model_weights(model, checkpoint_info): ...@@ -141,7 +167,7 @@ def load_model_weights(model, checkpoint_info):
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd) sd = get_state_dict_from_checkpoint(pl_sd)
model.load_state_dict(sd, strict=False) missing, extra = model.load_state_dict(sd, strict=False)
if shared.cmd_opts.opt_channelslast: if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last) model.to(memory_format=torch.channels_last)
...@@ -178,14 +204,26 @@ def load_model_weights(model, checkpoint_info): ...@@ -178,14 +204,26 @@ def load_model_weights(model, checkpoint_info):
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info
def load_model(): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config: if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}") print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config) sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config.model.params.use_ema = False
sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.unet_config.params.in_channels = 9
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
...@@ -209,9 +247,9 @@ def reload_model_weights(sd_model, info=None): ...@@ -209,9 +247,9 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config: if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear() checkpoints_loaded.clear()
shared.sd_model = load_model() shared.sd_model = load_model(checkpoint_info)
return shared.sd_model return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
......
This diff is collapsed.
...@@ -3,6 +3,7 @@ import datetime ...@@ -3,6 +3,7 @@ import datetime
import json import json
import os import os
import sys import sys
from collections import OrderedDict
import gradio as gr import gradio as gr
import tqdm import tqdm
...@@ -30,6 +31,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th ...@@ -30,6 +31,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(models_path, 'aesthetic_embeddings'), help="aesthetic_embeddings directory(default: aesthetic_embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
...@@ -70,12 +72,14 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image upload ...@@ -70,12 +72,14 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image upload
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
restricted_opts = [ restricted_opts = [
...@@ -104,6 +108,21 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) ...@@ -104,6 +108,21 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None loaded_hypernetwork = None
os.makedirs(cmd_opts.aesthetic_embeddings_dir, exist_ok=True)
aesthetic_embeddings = {}
def update_aesthetic_embeddings():
global aesthetic_embeddings
aesthetic_embeddings = {f.replace(".pt", ""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings)
update_aesthetic_embeddings()
def reload_hypernetworks(): def reload_hypernetworks():
global hypernetworks global hypernetworks
...@@ -135,7 +154,7 @@ class State: ...@@ -135,7 +154,7 @@ class State:
self.job_no += 1 self.job_no += 1
self.sampling_step = 0 self.sampling_step = 0
self.current_image_sampling_step = 0 self.current_image_sampling_step = 0
def get_job_timestamp(self): def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
...@@ -293,6 +312,7 @@ options_templates.update(options_section(('ui', "User interface"), { ...@@ -293,6 +312,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
...@@ -384,6 +404,11 @@ sd_upscalers = [] ...@@ -384,6 +404,11 @@ sd_upscalers = []
sd_model = None sd_model = None
clip_model = None
from modules.aesthetic_clip import AestheticCLIP
aesthetic_clip = AestheticCLIP()
progress_print_out = sys.stdout progress_print_out = sys.stdout
......
...@@ -45,7 +45,7 @@ class StyleDatabase: ...@@ -45,7 +45,7 @@ class StyleDatabase:
if not os.path.exists(path): if not os.path.exists(path):
return return
with open(path, "r", encoding="utf8", newline='') as file: with open(path, "r", encoding="utf-8-sig", newline='') as file:
reader = csv.DictReader(file) reader = csv.DictReader(file)
for row in reader: for row in reader:
# Support loading old CSV format with "name, text"-columns # Support loading old CSV format with "name, text"-columns
...@@ -79,7 +79,7 @@ class StyleDatabase: ...@@ -79,7 +79,7 @@ class StyleDatabase:
def save_styles(self, path: str) -> None: def save_styles(self, path: str) -> None:
# Write to temporary file first, so we don't nuke the file if something goes wrong # Write to temporary file first, so we don't nuke the file if something goes wrong
fd, temp_path = tempfile.mkstemp(".csv") fd, temp_path = tempfile.mkstemp(".csv")
with os.fdopen(fd, "w", encoding="utf8", newline='') as file: with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
......
...@@ -83,7 +83,7 @@ class PersonalizedBase(Dataset): ...@@ -83,7 +83,7 @@ class PersonalizedBase(Dataset):
self.dataset.append(entry) self.dataset.append(entry)
assert len(self.dataset) > 1, "No images have been found in the dataset." assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(len(self.dataset)) self.initial_indexes = np.arange(len(self.dataset))
...@@ -91,7 +91,7 @@ class PersonalizedBase(Dataset): ...@@ -91,7 +91,7 @@ class PersonalizedBase(Dataset):
self.shuffle() self.shuffle()
def shuffle(self): def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
def create_text(self, filename_text): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
......
...@@ -5,6 +5,7 @@ import zlib ...@@ -5,6 +5,7 @@ import zlib
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
from fonts.ttf import Roboto from fonts.ttf import Roboto
import torch import torch
from modules.shared import opts
class EmbeddingEncoder(json.JSONEncoder): class EmbeddingEncoder(json.JSONEncoder):
...@@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t ...@@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
from math import cos from math import cos
image = srcimage.copy() image = srcimage.copy()
fontsize = 32
if textfont is None: if textfont is None:
try: try:
textfont = ImageFont.truetype(opts.font or Roboto, fontsize) textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
...@@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t ...@@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
fontsize = 32
font = ImageFont.truetype(textfont, fontsize) font = ImageFont.truetype(textfont, fontsize)
padding = 10 padding = 10
......
...@@ -11,7 +11,7 @@ if cmd_opts.deepdanbooru: ...@@ -11,7 +11,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
try: try:
if process_caption: if process_caption:
shared.interrogator.load() shared.interrogator.load()
...@@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ ...@@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru)
finally: finally:
...@@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ ...@@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width width = process_width
height = process_height height = process_height
src = os.path.abspath(process_src) src = os.path.abspath(process_src)
...@@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro ...@@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
shared.state.textinfo = "Preprocessing..." shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files) shared.state.job_count = len(files)
def save_pic_with_caption(image, index): def save_pic_with_caption(image, index, existing_caption=None):
caption = "" caption = ""
if process_caption: if process_caption:
...@@ -66,17 +66,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro ...@@ -66,17 +66,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
basename = f"{index:05}-{subindex[0]}-{filename_part}" basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png")) image.save(os.path.join(dst, f"{basename}.png"))
if preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption
elif preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption
elif preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption
caption = caption.strip()
if len(caption) > 0: if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file: with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption) file.write(caption)
subindex[0] += 1 subindex[0] += 1
def save_pic(image, index): def save_pic(image, index, existing_caption=None):
save_pic_with_caption(image, index) save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip: if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index) save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
for index, imagefile in enumerate(tqdm.tqdm(files)): for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0] subindex = [0]
...@@ -86,6 +95,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro ...@@ -86,6 +95,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
except Exception: except Exception:
continue continue
existing_caption = None
try:
existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read()
except Exception as e:
print(e)
if shared.state.interrupted: if shared.state.interrupted:
break break
...@@ -97,20 +113,20 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro ...@@ -97,20 +113,20 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
img = img.resize((width, height * img.height // img.width)) img = img.resize((width, height * img.height // img.width))
top = img.crop((0, 0, width, height)) top = img.crop((0, 0, width, height))
save_pic(top, index) save_pic(top, index, existing_caption=existing_caption)
bot = img.crop((0, img.height - height, width, img.height)) bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index) save_pic(bot, index, existing_caption=existing_caption)
elif process_split and is_wide: elif process_split and is_wide:
img = img.resize((width * img.width // img.height, height)) img = img.resize((width * img.width // img.height, height))
left = img.crop((0, 0, width, height)) left = img.crop((0, 0, width, height))
save_pic(left, index) save_pic(left, index, existing_caption=existing_caption)
right = img.crop((img.width - width, 0, img.width, height)) right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index) save_pic(right, index, existing_caption=existing_caption)
else: else:
img = images.resize_image(1, img, width, height) img = images.resize_image(1, img, width, height)
save_pic(img, index) save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob() shared.state.nextjob()
...@@ -153,7 +153,7 @@ class EmbeddingDatabase: ...@@ -153,7 +153,7 @@ class EmbeddingDatabase:
return None, None return None, None
def create_embedding(name, num_vectors_per_token, init_text='*'): def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
...@@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): ...@@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists" if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name) embedding = Embedding(vec, name)
embedding.step = 0 embedding.step = 0
...@@ -275,6 +276,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -275,6 +276,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
loss.backward() loss.backward()
optimizer.step() optimizer.step()
epoch_num = embedding.step // len(ds) epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * len(ds)) + 1 epoch_step = embedding.step - (epoch_num * len(ds)) + 1
......
...@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess ...@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared from modules import sd_hijack, shared
def create_embedding(name, initialization_text, nvpt): def create_embedding(name, initialization_text, nvpt, overwrite_old):
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text) filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
......
import modules.scripts import modules.scripts
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
import modules.shared as shared import modules.shared as shared
import modules.processing as processing import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
...@@ -35,6 +36,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -35,6 +36,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None, firstphase_height=firstphase_height if enable_hr else None,
) )
shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
if cmd_opts.enable_console_prompts: if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
...@@ -53,4 +56,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -53,4 +56,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed.images = [] processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info)
This diff is collapsed.
...@@ -23,3 +23,4 @@ resize-right ...@@ -23,3 +23,4 @@ resize-right
torchdiffeq torchdiffeq
kornia kornia
lark lark
inflection
...@@ -22,3 +22,4 @@ resize-right==0.0.2 ...@@ -22,3 +22,4 @@ resize-right==0.0.2
torchdiffeq==0.2.3 torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2 lark==1.1.2
inflection==0.5.1
...@@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs): ...@@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs):
if info is None: if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}") raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info) modules.sd_models.reload_model_weights(shared.sd_model, info)
p.sd_model = shared.sd_model
def confirm_checkpoints(p, xs): def confirm_checkpoints(p, xs):
......
...@@ -34,9 +34,10 @@ ...@@ -34,9 +34,10 @@
.performance { .performance {
font-size: 0.85em; font-size: 0.85em;
color: #444; color: #444;
display: flex; }
justify-content: space-between;
white-space: nowrap; .performance p{
display: inline-block;
} }
.performance .time { .performance .time {
...@@ -44,8 +45,6 @@ ...@@ -44,8 +45,6 @@
} }
.performance .vram { .performance .vram {
margin-left: 0;
text-align: right;
} }
#txt2img_generate, #img2img_generate { #txt2img_generate, #img2img_generate {
...@@ -478,7 +477,7 @@ input[type="range"]{ ...@@ -478,7 +477,7 @@ input[type="range"]{
padding: 0; padding: 0;
} }
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{ #refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_aesthetic_embeddings{
max-width: 2.5em; max-width: 2.5em;
min-width: 2.5em; min-width: 2.5em;
height: 2.4em; height: 2.4em;
......
...@@ -33,7 +33,7 @@ goto :launch ...@@ -33,7 +33,7 @@ goto :launch
:skip_venv :skip_venv
:launch :launch
%PYTHON% launch.py %PYTHON% launch.py %*
pause pause
exit /b exit /b
......
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import importlib import importlib
import signal import signal
import threading import threading
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path from modules.paths import script_path
...@@ -31,7 +31,6 @@ from modules.paths import script_path ...@@ -31,7 +31,6 @@ from modules.paths import script_path
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
queue_lock = threading.Lock() queue_lock = threading.Lock()
...@@ -87,10 +86,6 @@ def initialize(): ...@@ -87,10 +86,6 @@ def initialize():
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
def webui():
initialize()
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
...@@ -98,10 +93,38 @@ def webui(): ...@@ -98,10 +93,38 @@ def webui():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
def create_api(app):
from modules.api.api import Api
api = Api(app, queue_lock)
return api
def wait_on_server(demo=None):
while 1: while 1:
time.sleep(0.5)
if demo and getattr(demo, 'do_restart', False):
time.sleep(0.5)
demo.close()
time.sleep(0.5)
break
def api_only():
initialize()
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app)
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
def webui():
launch_api = cmd_opts.api
initialize()
while 1:
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
app, local_url, share_url = demo.launch( app, local_url, share_url = demo.launch(
share=cmd_opts.share, share=cmd_opts.share,
server_name="0.0.0.0" if cmd_opts.listen else None, server_name="0.0.0.0" if cmd_opts.listen else None,
...@@ -111,16 +134,13 @@ def webui(): ...@@ -111,16 +134,13 @@ def webui():
inbrowser=cmd_opts.autolaunch, inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True prevent_thread_lock=True
) )
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
while 1: if (launch_api):
time.sleep(0.5) create_api(app)
if getattr(demo, 'do_restart', False):
time.sleep(0.5) wait_on_server(demo)
demo.close()
time.sleep(0.5)
break
sd_samplers.set_samplers() sd_samplers.set_samplers()
...@@ -133,5 +153,10 @@ def webui(): ...@@ -133,5 +153,10 @@ def webui():
print('Restarting Gradio') print('Restarting Gradio')
task = []
if __name__ == "__main__": if __name__ == "__main__":
webui() if cmd_opts.nowebui:
api_only()
else:
webui()
...@@ -138,4 +138,4 @@ fi ...@@ -138,4 +138,4 @@ fi
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..." printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
"${python_cmd}" "${LAUNCH_SCRIPT}" "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment