Commit 17e92f46 authored by kurumuz's avatar kurumuz

vanilla check

parent f077911c
...@@ -213,6 +213,10 @@ class StableDiffusionModel(nn.Module): ...@@ -213,6 +213,10 @@ class StableDiffusionModel(nn.Module):
@torch.autocast("cuda", enabled=True, dtype=torch.float16) @torch.autocast("cuda", enabled=True, dtype=torch.float16)
def sample(self, request): def sample(self, request):
if request.module is not None: if request.module is not None:
if request.module == "vanilla":
pass
else:
module = self.premodules[request.module] module = self.premodules[request.module]
CrossAttention.set_hypernetwork(module) CrossAttention.set_hypernetwork(module)
......
...@@ -78,6 +78,7 @@ class GenerationRequest(BaseModel): ...@@ -78,6 +78,7 @@ class GenerationRequest(BaseModel):
noise: float = 0.667 noise: float = 0.667
mitigate: bool = False mitigate: bool = False
module: str = None module: str = None
masks: List[str] = None
class GenerationOutput(BaseModel): class GenerationOutput(BaseModel):
output: List[str] output: List[str]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment