1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-09-12 19:17:03 -05:00
parent 7d8819ef87
commit 0a5f725efa
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
10 changed files with 44 additions and 32 deletions

View File

@ -92,7 +92,6 @@ class ChainPipeline:
return steps
def __call__(
self,
worker: WorkerContext,

View File

@ -50,8 +50,8 @@ class BlendGridStage(BaseStage):
return [*sources, output]
def outputs(
self,
params: ImageParams,
sources: int,
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1
return sources + 1

View File

@ -105,15 +105,15 @@ class BlendImg2ImgStage(BaseStage):
return outputs
def steps(
self,
params: ImageParams,
*args,
self,
params: ImageParams,
*args,
) -> int:
return params.steps # TODO: multiply by strength
return params.steps # TODO: multiply by strength
def outputs(
self,
params: ImageParams,
sources: int,
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -44,8 +44,8 @@ class SourceNoiseStage(BaseStage):
return outputs
def outputs(
self,
params: ImageParams,
sources: int,
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1
return sources + 1

View File

@ -51,8 +51,8 @@ class SourceS3Stage(BaseStage):
return outputs
def outputs(
self,
params: ImageParams,
sources: int,
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1 # TODO: len(source_keys)
return sources + 1 # TODO: len(source_keys)

View File

@ -139,4 +139,4 @@ class SourceTxt2ImgStage(BaseStage):
params: ImageParams,
sources: int,
) -> int:
return sources + 1
return sources + 1

View File

@ -44,8 +44,8 @@ class SourceURLStage(BaseStage):
return outputs
def outputs(
self,
params: ImageParams,
sources: int,
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1
return sources + 1

View File

@ -50,7 +50,12 @@ def needs_tile(
source: Optional[Image.Image] = None,
) -> bool:
tile = min(max_tile, stage_tile)
logger.trace("checking image tile dimensions: %s, %s, %s", tile, source.width > tile or source.height > tile, size.width > tile or size.height > tile)
logger.trace(
"checking image tile dimensions: %s, %s, %s",
tile,
source.width > tile or source.height > tile,
size.width > tile or size.height > tile,
)
if source is not None:
return source.width > tile or source.height > tile

View File

@ -376,7 +376,9 @@ def load_pipeline(
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
components["vae_decoder_session"]._model_path = vae_decoder # "#\\not a real path on any system"
components[
"vae_decoder_session"
]._model_path = vae_decoder # "#\\not a real path on any system"
logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
@ -384,7 +386,9 @@ def load_pipeline(
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
components["vae_encoder_session"]._model_path = vae_encoder # "#\\not a real path on any system"
components[
"vae_encoder_session"
]._model_path = vae_encoder # "#\\not a real path on any system"
else:
logger.debug("loading VAE decoder from %s", vae_decoder)
@ -440,13 +444,13 @@ def load_pipeline(
if "vae_decoder_session" in components:
pipe.vae_decoder = ORTModelVaeDecoder(
components["vae_decoder_session"],
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
)
if "vae_encoder_session" in components:
pipe.vae_encoder = ORTModelVaeEncoder(
components["vae_encoder_session"],
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
)
if not server.show_progress:

View File

@ -387,7 +387,9 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
validate(data, schema)
# get defaults from the regular parameters
device, base_params, base_size = pipeline_from_request(server, data=data.get("defaults", None))
device, base_params, base_size = pipeline_from_request(
server, data=data.get("defaults", None)
)
# start building the pipeline
pipeline = ChainPipeline()
@ -452,7 +454,9 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
output = make_output_name(server, "chain", base_params, base_size, count=len(pipeline.stages))
output = make_output_name(
server, "chain", base_params, base_size, count=len(pipeline.stages)
)
job_name = output[0]
# build and run chain pipeline