Commit e23f6880 authored by kurumuz's avatar kurumuz

s3 maybe works

parent be2c3628
...@@ -140,9 +140,17 @@ def init_config_model(): ...@@ -140,9 +140,17 @@ def init_config_model():
logger.info(f"MODEL: {config.model_name}") logger.info(f"MODEL: {config.model_name}")
#S3 auth stuff #S3 auth stuff
config.s3_access_key = os.environ['S3_ACCESS_KEY'] config.s3_access_key = os.environ('S3_ACCESS_KEY', None)
config.s3_secret_key = os.environ['S3_SECRET_KEY'] config.s3_secret_key = os.environ('S3_SECRET_KEY', None)
config.s3_bucket = os.environ['S3_BUCKET'] config.s3_bucket = os.environ('S3_BUCKET', None)
config.s3_file = os.environ('S3_FILE', None)
config.s3_endpoint = os.environ('S3_ENDPOINT', None)
try:
config.model_path = f"https://{config.s3_endpoint}/{config.s3_bucket}/{config.s3_file}"
logger.info(f"Path is set to S3 {config.model_path}")
except:
pass
# Resolve where we get our model and data from. # Resolve where we get our model and data from.
config.model_path = os.getenv('MODEL_PATH', None) config.model_path = os.getenv('MODEL_PATH', None)
......
...@@ -342,7 +342,12 @@ class StableDiffusionModel(nn.Module): ...@@ -342,7 +342,12 @@ class StableDiffusionModel(nn.Module):
model_config = requests.get(default_config, stream='True').raw model_config = requests.get(default_config, stream='True').raw
model_config = OmegaConf.load(model_config) model_config = OmegaConf.load(model_config)
print(f"Downloading model from {url}") print(f"Downloading model from {url}")
tensor_loader = web.CURLStreamFile(url) headers = web.get_s3_secret_headers(endpoint=self.config.s3_endpoint,
access_key=self.config.s3_access_key,
secret_key=self.config.s3_secret_key,
s3_file=self.config.s3_file
)
tensor_loader = web.CURLStreamFile(url, headers=headers)
if not default_config.is_file(): if not default_config.is_file():
raise Exception("Default config to load not found! Either give a folder on MODEL_PATH or specify a config to use with this checkpoint on DEFAULT_CONFIG") raise Exception("Default config to load not found! Either give a folder on MODEL_PATH or specify a config to use with this checkpoint on DEFAULT_CONFIG")
model_config = OmegaConf.load(default_config) model_config = OmegaConf.load(default_config)
......
...@@ -92,8 +92,6 @@ class CURLS3StreamFile(object): ...@@ -92,8 +92,6 @@ class CURLS3StreamFile(object):
subprocess_list.append(url) subprocess_list.append(url)
subprocess_list.extend(['-H', 'Accept-Encoding: identity']) subprocess_list.extend(['-H', 'Accept-Encoding: identity'])
subprocess_list.append('-s') subprocess_list.append('-s')
#subprocess_list.append('-o test124.yaml')
print(subprocess_list)
self._curl = subprocess.Popen(subprocess_list, self._curl = subprocess.Popen(subprocess_list,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment