diff --git a/api/onnx_web/chain/text_prompt.py b/api/onnx_web/chain/text_prompt.py index bfa16237..1dc69ee4 100644 --- a/api/onnx_web/chain/text_prompt.py +++ b/api/onnx_web/chain/text_prompt.py @@ -31,10 +31,10 @@ class TextPromptStage(BaseStage): sources: StageResult, *, callback: Optional[ProgressCallback] = None, - prompt_filter: str = "Gustavosta/MagicPrompt-Stable-Diffusion", + prompt_filter: str, remove_tokens: Optional[str] = None, add_suffix: Optional[str] = None, - min_length: int = 75, + min_length: int = 150, **kwargs, ) -> StageResult: device = worker.device.torch_str() diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index fac07cef..9e0cc6f5 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -74,6 +74,10 @@ def add_prompt_filter( pipeline.stage( TextPromptStage(), StageParams(), + prompt_filter=experimental.prompt_editing.model, + remove_tokens=experimental.prompt_editing.remove_tokens, + add_suffix=experimental.prompt_editing.add_suffix, + # TODO: add min length to experimental params ) else: logger.warning("prompt editing is not supported by the server")