Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions CrossAttentionPatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class CrossAttentionPatch:
# forward for patching
def __init__(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only'):
def __init__(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, image_schedule=None, embeds_scaling='V only'):
self.weights = [weight]
self.ipadapters = [ipadapter]
self.conds = [cond]
Expand All @@ -17,14 +17,15 @@ def __init__(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=Non
self.sigma_starts = [sigma_start]
self.sigma_ends = [sigma_end]
self.unfold_batch = [unfold_batch]
self.image_schedule = [image_schedule]
self.embeds_scaling = [embeds_scaling]
self.number = number
self.layers = 11 if '101_to_k_ip' in ipadapter.ip_layers.to_kvs else 16 # TODO: check if this is a valid condition to detect all models

self.k_key = str(self.number*2+1) + "_to_k_ip"
self.v_key = str(self.number*2+1) + "_to_v_ip"

def set_new_condition(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only'):
def set_new_condition(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, image_schedule=None, embeds_scaling='V only'):
self.weights.append(weight)
self.ipadapters.append(ipadapter)
self.conds.append(cond)
Expand All @@ -35,6 +36,7 @@ def set_new_condition(self, ipadapter=None, number=0, weight=1.0, cond=None, con
self.sigma_starts.append(sigma_start)
self.sigma_ends.append(sigma_end)
self.unfold_batch.append(unfold_batch)
self.image_schedule.append(image_schedule)
self.embeds_scaling.append(embeds_scaling)

def __call__(self, q, k, v, extra_options):
Expand All @@ -54,7 +56,7 @@ def __call__(self, q, k, v, extra_options):
out = optimized_attention(q, k, v, extra_options["n_heads"])
_, _, oh, ow = extra_options["original_shape"]

for weight, cond, cond_alt, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch, embeds_scaling in zip(self.weights, self.conds, self.conds_alt, self.unconds, self.ipadapters, self.masks, self.weight_types, self.sigma_starts, self.sigma_ends, self.unfold_batch, self.embeds_scaling):
for weight, cond, cond_alt, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch, image_schedule, embeds_scaling in zip(self.weights, self.conds, self.conds_alt, self.unconds, self.ipadapters, self.masks, self.weight_types, self.sigma_starts, self.sigma_ends, self.unfold_batch, self.image_schedule, self.embeds_scaling):
if sigma <= sigma_start and sigma >= sigma_end:
if weight_type == 'ease in':
weight = weight * (0.05 + 0.95 * (1 - t_idx / self.layers))
Expand Down Expand Up @@ -94,16 +96,23 @@ def __call__(self, q, k, v, extra_options):
elif weight == 0:
continue

# if image length matches or exceeds full_length get sub_idx images
if cond.shape[0] >= ad_params["full_length"]:
cond = torch.Tensor(cond[ad_params["sub_idxs"]])
uncond = torch.Tensor(uncond[ad_params["sub_idxs"]])
# otherwise get sub_idxs images
if image_schedule is not None:
# Use the image_schedule as a lookup table to get the embedded image corresponding to each sub_idx
# If image_schedule isn't long enough then use the last image
cond_idxs = [image_schedule[i if i < len(image_schedule) else -1] for i in ad_params["sub_idxs"]]
cond = torch.Tensor(cond[cond_idxs])
uncond = torch.Tensor(uncond[cond_idxs])
else:
cond = tensor_to_size(cond, ad_params["full_length"])
uncond = tensor_to_size(uncond, ad_params["full_length"])
cond = cond[ad_params["sub_idxs"]]
uncond = uncond[ad_params["sub_idxs"]]
# if image length matches or exceeds full_length get sub_idx images
if cond.shape[0] >= ad_params["full_length"]:
cond = torch.Tensor(cond[ad_params["sub_idxs"]])
uncond = torch.Tensor(uncond[ad_params["sub_idxs"]])
# otherwise get sub_idxs images
else:
cond = tensor_to_size(cond, ad_params["full_length"])
uncond = tensor_to_size(uncond, ad_params["full_length"])
cond = cond[ad_params["sub_idxs"]]
uncond = uncond[ad_params["sub_idxs"]]
else:
if isinstance(weight, torch.Tensor):
weight = tensor_to_size(weight, batch_prompt)
Expand Down
42 changes: 41 additions & 1 deletion IPAdapterPlus.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def ipadapter_execute(model,
pos_embed=None,
neg_embed=None,
unfold_batch=False,
image_schedule=None,
embeds_scaling='V only',
layer_weights=None):
device = model_management.get_torch_device()
Expand Down Expand Up @@ -374,6 +375,7 @@ def ipadapter_execute(model,
"sigma_start": sigma_start,
"sigma_end": sigma_end,
"unfold_batch": unfold_batch,
"image_schedule": image_schedule,
"embeds_scaling": embeds_scaling,
}

Expand Down Expand Up @@ -634,7 +636,7 @@ def INPUT_TYPES(s):
FUNCTION = "apply_ipadapter"
CATEGORY = "ipadapter"

def apply_ipadapter(self, model, ipadapter, start_at=0.0, end_at=1.0, weight=1.0, weight_style=1.0, weight_composition=1.0, expand_style=False, weight_type="linear", combine_embeds="concat", weight_faceidv2=None, image=None, image_style=None, image_composition=None, image_negative=None, clip_vision=None, attn_mask=None, insightface=None, embeds_scaling='V only', layer_weights=None, ipadapter_params=None):
def apply_ipadapter(self, model, ipadapter, start_at=0.0, end_at=1.0, weight=1.0, weight_style=1.0, weight_composition=1.0, expand_style=False, weight_type="linear", combine_embeds="concat", weight_faceidv2=None, image=None, image_style=None, image_composition=None, image_negative=None, clip_vision=None, image_schedule=None, attn_mask=None, insightface=None, embeds_scaling='V only', layer_weights=None, ipadapter_params=None):
is_sdxl = isinstance(model.model, (comfy.model_base.SDXL, comfy.model_base.SDXLRefiner, comfy.model_base.SDXL_instructpix2pix))

if 'ipadapter' in ipadapter:
Expand Down Expand Up @@ -688,6 +690,7 @@ def apply_ipadapter(self, model, ipadapter, start_at=0.0, end_at=1.0, weight=1.0
"end_at": end_at if not isinstance(end_at, list) else end_at[i],
"attn_mask": attn_mask if not isinstance(attn_mask, list) else attn_mask[i],
"unfold_batch": self.unfold_batch,
"image_schedule": image_schedule,
"embeds_scaling": embeds_scaling,
"insightface": insightface if insightface is not None else ipadapter['insightface']['model'] if 'insightface' in ipadapter else None,
"layer_weights": layer_weights,
Expand Down Expand Up @@ -719,6 +722,7 @@ def INPUT_TYPES(s):
"image_negative": ("IMAGE",),
"attn_mask": ("MASK",),
"clip_vision": ("CLIP_VISION",),
"image_schedule": ("INT", {"default": None, "forceInput": True} ),
}
}

Expand Down Expand Up @@ -771,6 +775,7 @@ def INPUT_TYPES(s):
"image_negative": ("IMAGE",),
"attn_mask": ("MASK",),
"clip_vision": ("CLIP_VISION",),
"image_schedule": ("INT", {"default": None, "forceInput": True} ),
}
}

Expand Down Expand Up @@ -963,6 +968,7 @@ def INPUT_TYPES(s):
"image_negative": ("IMAGE",),
"attn_mask": ("MASK",),
"clip_vision": ("CLIP_VISION",),
"image_schedule": ("INT", {"default": None, "forceInput": True} ),
}
}

Expand Down Expand Up @@ -1330,6 +1336,37 @@ def load(self, embeds):
path = folder_paths.get_annotated_filepath(embeds)
return (torch.load(path).cpu(), )

defaultValue="""0:0,
40:1,
80:2,
"""
class IPAdapterImageSchedule:
@classmethod
def INPUT_TYPES(s):
return {"required": {"text": ("STRING", {"multiline": True, "default": defaultValue}),
"max_frames": ("INT", {"default": 120.0, "min": 1.0, "max": 999999.0, "step": 1.0}),
"print_output": ("BOOLEAN", {"default": False})}}

RETURN_TYPES = ("INT",)
FUNCTION = "schedule"

CATEGORY = "ipadapter/utils"

def schedule(self, text, max_frames, print_output):
frames = [0] * max_frames
for item in text.split(","):
item = item.strip()
if ":" in item:
parts = item.split(":")
if len(parts) == 2:
start_frame = int(parts[0])
value = int(parts[1])
for i in range(start_frame, max_frames):
frames[i] = value
if print_output is True:
print("ValueSchedule: ", frames)
return (frames, )

class IPAdapterWeights:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -1521,8 +1558,10 @@ def combine(self, params_1, params_2, params_3=None, params_4=None, params_5=Non
"IPAdapterSaveEmbeds": IPAdapterSaveEmbeds,
"IPAdapterLoadEmbeds": IPAdapterLoadEmbeds,
"IPAdapterWeights": IPAdapterWeights,
"IPAdapterImageSchedule": IPAdapterImageSchedule,
"IPAdapterRegionalConditioning": IPAdapterRegionalConditioning,
"IPAdapterCombineParams": IPAdapterCombineParams,
"IPAdapterCombineParams": IPAdapterCombineParams,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand Down Expand Up @@ -1555,6 +1594,7 @@ def combine(self, params_1, params_2, params_3=None, params_4=None, params_5=Non
"IPAdapterSaveEmbeds": "IPAdapter Save Embeds",
"IPAdapterLoadEmbeds": "IPAdapter Load Embeds",
"IPAdapterWeights": "IPAdapter Weights",
"IPAdapterImageSchedule": "IPAdapterImageSchedule",
"IPAdapterRegionalConditioning": "IPAdapter Regional Conditioning",
"IPAdapterCombineParams": "IPAdapter Combine Params",
}
31 changes: 22 additions & 9 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,34 @@ def insightface_loader(provider):
model.prepare(ctx_id=0, det_size=(640, 640))
return model

def encode_image_masked(clip_vision, image, mask=None):
def encode_image_masked(clip_vision, images, mask=None):
model_management.load_model_gpu(clip_vision.patcher)
image = image.to(clip_vision.load_device)

pixel_values = clip_preprocess(image.to(clip_vision.load_device)).float()
# Initialize lists to collect outputs
last_hidden_states = []
image_embeds = []
penultimate_hidden_states = []

if mask is not None:
pixel_values = pixel_values * mask.to(clip_vision.load_device)
# Loop over each image in the batch
for image in images:
pixel_values = clip_preprocess(image.to(clip_vision.load_device).unsqueeze(0)).float()

out = clip_vision.model(pixel_values=pixel_values, intermediate_output=-2)
if mask is not None:
pixel_values *= mask.to(clip_vision.load_device)

out = clip_vision.model(pixel_values=pixel_values, intermediate_output=-2)

# Collect the outputs for each image
last_hidden_states.append(out[0].to(model_management.intermediate_device()))
image_embeds.append(out[2].to(model_management.intermediate_device()))
penultimate_hidden_states.append(out[1].to(model_management.intermediate_device()))

# Concatenate all collected outputs across the batch
outputs = Output()
outputs["last_hidden_state"] = out[0].to(model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(model_management.intermediate_device())
outputs["last_hidden_state"] = torch.cat(last_hidden_states, dim=0)
outputs["image_embeds"] = torch.cat(image_embeds, dim=0)
outputs["penultimate_hidden_states"] = torch.cat(penultimate_hidden_states, dim=0)

return outputs

def tensor_to_size(source, dest_size):
Expand Down