Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/providers/workers-ai/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ const WorkersAiAPIConfig: ProviderAPIConfig = {
case 'chatComplete': {
return `/${model}`;
}
case 'embed':
case 'embed': {
return `/${model}`;
}
case 'imageGenerate': {
return `/${model}`;
}
default:
return '';
}
Expand Down
37 changes: 4 additions & 33 deletions src/providers/workers-ai/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import {
ErrorResponse,
ProviderConfig,
} from '../types';
import { generateInvalidProviderResponseError } from '../utils';
import {
generateErrorResponse,
generateInvalidProviderResponseError,
} from '../utils';
WorkersAiErrorResponse,
WorkersAiErrorResponseTransform,
} from './utils';

export const WorkersAiChatCompleteConfig: ProviderConfig = {
messages: {
Expand Down Expand Up @@ -36,16 +37,6 @@ export const WorkersAiChatCompleteConfig: ProviderConfig = {
},
};

export interface WorkersAiErrorObject {
code: string;
message: string;
}

interface WorkersAiErrorResponse {
success: boolean;
errors: WorkersAiErrorObject[];
}

interface WorkersAiChatCompleteResponse {
result: {
response: string;
Expand All @@ -60,26 +51,6 @@ interface WorkersAiChatCompleteStreamResponse {
p?: string;
}

export const WorkersAiErrorResponseTransform: (
response: WorkersAiErrorResponse
) => ErrorResponse | undefined = (response) => {
if ('errors' in response) {
return generateErrorResponse(
{
message: response.errors
?.map((error) => `Error ${error.code}:${error.message}`)
.join(', '),
type: null,
param: null,
code: null,
},
WORKERS_AI
);
}

return undefined;
};

// TODO: cloudflare do not return the usage
export const WorkersAiChatCompleteResponseTransform: (
response: WorkersAiChatCompleteResponse | WorkersAiErrorResponse,
Expand Down
37 changes: 4 additions & 33 deletions src/providers/workers-ai/complete.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { Params } from '../../types/requestBody';
import { CompletionResponse, ErrorResponse, ProviderConfig } from '../types';
import { WORKERS_AI } from '../../globals';
import { generateInvalidProviderResponseError } from '../utils';
import {
generateErrorResponse,
generateInvalidProviderResponseError,
} from '../utils';
WorkersAiErrorResponse,
WorkersAiErrorResponseTransform,
} from './utils';

export const WorkersAiCompleteConfig: ProviderConfig = {
prompt: {
Expand All @@ -24,16 +25,6 @@ export const WorkersAiCompleteConfig: ProviderConfig = {
},
};

export interface WorkersAiErrorObject {
code: string;
message: string;
}

interface WorkersAiErrorResponse {
success: boolean;
errors: WorkersAiErrorObject[];
}

interface WorkersAiCompleteResponse {
result: {
response: string;
Expand All @@ -48,26 +39,6 @@ interface WorkersAiCompleteStreamResponse {
p?: string;
}

export const WorkersAiErrorResponseTransform: (
response: WorkersAiErrorResponse
) => ErrorResponse | undefined = (response) => {
if ('errors' in response) {
return generateErrorResponse(
{
message: response.errors
?.map((error) => `Error ${error.code}:${error.message}`)
.join(', '),
type: null,
param: null,
code: null,
},
WORKERS_AI
);
}

return undefined;
};

export const WorkersAiCompleteResponseTransform: (
response: WorkersAiCompleteResponse | WorkersAiErrorResponse,
responseStatus: number,
Expand Down
37 changes: 4 additions & 33 deletions src/providers/workers-ai/embed.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { WORKERS_AI } from '../../globals';
import { EmbedParams, EmbedResponse } from '../../types/embedRequestBody';
import { ErrorResponse, ProviderConfig } from '../types';
import { generateInvalidProviderResponseError } from '../utils';
import {
generateErrorResponse,
generateInvalidProviderResponseError,
} from '../utils';
WorkersAiErrorResponse,
WorkersAiErrorResponseTransform,
} from './utils';

export const WorkersAiEmbedConfig: ProviderConfig = {
input: {
Expand All @@ -20,36 +21,6 @@ export const WorkersAiEmbedConfig: ProviderConfig = {
},
};

export interface WorkersAiErrorObject {
code: string;
message: string;
}

interface WorkersAiErrorResponse {
success: boolean;
errors: WorkersAiErrorObject[];
}

export const WorkersAiErrorResponseTransform: (
response: WorkersAiErrorResponse
) => ErrorResponse | undefined = (response) => {
if ('errors' in response) {
return generateErrorResponse(
{
message: response.errors
?.map((error) => `Error ${error.code}:${error.message}`)
.join(', '),
type: null,
param: null,
code: null,
},
WORKERS_AI
);
}

return undefined;
};

/**
* The structure of the CohereEmbedResponse.
* @interface
Expand Down
83 changes: 83 additions & 0 deletions src/providers/workers-ai/imageGenerate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import { WORKERS_AI } from '../../globals';
import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from '../types';
import { generateInvalidProviderResponseError } from '../utils';
import {
WorkersAiErrorResponse,
WorkersAiErrorResponseTransform,
} from './utils';

interface WorkersAiImageGenerateResponse extends ImageGenerateResponse {
result: {
image?: string;
};
success: boolean;
errors: string[];
messages: string[];
}

export const WorkersAiImageGenerateConfig: ProviderConfig = {
prompt: {
param: 'prompt',
required: true,
},
negative_prompt: {
param: 'negative_prompt',
},
base64: {
param: 'base64',
transform: (params: any) => true, // Always true to handle uniform responses
default: true,
required: true,
},
steps: [
{
param: 'num_steps',
},
{
param: 'steps',
},
],
size: [
{
param: 'height',
transform: (params: any) =>
parseInt(params.size.toLowerCase().split('x')[1]),
},
{
param: 'width',
transform: (params: any) =>
parseInt(params.size.toLowerCase().split('x')[0]),
},
],
seed: {
param: 'seed',
},
};

export const WorkersAiImageGenerateResponseTransform: (
response: WorkersAiImageGenerateResponse | WorkersAiErrorResponse,
responseStatus: number
) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => {
if (responseStatus !== 200) {
const errorResponse = WorkersAiErrorResponseTransform(
response as WorkersAiErrorResponse
);
if (errorResponse) return errorResponse;
}

// the imageGenerate response is not always the same for all Cloudflare image models
// we currently only support the image model that returns a base64 image
if ('result' in response && 'image' in response.result) {
return {
created: Math.floor(Date.now() / 1000),
data: [
{
b64_json: response.result.image,
},
],
provider: WORKERS_AI,
};
}

return generateInvalidProviderResponseError(response, WORKERS_AI);
};
6 changes: 6 additions & 0 deletions src/providers/workers-ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,24 @@ import {
WorkersAiCompleteStreamChunkTransform,
} from './complete';
import { WorkersAiEmbedConfig, WorkersAiEmbedResponseTransform } from './embed';
import {
WorkersAiImageGenerateConfig,
WorkersAiImageGenerateResponseTransform,
} from './imageGenerate';

const WorkersAiConfig: ProviderConfigs = {
complete: WorkersAiCompleteConfig,
chatComplete: WorkersAiChatCompleteConfig,
api: WorkersAiAPIConfig,
embed: WorkersAiEmbedConfig,
imageGenerate: WorkersAiImageGenerateConfig,
responseTransforms: {
'stream-complete': WorkersAiCompleteStreamChunkTransform,
complete: WorkersAiCompleteResponseTransform,
chatComplete: WorkersAiChatCompleteResponseTransform,
'stream-chatComplete': WorkersAiChatCompleteStreamChunkTransform,
embed: WorkersAiEmbedResponseTransform,
imageGenerate: WorkersAiImageGenerateResponseTransform,
},
};

Expand Down
33 changes: 33 additions & 0 deletions src/providers/workers-ai/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import { ErrorResponse } from '../types';
import { generateErrorResponse } from '../utils';
import { WORKERS_AI } from '../../globals';

export interface WorkersAiErrorResponse {
success: boolean;
errors: WorkersAiErrorObject[];
}

export interface WorkersAiErrorObject {
code: string;
message: string;
}

export const WorkersAiErrorResponseTransform: (
response: WorkersAiErrorResponse
) => ErrorResponse | undefined = (response) => {
if ('errors' in response) {
return generateErrorResponse(
{
message: response.errors
?.map((error) => `Error ${error.code}:${error.message}`)
.join(', '),
type: null,
param: null,
code: null,
},
WORKERS_AI
);
}

return undefined;
};