feat: add a way to select textual inversions
This commit is contained in:
parent
45f5fca383
commit
2e7de16778
|
@ -36,6 +36,7 @@ def blend_img2img(
|
|||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
params.inversion,
|
||||
)
|
||||
if params.lpw:
|
||||
logger.debug("using LPW pipeline for img2img")
|
||||
|
|
|
@ -75,6 +75,7 @@ def blend_inpaint(
|
|||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
params.inversion,
|
||||
)
|
||||
|
||||
if params.lpw:
|
||||
|
|
|
@ -42,6 +42,7 @@ def source_txt2img(
|
|||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
params.inversion,
|
||||
)
|
||||
|
||||
if params.lpw:
|
||||
|
|
|
@ -79,6 +79,7 @@ def upscale_outpaint(
|
|||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
params.inversion,
|
||||
)
|
||||
if params.lpw:
|
||||
logger.debug("using LPW pipeline for inpaint")
|
||||
|
|
|
@ -18,6 +18,7 @@ from diffusers import (
|
|||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
OnnxRuntimeModel,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -138,8 +139,9 @@ def load_pipeline(
|
|||
scheduler_type: Any,
|
||||
device: DeviceParams,
|
||||
lpw: bool,
|
||||
inversion: Optional[str],
|
||||
):
|
||||
pipe_key = (pipeline, model, device.device, device.provider, lpw)
|
||||
pipe_key = (pipeline, model, device.device, device.provider, lpw, inversion)
|
||||
scheduler_key = (scheduler_type, model)
|
||||
|
||||
cache_pipe = server.cache.get("diffusion", pipe_key)
|
||||
|
@ -182,6 +184,17 @@ def load_pipeline(
|
|||
sess_options=device.sess_options(),
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
text_encoder = None
|
||||
if inversion is not None:
|
||||
logger.debug("loading text encoder from %s", inversion)
|
||||
text_encoder = OnnxRuntimeModel.from_pretrained(
|
||||
inversion,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
subfolder="text_encoder",
|
||||
)
|
||||
|
||||
pipe = pipeline.from_pretrained(
|
||||
model,
|
||||
custom_pipeline=custom_pipeline,
|
||||
|
@ -190,6 +203,7 @@ def load_pipeline(
|
|||
revision="onnx",
|
||||
safety_checker=None,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
)
|
||||
|
||||
if not server.show_progress:
|
||||
|
|
|
@ -36,6 +36,7 @@ def run_txt2img_pipeline(
|
|||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
params.inversion,
|
||||
)
|
||||
progress = job.get_progress_callback()
|
||||
|
||||
|
@ -109,6 +110,7 @@ def run_img2img_pipeline(
|
|||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
params.inversion,
|
||||
)
|
||||
progress = job.get_progress_callback()
|
||||
if params.lpw:
|
||||
|
|
|
@ -19,6 +19,8 @@ logger = getLogger(__name__)
|
|||
def hash_value(sha, param: Param):
|
||||
if param is None:
|
||||
return
|
||||
elif isinstance(param, bool):
|
||||
sha.update(bytearray(pack("!B", param)))
|
||||
elif isinstance(param, float):
|
||||
sha.update(bytearray(pack("!f", param)))
|
||||
elif isinstance(param, int):
|
||||
|
@ -73,8 +75,12 @@ def make_output_name(
|
|||
hash_value(sha, params.prompt)
|
||||
hash_value(sha, params.negative_prompt)
|
||||
hash_value(sha, params.cfg)
|
||||
hash_value(sha, params.steps)
|
||||
hash_value(sha, params.seed)
|
||||
hash_value(sha, params.steps)
|
||||
hash_value(sha, params.lpw)
|
||||
hash_value(sha, params.eta)
|
||||
hash_value(sha, params.batch)
|
||||
hash_value(sha, params.inversion)
|
||||
hash_value(sha, size.width)
|
||||
hash_value(sha, size.height)
|
||||
|
||||
|
|
|
@ -157,6 +157,7 @@ class ImageParams:
|
|||
lpw: bool = False,
|
||||
eta: float = 0.0,
|
||||
batch: int = 1,
|
||||
inversion: str = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.scheduler = scheduler
|
||||
|
@ -168,6 +169,7 @@ class ImageParams:
|
|||
self.lpw = lpw or False
|
||||
self.eta = eta
|
||||
self.batch = batch
|
||||
self.inversion = inversion
|
||||
|
||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||
return {
|
||||
|
@ -181,6 +183,7 @@ class ImageParams:
|
|||
"lpw": self.lpw,
|
||||
"eta": self.eta,
|
||||
"batch": self.batch,
|
||||
"inversion": self.inversion,
|
||||
}
|
||||
|
||||
def with_args(self, **kwargs):
|
||||
|
@ -195,6 +198,7 @@ class ImageParams:
|
|||
kwargs.get("lpw", self.lpw),
|
||||
kwargs.get("eta", self.eta),
|
||||
kwargs.get("batch", self.batch),
|
||||
kwargs.get("inversion", self.inversion),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -159,6 +159,8 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
|||
scheduler = get_from_map(
|
||||
request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler")
|
||||
)
|
||||
inversion = get_not_empty(request.args, "inversion", get_config_value("inversion"))
|
||||
inversion_path = get_model_path(inversion)
|
||||
|
||||
# image params
|
||||
prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
|
||||
|
@ -240,6 +242,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
|||
lpw=lpw,
|
||||
negative_prompt=negative_prompt,
|
||||
batch=batch,
|
||||
inversion=inversion_path,
|
||||
)
|
||||
size = Size(width, height)
|
||||
return (device, params, size)
|
||||
|
|
|
@ -32,6 +32,8 @@ export interface ModelParams {
|
|||
* Use the long prompt weighting pipeline.
|
||||
*/
|
||||
lpw: boolean;
|
||||
|
||||
inversion: string;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -183,6 +185,7 @@ export interface ReadyResponse {
|
|||
export interface ModelsResponse {
|
||||
diffusion: Array<string>;
|
||||
correction: Array<string>;
|
||||
inversion: Array<string>;
|
||||
upscaling: Array<string>;
|
||||
}
|
||||
|
||||
|
@ -325,6 +328,7 @@ export function appendModelToURL(url: URL, params: ModelParams) {
|
|||
url.searchParams.append('upscaling', params.upscaling);
|
||||
url.searchParams.append('correction', params.correction);
|
||||
url.searchParams.append('lpw', String(params.lpw));
|
||||
url.searchParams.append('inversion', params.inversion);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -54,6 +54,21 @@ export function ModelControl() {
|
|||
});
|
||||
}}
|
||||
/>
|
||||
<QueryList
|
||||
id='inversion'
|
||||
labels={MODEL_LABELS}
|
||||
name='Textual Inversion'
|
||||
query={{
|
||||
result: models,
|
||||
selector: (result) => result.inversion,
|
||||
}}
|
||||
value={params.inversion}
|
||||
onChange={(inversion) => {
|
||||
setModel({
|
||||
inversion,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<QueryList
|
||||
id='upscaling'
|
||||
labels={MODEL_LABELS}
|
||||
|
|
|
@ -487,6 +487,7 @@ export function createStateSlices(server: ServerParams) {
|
|||
platform: server.platform.default,
|
||||
upscaling: server.upscaling.default,
|
||||
correction: server.correction.default,
|
||||
inversion: server.inversion.default,
|
||||
lpw: false,
|
||||
},
|
||||
setModel(params) {
|
||||
|
|
Loading…
Reference in New Issue