pass prompt editing params to stage
This commit is contained in:
parent
0bbdc3b96b
commit
cdef20ffb6
|
@ -31,10 +31,10 @@ class TextPromptStage(BaseStage):
|
||||||
sources: StageResult,
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
prompt_filter: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
prompt_filter: str,
|
||||||
remove_tokens: Optional[str] = None,
|
remove_tokens: Optional[str] = None,
|
||||||
add_suffix: Optional[str] = None,
|
add_suffix: Optional[str] = None,
|
||||||
min_length: int = 75,
|
min_length: int = 150,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
device = worker.device.torch_str()
|
device = worker.device.torch_str()
|
||||||
|
|
|
@ -74,6 +74,10 @@ def add_prompt_filter(
|
||||||
pipeline.stage(
|
pipeline.stage(
|
||||||
TextPromptStage(),
|
TextPromptStage(),
|
||||||
StageParams(),
|
StageParams(),
|
||||||
|
prompt_filter=experimental.prompt_editing.model,
|
||||||
|
remove_tokens=experimental.prompt_editing.remove_tokens,
|
||||||
|
add_suffix=experimental.prompt_editing.add_suffix,
|
||||||
|
# TODO: add min length to experimental params
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("prompt editing is not supported by the server")
|
logger.warning("prompt editing is not supported by the server")
|
||||||
|
|
Loading…
Reference in New Issue