Commit c1b0e089 authored by kurumuz's avatar kurumuz

vectoradjustprior support

parent d22bf3c7
...@@ -92,6 +92,7 @@ def init_config_model(): ...@@ -92,6 +92,7 @@ def init_config_model():
# Resolve where we get our model and data from. # Resolve where we get our model and data from.
config.model_path = os.getenv('MODEL_PATH', None) config.model_path = os.getenv('MODEL_PATH', None)
config.prior_path = os.getenv('PRIOR_PATH', None)
# Misc settings # Misc settings
config.model_alias = os.getenv('MODEL_ALIAS') config.model_alias = os.getenv('MODEL_ALIAS')
......
...@@ -99,6 +99,29 @@ def decode_image(image, model): ...@@ -99,6 +99,29 @@ def decode_image(image, model):
image = custom_to_pil(image) image = custom_to_pil(image)
return image return image
class VectorAdjustPrior(nn.Module):
def __init__(self, hidden_size, inter_dim=64):
super().__init__()
self.vector_proj = nn.Linear(hidden_size*2, inter_dim, bias=True)
self.out_proj = nn.Linear(hidden_size+inter_dim, hidden_size, bias=True)
def forward(self, z):
b, s = z.shape[0:2]
x1 = torch.mean(z, dim=1).repeat(s, 1)
x2 = z.reshape(b*s, -1)
x = torch.cat((x1, x2), dim=1)
x = self.vector_proj(x)
x = torch.cat((x2, x), dim=1)
x = self.out_proj(x)
x = x.reshape(b, s, -1)
return x
@classmethod
def load_model(cls, model_path, hidden_size=768, inter_dim=64):
model = cls(hidden_size=hidden_size, inter_dim=inter_dim)
model.load_state_dict(torch.load(model_path)["state_dict"])
return model
class StableInterface(nn.Module): class StableInterface(nn.Module):
def __init__(self, model, thresholder = None): def __init__(self, model, thresholder = None):
super().__init__() super().__init__()
...@@ -145,6 +168,8 @@ class StableDiffusionModel(nn.Module): ...@@ -145,6 +168,8 @@ class StableDiffusionModel(nn.Module):
'k_dpm_2_ancestral': K.sampling.sample_dpm_2_ancestral, 'k_dpm_2_ancestral': K.sampling.sample_dpm_2_ancestral,
'k_lms': K.sampling.sample_lms, 'k_lms': K.sampling.sample_lms,
} }
if config.prior_path:
self.prior = VectorAdjustPrior.load_model(config.prior_path, hidden_size=model_config['hidden_size'])
def from_folder(self, folder): def from_folder(self, folder):
folder = Path(folder) folder = Path(folder)
...@@ -213,6 +238,8 @@ class StableDiffusionModel(nn.Module): ...@@ -213,6 +238,8 @@ class StableDiffusionModel(nn.Module):
prompt = [request.prompt] * request.n_samples prompt = [request.prompt] * request.n_samples
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples) prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
if self.prior and request.mitigate:
prompt_condition = self.prior(prompt_condition)
uc = None uc = None
if request.scale != 1.0: if request.scale != 1.0:
...@@ -253,7 +280,7 @@ class StableDiffusionModel(nn.Module): ...@@ -253,7 +280,7 @@ class StableDiffusionModel(nn.Module):
else: else:
start_code = start_code * sigmas[0] start_code = start_code * sigmas[0]
extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale} extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale}
samples = self.sampler_map[request.sampler](self.k_model, start_code, sigmas, extra_args=extra_args) samples = self.sampler_map[request.sampler](self.k_model, start_code, sigmas, extra_args=extra_args)
......
...@@ -74,6 +74,7 @@ class GenerationRequest(BaseModel): ...@@ -74,6 +74,7 @@ class GenerationRequest(BaseModel):
stage_two_seed: int = None stage_two_seed: int = None
strength: float = 0.69 strength: float = 0.69
noise: float = 0.667 noise: float = 0.667
mitigate: bool = False
class GenerationOutput(BaseModel): class GenerationOutput(BaseModel):
output: List[str] output: List[str]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment