Skip to content

Commit 2ad5b5c

Browse files
maryhippMary Hipp
andauthored
Flux Kontext UI support (#8111)
* add support for flux-kontext models in nodes * flux kontext in canvas * add aspect ratio support * lint * restore aspect ratio logic * more linting * typegen * fix typegen --------- Co-authored-by: Mary Hipp <[email protected]>
1 parent 24d8a96 commit 2ad5b5c

File tree

29 files changed

+357
-16
lines changed

29 files changed

+357
-16
lines changed

invokeai/app/invocations/fields.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
6464
Imagen3Model = "Imagen3ModelField"
6565
Imagen4Model = "Imagen4ModelField"
6666
ChatGPT4oModel = "ChatGPT4oModelField"
67+
FluxKontextModel = "FluxKontextModelField"
6768
# endregion
6869

6970
# region Misc Field Types

invokeai/backend/model_manager/taxonomy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class BaseModelType(str, Enum):
2929
Imagen3 = "imagen3"
3030
Imagen4 = "imagen4"
3131
ChatGPT4o = "chatgpt-4o"
32+
FluxKontext = "flux-kontext"
3233

3334

3435
class ModelType(str, Enum):

invokeai/frontend/web/public/locales/en.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,7 @@
11471147
"modelIncompatibleScaledBboxWidth": "Scaled bbox width is {{width}} but {{model}} requires multiple of {{multiple}}",
11481148
"modelIncompatibleScaledBboxHeight": "Scaled bbox height is {{height}} but {{model}} requires multiple of {{multiple}}",
11491149
"fluxModelMultipleControlLoRAs": "Can only use 1 Control LoRA at a time",
1150+
"fluxKontextMultipleReferenceImages": "Can only use 1 Reference Image at a time with Flux Kontext",
11501151
"canvasIsFiltering": "Canvas is busy (filtering)",
11511152
"canvasIsTransforming": "Canvas is busy (transforming)",
11521153
"canvasIsRasterizing": "Canvas is busy (rasterizing)",
@@ -1337,6 +1338,7 @@
13371338
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
13381339
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
13391340
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
1341+
"fluxKontextIncompatibleGenerationMode": "Flux Kontext supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
13401342
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
13411343
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
13421344
"workflowUnpublished": "Workflow Unpublished"

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatch
1010
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
1111
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
1212
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
13+
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
1314
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
1415
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
1516
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
@@ -59,6 +60,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
5960
return await buildImagen4Graph(state, manager);
6061
case 'chatgpt-4o':
6162
return await buildChatGPT4oGraph(state, manager);
63+
case 'flux-kontext':
64+
return await buildFluxKontextGraph(state, manager);
6265
default:
6366
assert(false, `No graph builders for base ${base}`);
6467
}

invokeai/frontend/web/src/features/controlLayers/hooks/addLayerHooks.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import type {
2929
import {
3030
initialChatGPT4oReferenceImage,
3131
initialControlNet,
32+
initialFluxKontextReferenceImage,
3233
initialIPAdapter,
3334
initialT2IAdapter,
3435
} from 'features/controlLayers/store/util';
@@ -87,6 +88,12 @@ export const selectDefaultRefImageConfig = createSelector(
8788
return referenceImage;
8889
}
8990

91+
if (selectedMainModel?.base === 'flux-kontext') {
92+
const referenceImage = deepClone(initialFluxKontextReferenceImage);
93+
referenceImage.model = zModelIdentifierField.parse(selectedMainModel);
94+
return referenceImage;
95+
}
96+
9097
const { data } = query;
9198
let model: IPAdapterModelConfig | null = null;
9299
if (data) {

invokeai/frontend/web/src/features/controlLayers/hooks/useIsEntityTypeEnabled.ts

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ import { useAppSelector } from 'app/store/storeHooks';
22
import {
33
selectIsChatGTP4o,
44
selectIsCogView4,
5+
selectIsFluxKontext,
56
selectIsImagen3,
67
selectIsImagen4,
78
selectIsSD3,
89
} from 'features/controlLayers/store/paramsSlice';
10+
import { selectActiveReferenceImageEntities } from 'features/controlLayers/store/selectors';
911
import type { CanvasEntityType } from 'features/controlLayers/store/types';
1012
import { useMemo } from 'react';
1113
import type { Equals } from 'tsafe';
@@ -17,23 +19,28 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
1719
const isImagen3 = useAppSelector(selectIsImagen3);
1820
const isImagen4 = useAppSelector(selectIsImagen4);
1921
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
22+
const isFluxKontext = useAppSelector(selectIsFluxKontext);
23+
const activeReferenceImageEntities = useAppSelector(selectActiveReferenceImageEntities);
2024

2125
const isEntityTypeEnabled = useMemo<boolean>(() => {
2226
switch (entityType) {
2327
case 'reference_image':
28+
if (isFluxKontext) {
29+
return activeReferenceImageEntities.length === 0;
30+
}
2431
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4;
2532
case 'regional_guidance':
26-
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
33+
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
2734
case 'control_layer':
28-
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
35+
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
2936
case 'inpaint_mask':
30-
return !isImagen3 && !isImagen4 && !isChatGPT4o;
37+
return !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
3138
case 'raster_layer':
32-
return !isImagen3 && !isImagen4 && !isChatGPT4o;
39+
return !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
3340
default:
3441
assert<Equals<typeof entityType, never>>(false);
3542
}
36-
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isChatGPT4o]);
43+
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isFluxKontext, isChatGPT4o, activeReferenceImageEntities]);
3744

