Skip to content

Commit a6f4791

Browse files
committed
🐛 fix: implement the chat
1 parent 6e30e85 commit a6f4791

File tree

8 files changed

+91
-15
lines changed

8 files changed

+91
-15
lines changed

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@
236236
"devDependencies": {
237237
"@commitlint/cli": "^19.4.0",
238238
"@edge-runtime/vm": "^4.0.2",
239+
"@huggingface/tasks": "^0.12.12",
239240
"@lobehub/i18n-cli": "^1.19.1",
240241
"@lobehub/lint": "^1.24.4",
241242
"@lobehub/seo-cli": "^1.4.2",

src/app/api/chat/agentRuntime.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,16 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => {
225225

226226
return { apiKey, baseURL };
227227
}
228+
229+
case ModelProvider.HuggingFace: {
230+
const { HUGGINGFACE_PROXY_URL, HUGGINGFACE_API_KEY } = getLLMConfig();
231+
232+
const apiKey = apiKeyManager.pick(payload?.apiKey || HUGGINGFACE_API_KEY);
233+
const baseURL = payload?.endpoint || HUGGINGFACE_PROXY_URL;
234+
235+
return { apiKey, baseURL };
236+
}
237+
228238
case ModelProvider.Upstage: {
229239
const { UPSTAGE_API_KEY } = getLLMConfig();
230240

src/config/llm.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ export const getLLMConfig = () => {
125125
ENABLED_HUNYUAN: z.boolean(),
126126
HUNYUAN_API_KEY: z.string().optional(),
127127
HUNYUAN_MODEL_LIST: z.string().optional(),
128+
129+
ENABLED_HUGGINGFACE: z.boolean(),
130+
HUGGINGFACE_API_KEY: z.string().optional(),
131+
HUGGINGFACE_PROXY_URL: z.string().optional(),
128132
},
129133
runtimeEnv: {
130134
API_KEY_SELECT_MODE: process.env.API_KEY_SELECT_MODE,
@@ -247,6 +251,10 @@ export const getLLMConfig = () => {
247251
ENABLED_HUNYUAN: !!process.env.HUNYUAN_API_KEY,
248252
HUNYUAN_API_KEY: process.env.HUNYUAN_API_KEY,
249253
HUNYUAN_MODEL_LIST: process.env.HUNYUAN_MODEL_LIST,
254+
255+
ENABLED_HUGGINGFACE: !!process.env.HUGGINGFACE_API_KEY,
256+
HUGGINGFACE_API_KEY: process.env.HUGGINGFACE_API_KEY,
257+
HUGGINGFACE_PROXY_URL: process.env.HUGGINGFACE_PROXY_URL,
250258
},
251259
});
252260
};

src/config/modelProviders/huggingface.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ import { ModelProviderCard } from '@/types/llm';
22

