1
0
Fork 0

fix(api): move prompt splitting logging to debug level

This commit is contained in:
Sean Sube 2023-03-08 22:55:58 -06:00
parent 6c47542d46
commit 21c60709bc
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 7 additions and 7 deletions

View File

@ -48,7 +48,7 @@ def expand_prompt(
) )
groups_count = ceil(tokens.input_ids.shape[1] / MAX_TOKENS_PER_GROUP) groups_count = ceil(tokens.input_ids.shape[1] / MAX_TOKENS_PER_GROUP)
logger.info("splitting %s into %s groups", tokens.input_ids.shape, groups_count) logger.debug("splitting %s into %s groups", tokens.input_ids.shape, groups_count)
groups = [] groups = []
# np.array_split(tokens.input_ids, groups_count, axis=1) # np.array_split(tokens.input_ids, groups_count, axis=1)
@ -57,19 +57,19 @@ def expand_prompt(
group_end = min( group_end = min(
group_start + MAX_TOKENS_PER_GROUP, tokens.input_ids.shape[1] group_start + MAX_TOKENS_PER_GROUP, tokens.input_ids.shape[1]
) # or should this be 1? ) # or should this be 1?
logger.info("building group for token slice [%s : %s]", group_start, group_end) logger.debug("building group for token slice [%s : %s]", group_start, group_end)
groups.append(tokens.input_ids[:, group_start:group_end]) groups.append(tokens.input_ids[:, group_start:group_end])
# encode each chunk # encode each chunk
logger.info("group token shapes: %s", [t.shape for t in groups]) logger.debug("group token shapes: %s", [t.shape for t in groups])
group_embeds = [] group_embeds = []
for group in groups: for group in groups:
logger.info("encoding group: %s", group.shape) logger.debug("encoding group: %s", group.shape)
embeds = self.text_encoder(input_ids=group.astype(np.int32))[0] embeds = self.text_encoder(input_ids=group.astype(np.int32))[0]
group_embeds.append(embeds) group_embeds.append(embeds)
# concat those embeds # concat those embeds
logger.info("group embeds shape: %s", [t.shape for t in group_embeds]) logger.debug("group embeds shape: %s", [t.shape for t in group_embeds])
prompt_embeds = np.concatenate(group_embeds, axis=1) prompt_embeds = np.concatenate(group_embeds, axis=1)
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
@ -105,7 +105,7 @@ def expand_prompt(
input_ids=uncond_input.input_ids.astype(np.int32) input_ids=uncond_input.input_ids.astype(np.int32)
)[0] )[0]
negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1] negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
logger.info( logger.debug(
"padding negative prompt to match input: %s, %s, %s extra tokens", "padding negative prompt to match input: %s, %s, %s extra tokens",
tokens.input_ids.shape, tokens.input_ids.shape,
negative_prompt_embeds.shape, negative_prompt_embeds.shape,
@ -126,5 +126,5 @@ def expand_prompt(
# to avoid doing two forward passes # to avoid doing two forward passes
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
logger.info("expanded prompt shape: %s", prompt_embeds.shape) logger.debug("expanded prompt shape: %s", prompt_embeds.shape)
return prompt_embeds return prompt_embeds