3845
return isEntityTypeEnabled;
3946
};

invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,13 @@ import type {
6969
IPMethodV2,
7070
T2IAdapterConfig,
7171
} from './types';
72-
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagenAspectRatioID, isRenderableEntity } from './types';
72+
import {
73+
getEntityIdentifier,
74+
isChatGPT4oAspectRatioID,
75+
isFluxKontextAspectRatioID,
76+
isImagenAspectRatioID,
77+
isRenderableEntity,
78+
} from './types';
7379
import {
7480
converters,
7581
getControlLayerState,
@@ -81,6 +87,7 @@ import {
8187
initialChatGPT4oReferenceImage,
8288
initialControlLoRA,
8389
initialControlNet,
90+
initialFluxKontextReferenceImage,
8491
initialFLUXRedux,
8592
initialIPAdapter,
8693
initialT2IAdapter,
@@ -686,6 +693,16 @@ export const canvasSlice = createSlice({
686693
return;
687694
}
688695

696+
if (entity.ipAdapter.model.base === 'flux-kontext') {
697+
// Switching to flux-kontext
698+
entity.ipAdapter = {
699+
...initialFluxKontextReferenceImage,
700+
image: entity.ipAdapter.image,
701+
model: entity.ipAdapter.model,
702+
};
703+
return;
704+
}
705+
689706
if (entity.ipAdapter.model.type === 'flux_redux') {
690707
// Switching to flux_redux
691708
entity.ipAdapter = {
@@ -1322,6 +1339,31 @@ export const canvasSlice = createSlice({
13221339
}
13231340
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
13241341
state.bbox.aspectRatio.isLocked = true;
1342+
} else if (state.bbox.modelBase === 'flux-kontext' && isFluxKontextAspectRatioID(id)) {
1343+
if (id === '3:4') {
1344+
state.bbox.rect.width = 880;
1345+
state.bbox.rect.height = 1184;
1346+
} else if (id === '4:3') {
1347+
state.bbox.rect.width = 1184;
1348+
state.bbox.rect.height = 880;
1349+
} else if (id === '9:16') {
1350+
state.bbox.rect.width = 752;
1351+
state.bbox.rect.height = 1392;
1352+
} else if (id === '16:9') {
1353+
state.bbox.rect.width = 1392;
1354+
state.bbox.rect.height = 752;
1355+
} else if (id === '21:9') {
1356+
state.bbox.rect.width = 1568;
1357+
state.bbox.rect.height = 672;
1358+
} else if (id === '9:21') {
1359+
state.bbox.rect.width = 672;
1360+
state.bbox.rect.height = 1568;
1361+
} else if (id === '1:1') {
1362+
state.bbox.rect.width = 880;
1363+
state.bbox.rect.height = 880;
1364+
}
1365+
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
1366+
state.bbox.aspectRatio.isLocked = true;
13251367
} else {
13261368
state.bbox.aspectRatio.isLocked = true;
13271369
state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;

invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ export const selectIsCogView4 = createParamsSelector((params) => params.model?.b
383383
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
384384
export const selectIsImagen4 = createParamsSelector((params) => params.model?.base === 'imagen4');
385385
export const selectIsChatGTP4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
386+
export const selectIsFluxKontext = createParamsSelector((params) => params.model?.base === 'flux-kontext');
386387

387388
export const selectModel = createParamsSelector((params) => params.model);
388389
export const selectModelKey = createParamsSelector((params) => params.model?.key);

invokeai/frontend/web/src/features/controlLayers/store/types.ts

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,13 @@ const zChatGPT4oReferenceImageConfig = z.object({
258258
});
259259
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
260260

261+
const zFluxKontextReferenceImageConfig = z.object({
262+
type: z.literal('flux_kontext_reference_image'),
263+
image: zImageWithDims.nullable(),
264+
model: zServerValidatedModelIdentifierField.nullable(),
265+
});
266+
export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>;
267+
261268
const zCanvasEntityBase = z.object({
262269
id: zId,
263270
name: zName,
@@ -268,7 +275,12 @@ const zCanvasEntityBase = z.object({
268275
const zCanvasReferenceImageState = zCanvasEntityBase.extend({
269276
type: z.literal('reference_image'),
270277
// This should be named `referenceImage` but we need to keep it as `ipAdapter` for backwards compatibility
271-
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig, zChatGPT4oReferenceImageConfig]),
278+
ipAdapter: z.discriminatedUnion('type', [
279+
zIPAdapterConfig,
280+
zFLUXReduxConfig,
281+
zChatGPT4oReferenceImageConfig,
282+
zFluxKontextReferenceImageConfig,
283+
]),
272284
});
273285
export type CanvasReferenceImageState = z.infer<typeof zCanvasReferenceImageState>;
274286

@@ -280,6 +292,9 @@ export const isFLUXReduxConfig = (config: CanvasReferenceImageState['ipAdapter']
280292
export const isChatGPT4oReferenceImageConfig = (
281293
config: CanvasReferenceImageState['ipAdapter']
282294
): config is ChatGPT4oReferenceImageConfig => config.type === 'chatgpt_4o_reference_image';
295+
export const isFluxKontextReferenceImageConfig = (
296+
config: CanvasReferenceImageState['ipAdapter']
297+
): config is FluxKontextReferenceImageConfig => config.type === 'flux_kontext_reference_image';
283298

284299
const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
285300
export type FillStyle = z.infer<typeof zFillStyle>;
@@ -406,7 +421,7 @@ export type StagingAreaImage = {
406421
offsetY: number;
407422
};
408423

409-
export const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
424+
export const zAspectRatioID = z.enum(['Free', '21:9', '9:21', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
410425

411426
export const zImagen3AspectRatioID = z.enum(['16:9', '4:3', '1:1', '3:4', '9:16']);
412427
export const isImagenAspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
@@ -416,6 +431,10 @@ export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);
416431
export const isChatGPT4oAspectRatioID = (v: unknown): v is z.infer<typeof zChatGPT4oAspectRatioID> =>
417432
zChatGPT4oAspectRatioID.safeParse(v).success;
418433

434+
export const zFluxKontextAspectRatioID = z.enum(['21:9', '4:3', '1:1', '3:4', '9:21', '16:9', '9:16']);
435+
export const isFluxKontextAspectRatioID = (v: unknown): v is z.infer<typeof zFluxKontextAspectRatioID> =>
436+
zFluxKontextAspectRatioID.safeParse(v).success;
437+
419438
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
420439
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
421440

invokeai/frontend/web/src/features/controlLayers/store/util.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import type {
1010
ChatGPT4oReferenceImageConfig,
1111
ControlLoRAConfig,
1212
ControlNetConfig,
13+
FluxKontextReferenceImageConfig,
1314
FLUXReduxConfig,
1415
ImageWithDims,
1516
IPAdapterConfig,
@@ -83,6 +84,11 @@ export const initialChatGPT4oReferenceImage: ChatGPT4oReferenceImageConfig = {
8384
image: null,
8485
model: null,
8586
};
87+
export const initialFluxKontextReferenceImage: FluxKontextReferenceImageConfig = {
88+
type: 'flux_kontext_reference_image',
89+
image: null,
90+
model: null,
91+
};
8692
export const initialT2IAdapter: T2IAdapterConfig = {
8793
type: 't2i_adapter',
8894
model: null,

0 commit comments

Comments
 (0)