apply lint
This commit is contained in:
parent
7d8819ef87
commit
0a5f725efa
|
@ -92,7 +92,6 @@ class ChainPipeline:
|
|||
|
||||
return steps
|
||||
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
worker: WorkerContext,
|
||||
|
|
|
@ -50,8 +50,8 @@ class BlendGridStage(BaseStage):
|
|||
return [*sources, output]
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
||||
return sources + 1
|
||||
|
|
|
@ -105,15 +105,15 @@ class BlendImg2ImgStage(BaseStage):
|
|||
return outputs
|
||||
|
||||
def steps(
|
||||
self,
|
||||
params: ImageParams,
|
||||
*args,
|
||||
self,
|
||||
params: ImageParams,
|
||||
*args,
|
||||
) -> int:
|
||||
return params.steps # TODO: multiply by strength
|
||||
return params.steps # TODO: multiply by strength
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
||||
|
|
|
@ -44,8 +44,8 @@ class SourceNoiseStage(BaseStage):
|
|||
return outputs
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
||||
return sources + 1
|
||||
|
|
|
@ -51,8 +51,8 @@ class SourceS3Stage(BaseStage):
|
|||
return outputs
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1 # TODO: len(source_keys)
|
||||
return sources + 1 # TODO: len(source_keys)
|
||||
|
|
|
@ -139,4 +139,4 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
||||
return sources + 1
|
||||
|
|
|
@ -44,8 +44,8 @@ class SourceURLStage(BaseStage):
|
|||
return outputs
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
||||
return sources + 1
|
||||
|
|
|
@ -50,7 +50,12 @@ def needs_tile(
|
|||
source: Optional[Image.Image] = None,
|
||||
) -> bool:
|
||||
tile = min(max_tile, stage_tile)
|
||||
logger.trace("checking image tile dimensions: %s, %s, %s", tile, source.width > tile or source.height > tile, size.width > tile or size.height > tile)
|
||||
logger.trace(
|
||||
"checking image tile dimensions: %s, %s, %s",
|
||||
tile,
|
||||
source.width > tile or source.height > tile,
|
||||
size.width > tile or size.height > tile,
|
||||
)
|
||||
|
||||
if source is not None:
|
||||
return source.width > tile or source.height > tile
|
||||
|
|
|
@ -376,7 +376,9 @@ def load_pipeline(
|
|||
provider=device.ort_provider("vae"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
components["vae_decoder_session"]._model_path = vae_decoder # "#\\not a real path on any system"
|
||||
components[
|
||||
"vae_decoder_session"
|
||||
]._model_path = vae_decoder # "#\\not a real path on any system"
|
||||
|
||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
||||
|
@ -384,7 +386,9 @@ def load_pipeline(
|
|||
provider=device.ort_provider("vae"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
components["vae_encoder_session"]._model_path = vae_encoder # "#\\not a real path on any system"
|
||||
components[
|
||||
"vae_encoder_session"
|
||||
]._model_path = vae_encoder # "#\\not a real path on any system"
|
||||
|
||||
else:
|
||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||
|
@ -440,13 +444,13 @@ def load_pipeline(
|
|||
if "vae_decoder_session" in components:
|
||||
pipe.vae_decoder = ORTModelVaeDecoder(
|
||||
components["vae_decoder_session"],
|
||||
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
|
||||
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
|
||||
)
|
||||
|
||||
if "vae_encoder_session" in components:
|
||||
pipe.vae_encoder = ORTModelVaeEncoder(
|
||||
components["vae_encoder_session"],
|
||||
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
|
||||
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
|
||||
)
|
||||
|
||||
if not server.show_progress:
|
||||
|
|
|
@ -387,7 +387,9 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
validate(data, schema)
|
||||
|
||||
# get defaults from the regular parameters
|
||||
device, base_params, base_size = pipeline_from_request(server, data=data.get("defaults", None))
|
||||
device, base_params, base_size = pipeline_from_request(
|
||||
server, data=data.get("defaults", None)
|
||||
)
|
||||
|
||||
# start building the pipeline
|
||||
pipeline = ChainPipeline()
|
||||
|
@ -452,7 +454,9 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||
|
||||
output = make_output_name(server, "chain", base_params, base_size, count=len(pipeline.stages))
|
||||
output = make_output_name(
|
||||
server, "chain", base_params, base_size, count=len(pipeline.stages)
|
||||
)
|
||||
job_name = output[0]
|
||||
|
||||
# build and run chain pipeline
|
||||
|
|
Loading…
Reference in New Issue