fix(api): mark all convert methods as no_grad
This commit is contained in:
parent
b44e644f9e
commit
21fc7c5968
|
@ -1650,6 +1650,7 @@ def extract_checkpoint(
|
|||
logger.info(result_status)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffusion_original(
|
||||
ctx: ConversionContext,
|
||||
model: ModelDict,
|
||||
|
|
|
@ -11,6 +11,7 @@ from ..utils import ConversionContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffusion_textual_inversion(
|
||||
context: ConversionContext, name: str, base_model: str, inversion: str
|
||||
):
|
||||
|
|
|
@ -74,14 +74,22 @@ class WorkerContext:
|
|||
|
||||
|
||||
class JobStatus:
|
||||
name: str
|
||||
device: str
|
||||
progress: int
|
||||
cancelled: bool
|
||||
finished: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
device: DeviceParams,
|
||||
progress: int = 0,
|
||||
cancelled: bool = False,
|
||||
finished: bool = False,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.device = device.device
|
||||
self.progress = progress
|
||||
self.cancelled = cancelled
|
||||
self.finished = finished
|
||||
|
|
Loading…
Reference in New Issue