1
0
Fork 0

pass prompt editing params to stage

This commit is contained in:
Sean Sube 2024-02-17 15:54:11 -06:00
parent 0bbdc3b96b
commit cdef20ffb6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 6 additions and 2 deletions

View File

@ -31,10 +31,10 @@ class TextPromptStage(BaseStage):
sources: StageResult, sources: StageResult,
*, *,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
prompt_filter: str = "Gustavosta/MagicPrompt-Stable-Diffusion", prompt_filter: str,
remove_tokens: Optional[str] = None, remove_tokens: Optional[str] = None,
add_suffix: Optional[str] = None, add_suffix: Optional[str] = None,
min_length: int = 75, min_length: int = 150,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
device = worker.device.torch_str() device = worker.device.torch_str()

View File

@ -74,6 +74,10 @@ def add_prompt_filter(
pipeline.stage( pipeline.stage(
TextPromptStage(), TextPromptStage(),
StageParams(), StageParams(),
prompt_filter=experimental.prompt_editing.model,
remove_tokens=experimental.prompt_editing.remove_tokens,
add_suffix=experimental.prompt_editing.add_suffix,
# TODO: add min length to experimental params
) )
else: else:
logger.warning("prompt editing is not supported by the server") logger.warning("prompt editing is not supported by the server")