From 7b0095a66583d34bbfa61f0a7d70fa5ead1f5c5e Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 22 Apr 2023 10:05:58 -0500 Subject: [PATCH] feat(api): add support for negative embeds (#348) --- api/onnx_web/diffusers/run.py | 9 +++++++++ api/onnx_web/params.py | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 12587815..75af59bd 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -35,6 +35,15 @@ def parse_prompt(params: ImageParams) -> Tuple[List[Tuple[str, float]], List[Tup prompt, inversions = get_inversions_from_prompt(prompt) params.prompt = prompt + if params.input_negative_prompt is not None: + neg_prompt, neg_loras = get_loras_from_prompt(params.input_negative_prompt) + neg_prompt, neg_inversions = get_inversions_from_prompt(neg_prompt) + params.negative_prompt = neg_prompt + + # TODO: check whether these need to be * -1 + loras.extend(neg_loras) + inversions.extend(neg_inversions) + return loras, inversions diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index ac2c0645..eaaa774c 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -174,6 +174,7 @@ class ImageParams: batch: int control: Optional[NetworkModel] input_prompt: str + input_negative_prompt: str def __init__( self, @@ -189,6 +190,7 @@ class ImageParams: batch: int = 1, control: Optional[NetworkModel] = None, input_prompt: Optional[str] = None, + input_negative_prompt: Optional[str] = None, ) -> None: self.model = model self.pipeline = pipeline @@ -202,6 +204,7 @@ class ImageParams: self.batch = batch self.control = control self.input_prompt = input_prompt or prompt + self.input_negative_prompt = input_negative_prompt or negative_prompt def lpw(self): return self.pipeline == "lpw" @@ -220,6 +223,7 @@ class ImageParams: "batch": self.batch, "control": self.control.name if self.control is not None else "", "input_prompt": self.input_prompt, + "input_negative_prompt": self.input_negative_prompt, } def with_args(self, **kwargs): @@ -236,6 +240,7 @@ class ImageParams: kwargs.get("batch", self.batch), kwargs.get("control", self.control), kwargs.get("input_prompt", self.input_prompt), + kwargs.get("input_negative_prompt", self.input_negative_prompt), )