Commit 5524301a authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #9169 from space-nuko/extension-settings-backup

Extension settings backup/restore feature
parents c018eefe 78d0ee3b
...@@ -33,3 +33,4 @@ notification.mp3 ...@@ -33,3 +33,4 @@ notification.mp3
/test/stdout.txt /test/stdout.txt
/test/stderr.txt /test/stderr.txt
/cache.json* /cache.json*
/config_states/
...@@ -47,3 +47,25 @@ function install_extension_from_index(button, url){ ...@@ -47,3 +47,25 @@ function install_extension_from_index(button, url){
gradioApp().querySelector('#install_extension_button').click() gradioApp().querySelector('#install_extension_button').click()
} }
function config_state_confirm_restore(_, config_state_name, config_restore_type) {
if (config_state_name == "Current") {
return [false, config_state_name, config_restore_type];
}
let restored = "";
if (config_restore_type == "extensions") {
restored = "all saved extension versions";
} else if (config_restore_type == "webui") {
restored = "the webui version";
} else {
restored = "the webui version and all saved extension versions";
}
let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
if (confirmed) {
restart_reload();
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
x.innerHTML = "Loading..."
})
}
return [confirmed, config_state_name, config_restore_type];
}
"""
Supports saving and restoring webui and extensions from a known working set of commits
"""
import os
import sys
import traceback
import json
import time
import tqdm
from datetime import datetime
from collections import OrderedDict
import git
from modules import shared, extensions
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir
all_config_states = OrderedDict()
def list_config_states():
global all_config_states
all_config_states.clear()
os.makedirs(config_states_dir, exist_ok=True)
config_states = []
for filename in os.listdir(config_states_dir):
if filename.endswith(".json"):
path = os.path.join(config_states_dir, filename)
with open(path, "r", encoding="utf-8") as f:
j = json.load(f)
j["filepath"] = path
config_states.append(j)
config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))
for cs in config_states:
timestamp = time.asctime(time.gmtime(cs["created_at"]))
name = cs.get("name", "Config")
full_name = f"{name}: {timestamp}"
all_config_states[full_name] = cs
return all_config_states
def get_webui_config():
webui_repo = None
try:
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
webui_remote = None
webui_commit_hash = None
webui_commit_date = None
webui_branch = None
if webui_repo and not webui_repo.bare:
try:
webui_remote = next(webui_repo.remote().urls, None)
head = webui_repo.head.commit
webui_commit_date = webui_repo.head.commit.committed_date
webui_commit_hash = head.hexsha
webui_branch = webui_repo.active_branch.name
except Exception:
webui_remote = None
return {
"remote": webui_remote,
"commit_hash": webui_commit_hash,
"commit_date": webui_commit_date,
"branch": webui_branch,
}
def get_extension_config():
ext_config = {}
for ext in extensions.extensions:
entry = {
"name": ext.name,
"path": ext.path,
"enabled": ext.enabled,
"is_builtin": ext.is_builtin,
"remote": ext.remote,
"commit_hash": ext.commit_hash,
"commit_date": ext.commit_date,
"branch": ext.branch,
"have_info_from_repo": ext.have_info_from_repo
}
ext_config[ext.name] = entry
return ext_config
def get_config():
creation_time = datetime.now().timestamp()
webui_config = get_webui_config()
ext_config = get_extension_config()
return {
"created_at": creation_time,
"webui": webui_config,
"extensions": ext_config
}
def restore_webui_config(config):
print("* Restoring webui state...")
if "webui" not in config:
print("Error: No webui data saved to config")
return
webui_config = config["webui"]
if "commit_hash" not in webui_config:
print("Error: No commit saved to webui config")
return
webui_commit_hash = webui_config.get("commit_hash", None)
webui_repo = None
try:
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return
try:
webui_repo.git.fetch(all=True)
webui_repo.git.reset(webui_commit_hash, hard=True)
print(f"* Restored webui to commit {webui_commit_hash}.")
except Exception:
print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def restore_extension_config(config):
print("* Restoring extension state...")
if "extensions" not in config:
print("Error: No extension data saved to config")
return
ext_config = config["extensions"]
results = []
disabled = []
for ext in tqdm.tqdm(extensions.extensions):
if ext.is_builtin:
continue
ext.read_info_from_repo()
current_commit = ext.commit_hash
if ext.name not in ext_config:
ext.disabled = True
disabled.append(ext.name)
results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
continue
entry = ext_config[ext.name]
if "commit_hash" in entry and entry["commit_hash"]:
try:
ext.fetch_and_reset_hard(entry["commit_hash"])
ext.read_info_from_repo()
if current_commit != entry["commit_hash"]:
results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
except Exception as ex:
results.append((ext, current_commit[:8], False, ex))
else:
results.append((ext, current_commit[:8], False, "No commit hash found in config"))
if not entry.get("enabled", False):
ext.disabled = True
disabled.append(ext.name)
else:
ext.disabled = False
shared.opts.disabled_extensions = disabled
shared.opts.save(shared.config_filename)
print("* Finished restoring extensions. Results:")
for ext, prev_commit, success, result in results:
if success:
print(f" + {ext.name}: {prev_commit} -> {result}")
else:
print(f" ! {ext.name}: FAILURE ({result})")
...@@ -3,10 +3,11 @@ import sys ...@@ -3,10 +3,11 @@ import sys
import traceback import traceback
import time import time
from datetime import datetime
import git import git
from modules import shared from modules import shared
from modules.paths_internal import extensions_dir, extensions_builtin_dir from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
extensions = [] extensions = []
...@@ -31,12 +32,15 @@ class Extension: ...@@ -31,12 +32,15 @@ class Extension:
self.status = '' self.status = ''
self.can_update = False self.can_update = False
self.is_builtin = is_builtin self.is_builtin = is_builtin
self.commit_hash = ''
self.commit_date = None
self.version = '' self.version = ''
self.branch = None
self.remote = None self.remote = None
self.have_info_from_repo = False self.have_info_from_repo = False
def read_info_from_repo(self): def read_info_from_repo(self):
if self.have_info_from_repo: if self.is_builtin or self.have_info_from_repo:
return return
self.have_info_from_repo = True self.have_info_from_repo = True
...@@ -56,10 +60,15 @@ class Extension: ...@@ -56,10 +60,15 @@ class Extension:
self.status = 'unknown' self.status = 'unknown'
self.remote = next(repo.remote().urls, None) self.remote = next(repo.remote().urls, None)
head = repo.head.commit head = repo.head.commit
ts = time.asctime(time.gmtime(repo.head.commit.committed_date)) self.commit_date = repo.head.commit.committed_date
self.version = f'{head.hexsha[:8]} ({ts})' ts = time.asctime(time.gmtime(self.commit_date))
if repo.active_branch:
except Exception: self.branch = repo.active_branch.name
self.commit_hash = head.hexsha
self.version = f'{self.commit_hash[:8]} ({ts})'
except Exception as ex:
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
self.remote = None self.remote = None
def list_files(self, subdir, extension): def list_files(self, subdir, extension):
...@@ -82,18 +91,30 @@ class Extension: ...@@ -82,18 +91,30 @@ class Extension:
for fetch in repo.remote().fetch(dry_run=True): for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE: if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True self.can_update = True
self.status = "behind" self.status = "new commits"
return
try:
origin = repo.rev_parse('origin')
if repo.head.commit != origin:
self.can_update = True
self.status = "behind HEAD"
return return
except Exception:
self.can_update = False
self.status = "unknown (remote error)"
return
self.can_update = False self.can_update = False
self.status = "latest" self.status = "latest"
def fetch_and_reset_hard(self): def fetch_and_reset_hard(self, commit='origin'):
repo = git.Repo(self.path) repo = git.Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`, # Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error. # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch(all=True) repo.git.fetch(all=True)
repo.git.reset('origin', hard=True) repo.git.reset(commit, hard=True)
self.have_info_from_repo = False
def list_extensions(): def list_extensions():
......
...@@ -20,3 +20,4 @@ data_path = cmd_opts_pre.data_dir ...@@ -20,3 +20,4 @@ data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models") models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions") extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
config_states_dir = os.path.join(script_path, "config_states")
...@@ -449,6 +449,7 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), { ...@@ -449,6 +449,7 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
options_templates.update(options_section((None, "Hidden options"), { options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable these extensions"), "disabled_extensions": OptionInfo([], "Disable these extensions"),
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}), "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
})) }))
......
This diff is collapsed.
...@@ -5,6 +5,7 @@ import importlib ...@@ -5,6 +5,7 @@ import importlib
import signal import signal
import re import re
import warnings import warnings
import json
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
...@@ -40,7 +41,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: ...@@ -40,7 +41,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__ torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.face_restoration import modules.face_restoration
import modules.gfpgan_model as gfpgan import modules.gfpgan_model as gfpgan
...@@ -150,6 +151,19 @@ def initialize(): ...@@ -150,6 +151,19 @@ def initialize():
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
startup_timer.record("list extensions") startup_timer.record("list extensions")
config_state_file = shared.opts.restore_config_state_file
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
if os.path.isfile(config_state_file):
print(f"*** About to restore extension state from file: {config_state_file}")
with open(config_state_file, "r", encoding="utf-8") as f:
config_state = json.load(f)
config_states.restore_extension_config(config_state)
startup_timer.record("restore extension config")
elif config_state_file:
print(f"!!! Config state backup not found: {config_state_file}")
if cmd_opts.ui_debug_mode: if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts() modules.scripts.load_scripts()
...@@ -344,6 +358,19 @@ def webui(): ...@@ -344,6 +358,19 @@ def webui():
extensions.list_extensions() extensions.list_extensions()
startup_timer.record("list extensions") startup_timer.record("list extensions")
config_state_file = shared.opts.restore_config_state_file
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
if os.path.isfile(config_state_file):
print(f"*** About to restore extension state from file: {config_state_file}")
with open(config_state_file, "r", encoding="utf-8") as f:
config_state = json.load(f)
config_states.restore_extension_config(config_state)
startup_timer.record("restore extension config")
elif config_state_file:
print(f"!!! Config state backup not found: {config_state_file}")
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
modelloader.forbid_loaded_nonbuiltin_upscalers() modelloader.forbid_loaded_nonbuiltin_upscalers()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment