Commit 0f28aee9 authored by Aarni Koskela's avatar Aarni Koskela

Refactor gradio auth

parent 674e80c6
...@@ -7,6 +7,7 @@ import re ...@@ -7,6 +7,7 @@ import re
import warnings import warnings
import json import json
from threading import Thread from threading import Thread
from typing import Iterable
from fastapi import FastAPI, Response from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
...@@ -178,6 +179,32 @@ def validate_tls_options(): ...@@ -178,6 +179,32 @@ def validate_tls_options():
startup_timer.record("TLS") startup_timer.record("TLS")
def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
"""
Convert the gradio_auth and gradio_auth_path commandline arguments into
an iterable of (username, password) tuples.
"""
def process_credential_line(s) -> tuple[str, ...] | None:
s = s.strip()
if not s:
return None
return tuple(s.split(':', 1))
if cmd_opts.gradio_auth:
for cred in cmd_opts.gradio_auth.split(','):
cred = process_credential_line(cred)
if cred:
yield cred
if cmd_opts.gradio_auth_path:
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
for cred in line.strip().split(','):
cred = process_credential_line(cred)
if cred:
yield cred
def configure_sigint_handler(): def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
...@@ -316,13 +343,7 @@ def webui(): ...@@ -316,13 +343,7 @@ def webui():
if not cmd_opts.no_gradio_queue: if not cmd_opts.no_gradio_queue:
shared.demo.queue(64) shared.demo.queue(64)
gradio_auth_creds = [] gradio_auth_creds = list(get_gradio_auth_creds()) or None
if cmd_opts.gradio_auth:
gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
if cmd_opts.gradio_auth_path:
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
# this restores the missing /docs endpoint # this restores the missing /docs endpoint
if launch_api and not hasattr(FastAPI, 'original_setup'): if launch_api and not hasattr(FastAPI, 'original_setup'):
...@@ -343,7 +364,7 @@ def webui(): ...@@ -343,7 +364,7 @@ def webui():
ssl_certfile=cmd_opts.tls_certfile, ssl_certfile=cmd_opts.tls_certfile,
ssl_verify=cmd_opts.disable_tls_verify, ssl_verify=cmd_opts.disable_tls_verify,
debug=cmd_opts.gradio_debug, debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None, auth=gradio_auth_creds,
inbrowser=cmd_opts.autolaunch, inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True, prevent_thread_lock=True,
allowed_paths=cmd_opts.gradio_allowed_path, allowed_paths=cmd_opts.gradio_allowed_path,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment