Skip to content

Commit 2a1660a

Browse files
committed
✨ support claude 3 params and stream handling
1 parent 50461af commit 2a1660a

File tree

5 files changed

+376
-71
lines changed

5 files changed

+376
-71
lines changed

src/libs/agent-runtime/anthropic/index.ts

Lines changed: 12 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,11 @@ import { AgentRuntimeErrorType } from '../error';
99
import {
1010
ChatCompetitionOptions,
1111
ChatStreamPayload,
12-
ModelProvider,
13-
OpenAIChatMessage,
14-
UserMessageContentPart,
12+
ModelProvider
1513
} from '../types';
1614
import { AgentRuntimeError } from '../utils/createError';
1715
import { debugStream } from '../utils/debugStream';
18-
import { parseDataUri } from '../utils/uriParser';
16+
import { buildAnthropicMessages } from '../utils/anthropicHelpers';
1917

2018
export class LobeAnthropicAI implements LobeRuntimeAI {
2119
private client: Anthropic;
@@ -26,40 +24,22 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
2624
this.client = new Anthropic({ apiKey });
2725
}
2826

29-
private buildAnthropicMessages = (
30-
messages: OpenAIChatMessage[],
31-
): Anthropic.Messages.MessageParam[] =>
32-
messages.map((message) => this.convertToAnthropicMessage(message));
33-
34-
private convertToAnthropicMessage = (
35-
message: OpenAIChatMessage,
36-
): Anthropic.Messages.MessageParam => {
37-
const content = message.content as string | UserMessageContentPart[];
38-
39-
return {
40-
content:
41-
typeof content === 'string' ? content : content.map((c) => this.convertToAnthropicBlock(c)),
42-
role: message.role === 'function' || message.role === 'system' ? 'assistant' : message.role,
43-
};
44-
};
45-
4627
async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
4728
const { messages, model, max_tokens, temperature, top_p } = payload;
4829
const system_message = messages.find((m) => m.role === 'system');
4930
const user_messages = messages.filter((m) => m.role !== 'system');
5031

51-
const requestParams: Anthropic.MessageCreateParams = {
52-
max_tokens: max_tokens || 4096,
53-
messages: this.buildAnthropicMessages(user_messages),
54-
model: model,
55-
stream: true,
56-
system: system_message?.content as string,
57-
temperature: temperature,
58-
top_p: top_p,
59-
};
60-
6132
try {
62-
const response = await this.client.messages.create(requestParams);
33+
const response = await this.client.messages.create({
34+
max_tokens: max_tokens || 4096,
35+
messages: buildAnthropicMessages(user_messages),
36+
model: model,
37+
stream: true,
38+
system: system_message?.content as string,
39+
temperature: temperature,
40+
top_p: top_p,
41+
});
42+
6343
const [prod, debug] = response.tee();
6444

6545
if (process.env.DEBUG_ANTHROPIC_CHAT_COMPLETION === '1') {
@@ -91,29 +71,6 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
9171
});
9272
}
9373
}
94-
95-
private convertToAnthropicBlock(
96-
content: UserMessageContentPart,
97-
): Anthropic.ContentBlock | Anthropic.ImageBlockParam {
98-
switch (content.type) {
99-
case 'text': {
100-
return content;
101-
}
102-
103-
case 'image_url': {
104-
const { mimeType, base64 } = parseDataUri(content.image_url.url);
105-
106-
return {
107-
source: {
108-
data: base64 as string,
109-
media_type: mimeType as Anthropic.ImageBlockParam.Source['media_type'],
110-
type: 'base64',
111-
},
112-
type: 'image',
113-
};
114-
}
115-
}
116-
}
11774
}
11875

11976
export default LobeAnthropicAI;
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
// @vitest-environment node
2+
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
3+
4+
import {
5+
InvokeModelWithResponseStreamCommand,
6+
} from '@aws-sdk/client-bedrock-runtime';
7+
import * as debugStreamModule from '../utils/debugStream';
8+
import { LobeBedrockAI } from './index';
9+
10+
const provider = 'bedrock';
11+
12+
// Mock the console.error to avoid polluting test output
13+
vi.spyOn(console, 'error').mockImplementation(() => {});
14+
15+
vi.mock("@aws-sdk/client-bedrock-runtime", async (importOriginal) => {
16+
const module = await importOriginal();
17+
return {
18+
...(module as any),
19+
InvokeModelWithResponseStreamCommand: vi.fn()
20+
}
21+
})
22+
23+
let instance: LobeBedrockAI;
24+
25+
beforeEach(() => {
26+
instance = new LobeBedrockAI({
27+
region: 'us-west-2',
28+
accessKeyId: 'test-access-key-id',
29+
accessKeySecret: 'test-access-key-secret',
30+
});
31+
32+
vi.spyOn(instance['client'], 'send').mockReturnValue(new ReadableStream() as any);
33+
});
34+
35+
afterEach(() => {
36+
vi.clearAllMocks();
37+
});
38+
39+
describe('LobeBedrockAI', () => {
40+
describe('init', () => {
41+
it('should correctly initialize with AWS credentials', async () => {
42+
const instance = new LobeBedrockAI({
43+
region: 'us-west-2',
44+
accessKeyId: 'test-access-key-id',
45+
accessKeySecret: 'test-access-key-secret',
46+
});
47+
expect(instance).toBeInstanceOf(LobeBedrockAI);
48+
});
49+
});
50+
51+
describe('chat', () => {
52+
53+
describe('Claude model', () => {
54+
55+
it('should return a Response on successful API call', async () => {
56+
const result = await instance.chat({
57+
messages: [{ content: 'Hello', role: 'user' }],
58+
model: 'anthropic.claude-v2:1',
59+
temperature: 0,
60+
});
61+
62+
// Assert
63+
expect(result).toBeInstanceOf(Response);
64+
});
65+
66+
it('should handle text messages correctly', async () => {
67+
// Arrange
68+
const mockStream = new ReadableStream({
69+
start(controller) {
70+
controller.enqueue('Hello, world!');
71+
controller.close();
72+
},
73+
});
74+
const mockResponse = Promise.resolve(mockStream);
75+
(instance['client'].send as Mock).mockResolvedValue(mockResponse);
76+
77+
// Act
78+
const result = await instance.chat({
79+
messages: [{ content: 'Hello', role: 'user' }],
80+
model: 'anthropic.claude-v2:1',
81+
temperature: 0,
82+
top_p: 1,
83+
});
84+
85+
// Assert
86+
expect(InvokeModelWithResponseStreamCommand).toHaveBeenCalledWith({
87+
accept: 'application/json',
88+
body: JSON.stringify({
89+
anthropic_version: "bedrock-2023-05-31",
90+
max_tokens: 4096,
91+
messages: [{ content: 'Hello', role: 'user' }],
92+
temperature: 0,
93+
top_p: 1,
94+
}),
95+
contentType: 'application/json',
96+
modelId: 'anthropic.claude-v2:1',
97+
});
98+
expect(result).toBeInstanceOf(Response);
99+
});
100+
101+
it('should handle system prompt correctly', async () => {
102+
// Arrange
103+
const mockStream = new ReadableStream({
104+
start(controller) {
105+
controller.enqueue('Hello, world!');
106+
controller.close();
107+
},
108+
});
109+
const mockResponse = Promise.resolve(mockStream);
110+
(instance['client'].send as Mock).mockResolvedValue(mockResponse);
111+
112+
// Act
113+
const result = await instance.chat({
114+
messages: [
115+
{ content: 'You are an awesome greeter', role: 'system' },
116+
{ content: 'Hello', role: 'user' },
117+
],
118+
model: 'anthropic.claude-v2:1',
119+
temperature: 0,
120+
top_p: 1,
121+
});
122+
123+
// Assert
124+
expect(InvokeModelWithResponseStreamCommand).toHaveBeenCalledWith({
125+
accept: 'application/json',
126+
body: JSON.stringify({
127+
anthropic_version: "bedrock-2023-05-31",
128+
max_tokens: 4096,
129+
messages: [{ content: 'Hello', role: 'user' }],
130+
system: 'You are an awesome greeter',
131+
temperature: 0,
132+
top_p: 1,
133+
}),
134+
contentType: 'application/json',
135+
modelId: 'anthropic.claude-v2:1',
136+
});
137+
expect(result).toBeInstanceOf(Response);
138+
});
139+
140+
it('should call Anthropic model with supported opions', async () => {
141+
// Arrange
142+
const mockStream = new ReadableStream({
143+
start(controller) {
144+
controller.enqueue('Hello, world!');
145+
controller.close();
146+
},
147+
});
148+
const mockResponse = Promise.resolve(mockStream);
149+
(instance['client'].send as Mock).mockResolvedValue(mockResponse);
150+
151+
// Act
152+
const result = await instance.chat({
153+
max_tokens: 2048,
154+
messages: [{ content: 'Hello', role: 'user' }],
155+
model: 'anthropic.claude-v2:1',
156+
temperature: 0.5,
157+
top_p: 1,
158+
});
159+
160+
// Assert
161+
expect(InvokeModelWithResponseStreamCommand).toHaveBeenCalledWith({
162+
accept: 'application/json',
163+
body: JSON.stringify({
164+
anthropic_version: "bedrock-2023-05-31",
165+
max_tokens: 2048,
166+
messages: [{ content: 'Hello', role: 'user' }],
167+
temperature: 0.5,
168+
top_p: 1,
169+
}),
170+
contentType: 'application/json',
171+
modelId: 'anthropic.claude-v2:1',
172+
});
173+
expect(result).toBeInstanceOf(Response);
174+
});
175+
176+
it('should call Anthropic model without unsupported opions', async () => {
177+
// Arrange
178+
const mockStream = new ReadableStream({
179+
start(controller) {
180+
controller.enqueue('Hello, world!');
181+
controller.close();
182+
},
183+
});
184+
const mockResponse = Promise.resolve(mockStream);
185+
(instance['client'].send as Mock).mockResolvedValue(mockResponse);
186+
187+
// Act
188+
const result = await instance.chat({
189+
frequency_penalty: 0.5, // Unsupported option
190+
max_tokens: 2048,
191+
messages: [{ content: 'Hello', role: 'user' }],
192+
model: 'anthropic.claude-v2:1',
193+
presence_penalty: 0.5,
194+
temperature: 0.5,
195+
top_p: 1,
196+
});
197+
198+
// Assert
199+
expect(InvokeModelWithResponseStreamCommand).toHaveBeenCalledWith({
200+
accept: 'application/json',
201+
body: JSON.stringify({
202+
anthropic_version: "bedrock-2023-05-31",
203+
max_tokens: 2048,
204+
messages: [{ content: 'Hello', role: 'user' }],
205+
temperature: 0.5,
206+
top_p: 1,
207+
}),
208+
contentType: 'application/json',
209+
modelId: 'anthropic.claude-v2:1',
210+
});
211+
expect(result).toBeInstanceOf(Response);
212+
});
213+
214+
});
215+
216+
});
217+
});

0 commit comments

Comments
 (0)