lint(api): use more accurate worker name throughout
This commit is contained in:
parent
ada482c183
commit
b31f546516
|
@ -70,14 +70,14 @@ class ChainPipeline:
|
|||
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
callback: Optional[ProgressCallback],
|
||||
**kwargs
|
||||
) -> List[Image.Image]:
|
||||
return self(job, server, params, sources=sources, callback=callback, **kwargs)
|
||||
return self(worker, server, params, sources=sources, callback=callback, **kwargs)
|
||||
|
||||
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
||||
self.stages.append((callback, params, kwargs))
|
||||
|
@ -85,7 +85,7 @@ class ChainPipeline:
|
|||
|
||||
def __call__(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
|
@ -151,10 +151,10 @@ class ChainPipeline:
|
|||
tile_mask: Image.Image,
|
||||
dims: Tuple[int, int, int],
|
||||
) -> Image.Image:
|
||||
for i in range(job.retries):
|
||||
for i in range(worker.retries):
|
||||
try:
|
||||
output_tile = stage_pipe.run(
|
||||
job,
|
||||
worker,
|
||||
server,
|
||||
stage_params,
|
||||
params,
|
||||
|
@ -175,8 +175,8 @@ class ChainPipeline:
|
|||
i,
|
||||
)
|
||||
server.cache.clear()
|
||||
run_gc([job.get_device()])
|
||||
job.retries = job.retries - (i + 1)
|
||||
run_gc([worker.get_device()])
|
||||
worker.retries = worker.retries - (i + 1)
|
||||
|
||||
raise RetryException("exhausted retries on tile")
|
||||
|
||||
|
@ -193,10 +193,10 @@ class ChainPipeline:
|
|||
stage_sources = stage_outputs
|
||||
else:
|
||||
logger.debug("image within tile size of %s, running stage", tile)
|
||||
for i in range(job.retries):
|
||||
for i in range(worker.retries):
|
||||
try:
|
||||
stage_outputs = stage_pipe.run(
|
||||
job,
|
||||
worker,
|
||||
server,
|
||||
stage_params,
|
||||
params,
|
||||
|
@ -213,10 +213,10 @@ class ChainPipeline:
|
|||
"error while running stage pipeline, retry %s of 3", i
|
||||
)
|
||||
server.cache.clear()
|
||||
run_gc([job.get_device()])
|
||||
job.retries = job.retries - (i + 1)
|
||||
run_gc([worker.get_device()])
|
||||
worker.retries = worker.retries - (i + 1)
|
||||
|
||||
if job.retries <= 0:
|
||||
if worker.retries <= 0:
|
||||
raise RetryException("exhausted retries on stage")
|
||||
|
||||
logger.debug(
|
||||
|
|
|
@ -20,7 +20,7 @@ class BlendImg2ImgStage(BaseStage):
|
|||
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
@ -45,7 +45,7 @@ class BlendImg2ImgStage(BaseStage):
|
|||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
job.get_device(),
|
||||
worker.get_device(),
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
|||
class BlendLinearStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
|||
class BlendMaskStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
|||
class CorrectCodeformerStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
@ -30,6 +30,6 @@ class CorrectCodeformerStage(BaseStage):
|
|||
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
|
||||
device = job.get_device()
|
||||
device = worker.get_device()
|
||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
||||
return [pipe(source) for source in sources]
|
||||
|
|
|
@ -15,7 +15,7 @@ logger = getLogger(__name__)
|
|||
class PersistDiskStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
|
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
|||
class PersistS3Stage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
|||
class ReduceCropStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
|||
class ReduceThumbnailStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
|||
class SourceNoiseStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
|||
class SourceS3Stage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -25,7 +25,7 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
@ -67,7 +67,7 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
job.get_device(),
|
||||
worker.get_device(),
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
|
|
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
|||
class SourceURLStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -12,7 +12,7 @@ class BaseStage:
|
|||
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -50,7 +50,7 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
@ -67,7 +67,7 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
return sources
|
||||
|
||||
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
|
||||
device = job.get_device()
|
||||
device = worker.get_device()
|
||||
bsrgan = self.load(server, stage, upscale, device)
|
||||
|
||||
outputs = []
|
||||
|
|
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
|||
class UpscaleHighresStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
@ -35,7 +35,7 @@ class UpscaleHighresStage(BaseStage):
|
|||
|
||||
return [
|
||||
chain(
|
||||
job,
|
||||
worker,
|
||||
server,
|
||||
params,
|
||||
source,
|
||||
|
|
|
@ -28,7 +28,7 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
@ -55,7 +55,7 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
job.get_device(),
|
||||
worker.get_device(),
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
|
|
|
@ -73,7 +73,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
|||
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
@ -89,7 +89,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
|||
for source in sources:
|
||||
output = np.array(source)
|
||||
upsampler = self.load(
|
||||
server, upscale, job.get_device(), tile=stage.tile_size
|
||||
server, upscale, worker.get_device(), tile=stage.tile_size
|
||||
)
|
||||
|
||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||
|
|
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
|||
class UpscaleSimpleStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -18,7 +18,7 @@ logger = getLogger(__name__)
|
|||
class UpscaleStableDiffusionStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
@ -43,7 +43,7 @@ class UpscaleStableDiffusionStage(BaseStage):
|
|||
server,
|
||||
params,
|
||||
"upscale",
|
||||
job.get_device(),
|
||||
worker.get_device(),
|
||||
model=path.join(server.model_path, upscale.upscale_model),
|
||||
)
|
||||
generator = torch.manual_seed(params.seed)
|
||||
|
|
|
@ -50,7 +50,7 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
@ -67,7 +67,7 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
return sources
|
||||
|
||||
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
|
||||
device = job.get_device()
|
||||
device = worker.get_device()
|
||||
swinir = self.load(server, stage, upscale, device)
|
||||
|
||||
outputs = []
|
||||
|
|
|
@ -34,7 +34,7 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
def run_txt2img_pipeline(
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
|
@ -83,8 +83,8 @@ def run_txt2img_pipeline(
|
|||
|
||||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = job.get_progress_callback()
|
||||
images = chain.run(job, server, params, [], callback=progress, latents=latents)
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain.run(worker, server, params, [], callback=progress, latents=latents)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
|
||||
|
@ -102,7 +102,7 @@ def run_txt2img_pipeline(
|
|||
)
|
||||
|
||||
# clean up
|
||||
run_gc([job.get_device()])
|
||||
run_gc([worker.get_device()])
|
||||
|
||||
# notify the user
|
||||
show_system_toast(f"finished txt2img job: {dest}")
|
||||
|
@ -110,7 +110,7 @@ def run_txt2img_pipeline(
|
|||
|
||||
|
||||
def run_img2img_pipeline(
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
outputs: List[str],
|
||||
|
@ -175,8 +175,8 @@ def run_img2img_pipeline(
|
|||
)
|
||||
|
||||
# run and append the filtered source
|
||||
progress = job.get_progress_callback()
|
||||
images = chain(job, server, params, [source], callback=progress)
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain(worker, server, params, [source], callback=progress)
|
||||
|
||||
if source_filter is not None and source_filter != "none":
|
||||
images.append(source)
|
||||
|
@ -199,7 +199,7 @@ def run_img2img_pipeline(
|
|||
)
|
||||
|
||||
# clean up
|
||||
run_gc([job.get_device()])
|
||||
run_gc([worker.get_device()])
|
||||
|
||||
# notify the user
|
||||
show_system_toast(f"finished img2img job: {dest}")
|
||||
|
@ -207,7 +207,7 @@ def run_img2img_pipeline(
|
|||
|
||||
|
||||
def run_inpaint_pipeline(
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
|
@ -353,8 +353,8 @@ def run_inpaint_pipeline(
|
|||
|
||||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = job.get_progress_callback()
|
||||
images = chain(job, server, params, [source], callback=progress, latents=latents)
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain(worker, server, params, [source], callback=progress, latents=latents)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
for image, output in zip(images, outputs):
|
||||
|
@ -378,7 +378,7 @@ def run_inpaint_pipeline(
|
|||
|
||||
# clean up
|
||||
del image
|
||||
run_gc([job.get_device()])
|
||||
run_gc([worker.get_device()])
|
||||
|
||||
# notify the user
|
||||
show_system_toast(f"finished inpaint job: {dest}")
|
||||
|
@ -386,7 +386,7 @@ def run_inpaint_pipeline(
|
|||
|
||||
|
||||
def run_upscale_pipeline(
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
|
@ -427,8 +427,8 @@ def run_upscale_pipeline(
|
|||
)
|
||||
|
||||
# run and save
|
||||
progress = job.get_progress_callback()
|
||||
images = chain(job, server, params, [source], callback=progress)
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain(worker, server, params, [source], callback=progress)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
for image, output in zip(images, outputs):
|
||||
|
@ -445,7 +445,7 @@ def run_upscale_pipeline(
|
|||
|
||||
# clean up
|
||||
del image
|
||||
run_gc([job.get_device()])
|
||||
run_gc([worker.get_device()])
|
||||
|
||||
# notify the user
|
||||
show_system_toast(f"finished upscale job: {dest}")
|
||||
|
@ -453,7 +453,7 @@ def run_upscale_pipeline(
|
|||
|
||||
|
||||
def run_blend_pipeline(
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
|
@ -477,15 +477,15 @@ def run_blend_pipeline(
|
|||
)
|
||||
|
||||
# run and save
|
||||
progress = job.get_progress_callback()
|
||||
images = chain(job, server, params, sources, callback=progress)
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain(worker, server, params, sources, callback=progress)
|
||||
|
||||
for image, output in zip(images, outputs):
|
||||
dest = save_image(server, output, image, params, size, upscale=upscale)
|
||||
|
||||
# clean up
|
||||
del image
|
||||
run_gc([job.get_device()])
|
||||
run_gc([worker.get_device()])
|
||||
|
||||
# notify the user
|
||||
show_system_toast(f"finished blend job: {dest}")
|
||||
|
|
|
@ -8,7 +8,7 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
def run_txt2txt_pipeline(
|
||||
job: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
params: ImageParams,
|
||||
_size: Size,
|
||||
|
@ -20,7 +20,7 @@ def run_txt2txt_pipeline(
|
|||
model = "EleutherAI/gpt-j-6B"
|
||||
tokens = 1024
|
||||
|
||||
device = job.get_device()
|
||||
device = worker.get_device()
|
||||
|
||||
pipe = GPTJForCausalLM.from_pretrained(model).to(device.torch_str())
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
|
Loading…
Reference in New Issue