1
0
Fork 0

lint(api): use more accurate worker name throughout

This commit is contained in:
Sean Sube 2023-07-15 18:54:54 -05:00
parent ada482c183
commit b31f546516
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
23 changed files with 63 additions and 63 deletions

View File

@ -70,14 +70,14 @@ class ChainPipeline:
def run(
self,
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
sources: List[Image.Image],
callback: Optional[ProgressCallback],
**kwargs
) -> List[Image.Image]:
return self(job, server, params, sources=sources, callback=callback, **kwargs)
return self(worker, server, params, sources=sources, callback=callback, **kwargs)
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
@ -85,7 +85,7 @@ class ChainPipeline:
def __call__(
self,
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
sources: List[Image.Image],
@ -151,10 +151,10 @@ class ChainPipeline:
tile_mask: Image.Image,
dims: Tuple[int, int, int],
) -> Image.Image:
for i in range(job.retries):
for i in range(worker.retries):
try:
output_tile = stage_pipe.run(
job,
worker,
server,
stage_params,
params,
@ -175,8 +175,8 @@ class ChainPipeline:
i,
)
server.cache.clear()
run_gc([job.get_device()])
job.retries = job.retries - (i + 1)
run_gc([worker.get_device()])
worker.retries = worker.retries - (i + 1)
raise RetryException("exhausted retries on tile")
@ -193,10 +193,10 @@ class ChainPipeline:
stage_sources = stage_outputs
else:
logger.debug("image within tile size of %s, running stage", tile)
for i in range(job.retries):
for i in range(worker.retries):
try:
stage_outputs = stage_pipe.run(
job,
worker,
server,
stage_params,
params,
@ -213,10 +213,10 @@ class ChainPipeline:
"error while running stage pipeline, retry %s of 3", i
)
server.cache.clear()
run_gc([job.get_device()])
job.retries = job.retries - (i + 1)
run_gc([worker.get_device()])
worker.retries = worker.retries - (i + 1)
if job.retries <= 0:
if worker.retries <= 0:
raise RetryException("exhausted retries on stage")
logger.debug(

View File

@ -20,7 +20,7 @@ class BlendImg2ImgStage(BaseStage):
def run(
self,
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
@ -45,7 +45,7 @@ class BlendImg2ImgStage(BaseStage):
server,
params,
pipe_type,
job.get_device(),
worker.get_device(),
inversions=inversions,
loras=loras,
)

View File

@ -14,7 +14,7 @@ logger = getLogger(__name__)
class BlendLinearStage(BaseStage):
def run(
self,
_job: WorkerContext,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,

View File

@ -16,7 +16,7 @@ logger = getLogger(__name__)
class BlendMaskStage(BaseStage):
def run(
self,
_job: WorkerContext,
_worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,

View File

@ -14,7 +14,7 @@ logger = getLogger(__name__)
class CorrectCodeformerStage(BaseStage):
def run(
self,
job: WorkerContext,
worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
@ -30,6 +30,6 @@ class CorrectCodeformerStage(BaseStage):
upscale = upscale.with_args(**kwargs)
device = job.get_device()
device = worker.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return [pipe(source) for source in sources]

View File

@ -15,7 +15,7 @@ logger = getLogger(__name__)
class PersistDiskStage(BaseStage):
def run(
self,
_job: WorkerContext,
_worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,

View File

@ -16,7 +16,7 @@ logger = getLogger(__name__)
class PersistS3Stage(BaseStage):
def run(
self,
_job: WorkerContext,
_worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,

View File

@ -14,7 +14,7 @@ logger = getLogger(__name__)
class ReduceCropStage(BaseStage):
def run(
self,
_job: WorkerContext,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,

View File

@ -14,7 +14,7 @@ logger = getLogger(__name__)
class ReduceThumbnailStage(BaseStage):
def run(
self,
_job: WorkerContext,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,

View File

@ -14,7 +14,7 @@ logger = getLogger(__name__)
class SourceNoiseStage(BaseStage):
def run(
self,
_job: WorkerContext,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,

View File

@ -16,7 +16,7 @@ logger = getLogger(__name__)
class SourceS3Stage(BaseStage):
def run(
self,
_job: WorkerContext,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,

View File

@ -25,7 +25,7 @@ class SourceTxt2ImgStage(BaseStage):
def run(
self,
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
@ -67,7 +67,7 @@ class SourceTxt2ImgStage(BaseStage):
server,
params,
pipe_type,
job.get_device(),
worker.get_device(),
inversions=inversions,
loras=loras,
)

View File

@ -16,7 +16,7 @@ logger = getLogger(__name__)
class SourceURLStage(BaseStage):
def run(
self,
_job: WorkerContext,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,

View File

@ -12,7 +12,7 @@ class BaseStage:
def run(
self,
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,

View File

@ -50,7 +50,7 @@ class UpscaleBSRGANStage(BaseStage):
def run(
self,
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
@ -67,7 +67,7 @@ class UpscaleBSRGANStage(BaseStage):
return sources
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
device = job.get_device()
device = worker.get_device()
bsrgan = self.load(server, stage, upscale, device)
outputs = []

View File

@ -16,7 +16,7 @@ logger = getLogger(__name__)
class UpscaleHighresStage(BaseStage):
def run(
self,
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
@ -35,7 +35,7 @@ class UpscaleHighresStage(BaseStage):
return [
chain(
job,
worker,
server,
params,
source,

View File

@ -28,7 +28,7 @@ class UpscaleOutpaintStage(BaseStage):
def run(
self,
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
@ -55,7 +55,7 @@ class UpscaleOutpaintStage(BaseStage):
server,
params,
pipe_type,
job.get_device(),
worker.get_device(),
inversions=inversions,
loras=loras,
)

View File

@ -73,7 +73,7 @@ class UpscaleRealESRGANStage(BaseStage):
def run(
self,
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
@ -89,7 +89,7 @@ class UpscaleRealESRGANStage(BaseStage):
for source in sources:
output = np.array(source)
upsampler = self.load(
server, upscale, job.get_device(), tile=stage.tile_size
server, upscale, worker.get_device(), tile=stage.tile_size
)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)

View File

@ -14,7 +14,7 @@ logger = getLogger(__name__)
class UpscaleSimpleStage(BaseStage):
def run(
self,
_job: WorkerContext,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,

View File

@ -18,7 +18,7 @@ logger = getLogger(__name__)
class UpscaleStableDiffusionStage(BaseStage):
def run(
self,
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
@ -43,7 +43,7 @@ class UpscaleStableDiffusionStage(BaseStage):
server,
params,
"upscale",
job.get_device(),
worker.get_device(),
model=path.join(server.model_path, upscale.upscale_model),
)
generator = torch.manual_seed(params.seed)

View File

@ -50,7 +50,7 @@ class UpscaleSwinIRStage(BaseStage):
def run(
self,
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
@ -67,7 +67,7 @@ class UpscaleSwinIRStage(BaseStage):
return sources
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
device = job.get_device()
device = worker.get_device()
swinir = self.load(server, stage, upscale, device)
outputs = []

View File

@ -34,7 +34,7 @@ logger = getLogger(__name__)
def run_txt2img_pipeline(
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
@ -83,8 +83,8 @@ def run_txt2img_pipeline(
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = job.get_progress_callback()
images = chain.run(job, server, params, [], callback=progress, latents=latents)
progress = worker.get_progress_callback()
images = chain.run(worker, server, params, [], callback=progress, latents=latents)
_pairs, loras, inversions, _rest = parse_prompt(params)
@ -102,7 +102,7 @@ def run_txt2img_pipeline(
)
# clean up
run_gc([job.get_device()])
run_gc([worker.get_device()])
# notify the user
show_system_toast(f"finished txt2img job: {dest}")
@ -110,7 +110,7 @@ def run_txt2img_pipeline(
def run_img2img_pipeline(
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
outputs: List[str],
@ -175,8 +175,8 @@ def run_img2img_pipeline(
)
# run and append the filtered source
progress = job.get_progress_callback()
images = chain(job, server, params, [source], callback=progress)
progress = worker.get_progress_callback()
images = chain(worker, server, params, [source], callback=progress)
if source_filter is not None and source_filter != "none":
images.append(source)
@ -199,7 +199,7 @@ def run_img2img_pipeline(
)
# clean up
run_gc([job.get_device()])
run_gc([worker.get_device()])
# notify the user
show_system_toast(f"finished img2img job: {dest}")
@ -207,7 +207,7 @@ def run_img2img_pipeline(
def run_inpaint_pipeline(
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
@ -353,8 +353,8 @@ def run_inpaint_pipeline(
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = job.get_progress_callback()
images = chain(job, server, params, [source], callback=progress, latents=latents)
progress = worker.get_progress_callback()
images = chain(worker, server, params, [source], callback=progress, latents=latents)
_pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs):
@ -378,7 +378,7 @@ def run_inpaint_pipeline(
# clean up
del image
run_gc([job.get_device()])
run_gc([worker.get_device()])
# notify the user
show_system_toast(f"finished inpaint job: {dest}")
@ -386,7 +386,7 @@ def run_inpaint_pipeline(
def run_upscale_pipeline(
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
@ -427,8 +427,8 @@ def run_upscale_pipeline(
)
# run and save
progress = job.get_progress_callback()
images = chain(job, server, params, [source], callback=progress)
progress = worker.get_progress_callback()
images = chain(worker, server, params, [source], callback=progress)
_pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs):
@ -445,7 +445,7 @@ def run_upscale_pipeline(
# clean up
del image
run_gc([job.get_device()])
run_gc([worker.get_device()])
# notify the user
show_system_toast(f"finished upscale job: {dest}")
@ -453,7 +453,7 @@ def run_upscale_pipeline(
def run_blend_pipeline(
job: WorkerContext,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
@ -477,15 +477,15 @@ def run_blend_pipeline(
)
# run and save
progress = job.get_progress_callback()
images = chain(job, server, params, sources, callback=progress)
progress = worker.get_progress_callback()
images = chain(worker, server, params, sources, callback=progress)
for image, output in zip(images, outputs):
dest = save_image(server, output, image, params, size, upscale=upscale)
# clean up
del image
run_gc([job.get_device()])
run_gc([worker.get_device()])
# notify the user
show_system_toast(f"finished blend job: {dest}")

View File

@ -8,7 +8,7 @@ logger = getLogger(__name__)
def run_txt2txt_pipeline(
job: WorkerContext,
worker: WorkerContext,
_server: ServerContext,
params: ImageParams,
_size: Size,
@ -20,7 +20,7 @@ def run_txt2txt_pipeline(
model = "EleutherAI/gpt-j-6B"
tokens = 1024
device = job.get_device()
device = worker.get_device()
pipe = GPTJForCausalLM.from_pretrained(model).to(device.torch_str())
tokenizer = AutoTokenizer.from_pretrained(model)