33
// ref https://cloud.tencent.com/document/product/1729/104753
44
const HuggingFace: ModelProviderCard = {
5-
chatModels: [],
5+
chatModels: [
6+
{ id: 'google/gemma-2-2b-it' },
7+
{ enabled: true, id: 'mistralai/Mistral-7B-Instruct-v0.2' },
8+
],
69
checkModel: 'meta-llama/Llama-3.1-8B-Instruct',
710
description:
811
'HuggingFace Inference API 提供了一种快速且免费的方式,让您可以探索成千上万种模型,适用于各种任务。无论您是在为新应用程序进行原型设计,还是在尝试机器学习的功能,这个 API 都能让您即时访问多个领域的高性能模型。',

src/libs/agent-runtime/AgentRuntime.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class AgentRuntime {
135135
github: Partial<ClientOptions>;
136136
google: { apiKey?: string; baseURL?: string };
137137
groq: Partial<ClientOptions>;
138-
huggingface: { apiKey?: string };
138+
huggingface: { apiKey?: string; baseURL?: string };
139139
hunyuan: Partial<ClientOptions>;
140140
minimax: Partial<ClientOptions>;
141141
mistral: Partial<ClientOptions>;

src/libs/agent-runtime/huggingface/index.ts

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { HfInference } from '@huggingface/inference';
2-
import { HuggingFaceStream, StreamingTextResponse } from 'ai';
32

43
import {
54
AgentRuntimeError,
@@ -8,43 +7,49 @@ import {
87
ChatStreamPayload,
98
LobeRuntimeAI,
109
} from '@/libs/agent-runtime';
10+
import { OpenAIStream } from '@/libs/agent-runtime/utils/streams';
1111

1212
import { debugStream } from '../utils/debugStream';
13+
import { StreamingResponse } from '../utils/response';
14+
import { HuggingfaceResultToStream } from '../utils/streams/huggingface';
1315

1416
export class LobeHuggingFaceAI implements LobeRuntimeAI {
1517
private client: HfInference;
1618
baseURL?: string;
1719

18-
constructor({ apiKey }: { apiKey?: string } = {}) {
20+
constructor({ apiKey, baseURL }: { apiKey?: string; baseURL?: string } = {}) {
1921
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
2022

2123
this.client = new HfInference(apiKey);
24+
25+
if (baseURL) {
26+
this.client.endpoint(baseURL);
27+
}
2228
}
2329

2430
async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
2531
try {
26-
const hfStream = this.client.textGenerationStream({
27-
inputs: payload.messages,
32+
const hfRes = this.client.chatCompletionStream({
33+
messages: payload.messages,
2834
model: payload.model,
29-
parameters: {
30-
temperature: payload.temperature,
31-
top_p: payload.top_p,
32-
},
35+
3336
stream: true,
37+
temperature: payload.temperature,
38+
top_p: payload.top_p,
3439
});
3540

41+
const rawStream = HuggingfaceResultToStream(hfRes);
3642
// Convert the response into a friendly text-stream
37-
38-
const stream = HuggingFaceStream(hfStream, options?.callback);
39-
40-
const [debug, output] = stream.tee();
43+
const [debug, output] = rawStream.tee();
4144

4245
if (process.env.DEBUG_HUGGINGFACE_CHAT_COMPLETION === '1') {
4346
debugStream(debug).catch(console.error);
4447
}
4548

49+
const stream = OpenAIStream(output, options?.callback);
50+
4651
// Respond with the stream
47-
return new StreamingTextResponse(output, { headers: options?.headers });
52+
return StreamingResponse(stream, { headers: options?.headers });
4853
} catch (e) {
4954
const err = e as Error;
5055

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import { ChatCompletionStreamOutput } from '@huggingface/tasks';
2+
import { readableFromAsyncIterable } from 'ai';
3+
4+
import { ChatStreamCallbacks } from '@/libs/agent-runtime';
5+
import { nanoid } from '@/utils/uuid';
6+
7+
import { ChatResp } from '../../wenxin/type';
8+
import {
9+
StreamProtocolChunk,
10+
StreamStack,
11+
chatStreamable,
12+
createCallbacksTransformer,
13+
createSSEProtocolTransformer,
14+
} from './protocol';
15+
16+
const transformHuggingfaceStream = (chunk: ChatResp): StreamProtocolChunk => {
17+
console.log(chunk);
18+
const finished = chunk.is_end;
19+
if (finished) {
20+
return { data: chunk.finish_reason || 'stop', id: chunk.id, type: 'stop' };
21+
}
22+
23+
if (chunk.result) {
24+
return { data: chunk.result, id: chunk.id, type: 'text' };
25+
}
26+
27+
return {
28+
data: chunk,
29+
id: chunk.id,
30+
type: 'data',
31+
};
32+
};
33+
34+
export const HuggingfaceResultToStream = (stream: AsyncIterable<ChatCompletionStreamOutput>) => {
35+
// make the response to the streamable format
36+
return readableFromAsyncIterable(chatStreamable(stream));
37+
};
38+
39+
export const HuggingFaceStream = (
40+
rawStream: ReadableStream<ChatResp>,
41+
callbacks?: ChatStreamCallbacks,
42+
) => {
43+
const streamStack: StreamStack = { id: 'chat_' + nanoid() };
44+
45+
return rawStream
46+
.pipeThrough(createSSEProtocolTransformer(transformHuggingfaceStream, streamStack))
47+
.pipeThrough(createCallbacksTransformer(callbacks));
48+
};

src/libs/agent-runtime/utils/streams/protocol.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ export const generateToolCallId = (index: number, functionName?: string) =>
3838

3939
export const chatStreamable = async function* <T>(stream: AsyncIterable<T>) {
4040
for await (const response of stream) {
41+
console.log(response);
4142
yield response;
4243
}
4344
};

0 commit comments

Comments
 (0)