Skip to content
Open
6 changes: 4 additions & 2 deletions invokeai/frontend/web/.storybook/ReduxInit.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { PropsWithChildren } from 'react';
import { memo, useEffect } from 'react';

import { useAppDispatch } from '../src/app/store/storeHooks';
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
import { modelChanged } from 'features/controlLayers/store/actions';
/**
* Initializes some state for storybook. Must be in a different component
* so that it is run inside the redux context.
Expand All @@ -13,7 +13,9 @@ export const ReduxInit = memo(({ children }: PropsWithChildren) => {
useGlobalModifiersInit();
useEffect(() => {
dispatch(
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
modelChanged({
model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' },
})
);
}, [dispatch]);

Expand Down
2 changes: 1 addition & 1 deletion invokeai/frontend/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"@invoke-ai/ui-library": "^0.0.47",
"@nanostores/react": "^1.0.0",
"@observ33r/object-equals": "^1.1.5",
"@reduxjs/toolkit": "2.8.2",
"@reduxjs/toolkit": "2.9.0",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.8.2",
"ag-psd": "^28.2.2",
Expand Down
10 changes: 5 additions & 5 deletions invokeai/frontend/web/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import type { Middleware, UnknownAction } from '@reduxjs/toolkit';
import { injectTabActionContext } from 'app/store/util';
import { isCanvasInstanceAction } from 'features/controlLayers/store/canvasSlice';
import { selectActiveCanvasId, selectActiveTab } from 'features/controlLayers/store/selectors';
import { isTabInstanceParamsAction } from 'features/controlLayers/store/tabSlice';

export const actionContextMiddleware: Middleware = (store) => (next) => (action) => {
const currentAction = action as UnknownAction;

if (isTabActionContextRequired(currentAction)) {
const state = store.getState();
const tab = selectActiveTab(state);
const canvasId = tab === 'canvas' ? selectActiveCanvasId(state) : undefined;

injectTabActionContext(currentAction, tab, canvasId);
}

return next(action);
};

const isTabActionContextRequired = (action: UnknownAction) => {
return isTabInstanceParamsAction(action) || isCanvasInstanceAction(action);
};
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import type { AppStartListening } from 'app/store/store';
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
import { selectActiveTabParams, setInfillMethod } from 'features/controlLayers/store/paramsSlice';
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';

export const addAppConfigReceivedListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
effect: (action, { getState, dispatch }) => {
effect: (action, api) => {
const { getState, dispatch } = api;
const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload;
const infillMethod = getState().params.infillMethod;
const infillMethod = selectActiveTabParams(getState()).infillMethod;

if (!infill_methods.includes(infillMethod)) {
// If the selected infill method does not exist, prefer 'lama' if it's in the list, otherwise 'tile'.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { AppStartListening } from 'app/store/store';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { selectCanvases } from 'features/controlLayers/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/state';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/selectors';
Expand All @@ -19,12 +19,12 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS

const state = getState();
const nodes = selectNodesSlice(state);
const canvas = selectCanvasSlice(state);
const canvases = selectCanvases(state);
const upscale = selectUpscaleSlice(state);
const refImages = selectRefImagesSlice(state);

deleted_images.forEach((image_name) => {
const imageUsage = getImageUsage(nodes, canvas, upscale, refImages, image_name);
const imageUsage = getImageUsage(nodes, canvases, upscale, refImages, image_name);

if (imageUsage.isNodesImage && !wasNodeEditorReset) {
dispatch(nodeEditorReset());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/store';
import { modelChanged } from 'features/controlLayers/store/actions';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraIsEnabledChanged } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import {
buildSelectIsStagingBySessionId,
selectActiveCanvasStagingAreaSessionId,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraIsEnabledChanged, selectAddedLoRAs } from 'features/controlLayers/store/lorasSlice';
import { selectActiveTabParams, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
import {
selectActiveCanvas,
selectAllEntitiesOfType,
selectBboxModelBase,
selectCanvasSlice,
} from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { modelSelected } from 'features/parameters/store/actions';
Expand All @@ -31,7 +35,8 @@ const log = logger('models');
export const addModelSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: modelSelected,
effect: (action, { getState, dispatch }) => {
effect: (action, api) => {
const { getState, dispatch } = api;
const state = getState();
const result = zParameterModel.safeParse(action.payload);

Expand All @@ -42,22 +47,23 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =

const newModel = result.data;
const newBase = newModel.base;
const didBaseModelChange = state.params.model?.base !== newBase;
const params = selectActiveTabParams(state);
const didBaseModelChange = params.model?.base !== newBase;

if (didBaseModelChange) {
// we may need to reset some incompatible submodels
let modelsUpdatedDisabledOrCleared = 0;

// handle incompatible loras
state.loras.loras.forEach((lora) => {
selectAddedLoRAs(state).forEach((lora) => {
if (lora.model.base !== newBase) {
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: false }));
modelsUpdatedDisabledOrCleared += 1;
}
});

// handle incompatible vae
const { vae } = state.params;
const { vae } = params;
if (vae && vae.base !== newBase) {
dispatch(vaeSelected(null));
modelsUpdatedDisabledOrCleared += 1;
Expand Down Expand Up @@ -118,7 +124,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
const newRegionalRefImageModel = selectRegionalRefImageModels(state)[0] ?? null;

// All regional guidance entities are updated to use the same new model.
const canvasState = selectCanvasSlice(state);
const canvasState = selectActiveCanvas(state);
const canvasRegionalGuidanceEntities = selectAllEntitiesOfType(canvasState, 'regional_guidance');
for (const entity of canvasRegionalGuidanceEntities) {
for (const refImage of entity.referenceImages) {
Expand Down Expand Up @@ -152,14 +158,16 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
}

dispatch(modelChanged({ model: newModel, previousModel: state.params.model }));
dispatch(modelChanged({ model: newModel, previousModel: params.model }));

const modelBase = selectBboxModelBase(state);

if (modelBase !== state.params.model?.base) {
if (modelBase !== params.model?.base) {
// Sync generate tab settings whenever the model base changes
dispatch(syncedToOptimalDimension());
const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
const sessionId = selectActiveCanvasStagingAreaSessionId(state);
const selectIsStaging = buildSelectIsStagingBySessionId(sessionId);
const isStaging = selectIsStaging(state);
if (!isStaging) {
// Canvas tab only syncs if not staging
dispatch(bboxSyncedToOptimalDimension());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, AppStartListening, RootState } from 'app/store/store';
import { modelChanged } from 'features/controlLayers/store/actions';
import { controlLayerModelChanged, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { loraDeleted, selectAddedLoRAs } from 'features/controlLayers/store/lorasSlice';
import {
clipEmbedModelSelected,
fluxVAESelected,
modelChanged,
refinerModelChanged,
selectActiveTabParams,
t5EncoderModelSelected,
vaeSelected,
} from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { selectActiveCanvas } from 'features/controlLayers/store/selectors';
import {
getEntityIdentifier,
isFLUXReduxConfig,
Expand Down Expand Up @@ -103,7 +104,7 @@ type ModelHandler = (
) => undefined;

const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
const selectedMainModel = state.params.model;
const selectedMainModel = selectActiveTabParams(state).model;
const allMainModels = models.filter(isNonRefinerMainModelConfig).sort((a) => (a.base === 'sdxl' ? -1 : 1));

const firstModel = allMainModels[0];
Expand Down Expand Up @@ -144,7 +145,7 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
};

const handleRefinerModels: ModelHandler = (models, state, dispatch, log) => {
const selectedRefinerModel = state.params.refinerModel;
const selectedRefinerModel = selectActiveTabParams(state).refinerModel;

// `null` is a valid refiner model - no need to do anything.
if (selectedRefinerModel === null) {
Expand All @@ -168,7 +169,7 @@ const handleRefinerModels: ModelHandler = (models, state, dispatch, log) => {
};

const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
const selectedVAEModel = state.params.vae;
const selectedVAEModel = selectActiveTabParams(state).vae;

// `null` is a valid VAE - it means "use the VAE baked into the currently-selected main model"
if (selectedVAEModel === null) {
Expand All @@ -193,7 +194,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {

const handleLoRAModels: ModelHandler = (models, state, dispatch, log) => {
const loraModels = models.filter(isLoRAModelConfig);
state.loras.loras.forEach((lora) => {
selectAddedLoRAs(state).forEach((lora) => {
const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
if (isLoRAAvailable) {
return;
Expand Down Expand Up @@ -221,7 +222,7 @@ const handleVideoModels: ModelHandler = (models, state, dispatch, log) => {

const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log) => {
const caModels = models.filter(isControlLayerModelConfig);
selectCanvasSlice(state).controlLayers.entities.forEach((entity) => {
selectActiveCanvas(state).controlLayers.entities.forEach((entity) => {
const selectedControlAdapterModel = entity.controlAdapter.model;
// `null` is a valid control adapter model - no need to do anything.
if (!selectedControlAdapterModel) {
Expand Down Expand Up @@ -256,7 +257,7 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
dispatch(refImageModelChanged({ id: entity.id, modelConfig: null }));
});

selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
selectActiveCanvas(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
if (!isRegionalGuidanceIPAdapterConfig(config)) {
return;
Expand Down Expand Up @@ -299,7 +300,7 @@ const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
dispatch(refImageModelChanged({ id: entity.id, modelConfig: null }));
});

selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
selectActiveCanvas(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
if (!isRegionalGuidanceFLUXReduxConfig(config)) {
return;
Expand Down Expand Up @@ -417,7 +418,7 @@ const handleTileControlNetModel: ModelHandler = (models, state, dispatch, log) =
};

const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const selectedT5EncoderModel = state.params.t5EncoderModel;
const selectedT5EncoderModel = selectActiveTabParams(state).t5EncoderModel;
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));

// If the currently selected model is available, we don't need to do anything
Expand Down Expand Up @@ -445,7 +446,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
};

const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
const selectedCLIPEmbedModel = state.params.clipEmbedModel;
const selectedCLIPEmbedModel = selectActiveTabParams(state).clipEmbedModel;
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfig(m));

// If the currently selected model is available, we don't need to do anything
Expand Down Expand Up @@ -473,7 +474,7 @@ const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
};

const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, log) => {
const selectedFLUXVAEModel = state.params.fluxVAE;
const selectedFLUXVAEModel = selectActiveTabParams(state).fluxVAE;
const fluxVAEModels = models.filter((m) => isFluxVAEModelConfig(m));

// If the currently selected model is available, we don't need to do anything
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import type { AppStartListening } from 'app/store/store';
import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
buildSelectIsStagingBySessionId,
selectActiveCanvasStagingAreaSessionId,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
heightChanged,
selectActiveTabParams,
setCfgRescaleMultiplier,
setCfgScale,
setGuidance,
Expand All @@ -13,6 +17,7 @@ import {
vaeSelected,
widthChanged,
} from 'features/controlLayers/store/paramsSlice';
import { selectActiveTab } from 'features/controlLayers/store/selectors';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
isParameterCFGRescaleMultiplier,
Expand All @@ -26,18 +31,18 @@ import {
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { t } from 'i18next';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';

export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: setDefaultSettings,
effect: async (action, { dispatch, getState }) => {
effect: async (action, api) => {
const { dispatch, getState } = api;
const state = getState();

const currentModel = state.params.model;
const currentModel = selectActiveTabParams(state).model;

if (!currentModel) {
return;
Expand Down Expand Up @@ -115,7 +120,9 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
const setSizeOptions = { updateAspectRatio: true, clamp: true };

const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
const sessionId = selectActiveCanvasStagingAreaSessionId(state);
const selectIsStaging = buildSelectIsStagingBySessionId(sessionId);
const isStaging = selectIsStaging(state);

const activeTab = selectActiveTab(getState());
if (activeTab === 'generate') {
Expand Down
Loading