Commit 17e92f46 authored by kurumuz's avatar kurumuz

vanilla check

parent f077911c
......@@ -213,6 +213,10 @@ class StableDiffusionModel(nn.Module):
@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def sample(self, request):
if request.module is not None:
if request.module == "vanilla":
pass
else:
module = self.premodules[request.module]
CrossAttention.set_hypernetwork(module)
......
......@@ -78,6 +78,7 @@ class GenerationRequest(BaseModel):
noise: float = 0.667
mitigate: bool = False
module: str = None
masks: List[str] = None
class GenerationOutput(BaseModel):
output: List[str]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment