Skip to content

Commit 8d1dd83

Browse files
committed
sdk release
2 parents db4bc9e + 138c013 commit 8d1dd83

File tree

7 files changed

+113
-10
lines changed

7 files changed

+113
-10
lines changed

src/examples/cohere_example/rerank.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import cohere
22
from dotenv import find_dotenv, load_dotenv
3+
from datetime import datetime
34

45
from langtrace_python_sdk import langtrace
56

@@ -16,10 +17,22 @@
1617
# @with_langtrace_root_span("embed_create")
1718
def rerank():
1819
docs = [
19-
"Carson City is the capital city of the American state of Nevada.",
20-
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
21-
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
22-
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
20+
{
21+
"text": "Carson City is the capital city of the American state of Nevada.",
22+
"date": datetime.now(),
23+
},
24+
{
25+
"text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
26+
"date": datetime(2020, 5, 17),
27+
},
28+
{
29+
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
30+
"date": datetime(1776, 7, 4),
31+
},
32+
{
33+
"text": "Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
34+
"date": datetime(2023, 9, 14),
35+
},
2336
]
2437

2538
response = co.rerank(

src/examples/langchain_example/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .basic import basic_app, rag, load_and_split
33
from langtrace_python_sdk import with_langtrace_root_span
44

5-
from .groq_example import groq_basic, groq_streaming
5+
from .groq_example import groq_basic, groq_tool_choice, groq_streaming
66
from .langgraph_example_tools import basic_graph_tools
77

88

@@ -20,3 +20,5 @@ class GroqRunner:
2020
@with_langtrace_root_span("Groq")
2121
def run(self):
2222
groq_streaming()
23+
groq_basic()
24+
groq_tool_choice()

src/examples/langchain_example/groq_example.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import json
2+
13
from dotenv import find_dotenv, load_dotenv
2-
from langchain_core.prompts import ChatPromptTemplate
3-
from langchain_groq import ChatGroq
44
from groq import Groq
55

66
_ = load_dotenv(find_dotenv())
@@ -30,6 +30,82 @@ def groq_basic():
3030
return chat_completion
3131

3232

33+
def groq_tool_choice():
34+
35+
user_prompt = "What is 25 * 4 + 10?"
36+
MODEL = "llama3-groq-70b-8192-tool-use-preview"
37+
38+
def calculate(expression):
39+
"""Evaluate a mathematical expression"""
40+
try:
41+
result = eval(expression)
42+
return json.dumps({"result": result})
43+
except:
44+
return json.dumps({"error": "Invalid expression"})
45+
46+
messages = [
47+
{
48+
"role": "system",
49+
"content": "You are a calculator assistant. Use the calculate function to perform mathematical operations and provide the results.",
50+
},
51+
{
52+
"role": "user",
53+
"content": user_prompt,
54+
},
55+
]
56+
tools = [
57+
{
58+
"type": "function",
59+
"function": {
60+
"name": "calculate",
61+
"description": "Evaluate a mathematical expression",
62+
"parameters": {
63+
"type": "object",
64+
"properties": {
65+
"expression": {
66+
"type": "string",
67+
"description": "The mathematical expression to evaluate",
68+
}
69+
},
70+
"required": ["expression"],
71+
},
72+
},
73+
}
74+
]
75+
response = client.chat.completions.create(
76+
model=MODEL,
77+
messages=messages,
78+
tools=tools,
79+
tool_choice={"type": "function", "function": {"name": "calculate"}},
80+
max_tokens=4096,
81+
)
82+
83+
response_message = response.choices[0].message
84+
tool_calls = response_message.tool_calls
85+
if tool_calls:
86+
available_functions = {
87+
"calculate": calculate,
88+
}
89+
messages.append(response_message)
90+
for tool_call in tool_calls:
91+
function_name = tool_call.function.name
92+
function_to_call = available_functions[function_name]
93+
function_args = json.loads(tool_call.function.arguments)
94+
function_response = function_to_call(
95+
expression=function_args.get("expression")
96+
)
97+
messages.append(
98+
{
99+
"tool_call_id": tool_call.id,
100+
"role": "tool",
101+
"name": function_name,
102+
"content": function_response,
103+
}
104+
)
105+
second_response = client.chat.completions.create(model=MODEL, messages=messages)
106+
return second_response.choices[0].message.content
107+
108+
33109
def groq_streaming():
34110
chat_completion = client.chat.completions.create(
35111
messages=[

src/langtrace_python_sdk/instrumentation/cohere/patch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from langtrace.trace_attributes import Event, LLMSpanAttributes
2929
from langtrace_python_sdk.utils import set_span_attribute
30+
from langtrace_python_sdk.utils.misc import datetime_encoder
3031
from opentelemetry.trace import SpanKind
3132
from opentelemetry.trace.status import Status, StatusCode
3233

@@ -50,7 +51,9 @@ def traced_method(wrapped, instance, args, kwargs):
5051
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus",
5152
SpanAttributes.LLM_URL: APIS["RERANK"]["URL"],
5253
SpanAttributes.LLM_PATH: APIS["RERANK"]["ENDPOINT"],
53-
SpanAttributes.LLM_REQUEST_DOCUMENTS: json.dumps(kwargs.get("documents")),
54+
SpanAttributes.LLM_REQUEST_DOCUMENTS: json.dumps(
55+
kwargs.get("documents"), cls=datetime_encoder
56+
),
5457
SpanAttributes.LLM_COHERE_RERANK_QUERY: kwargs.get("query"),
5558
**get_extra_attributes(),
5659
}

src/langtrace_python_sdk/utils/llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def get_llm_request_attributes(kwargs, prompts=None, model=None, operation_name=
124124

125125
top_p = kwargs.get("p", None) or kwargs.get("top_p", None)
126126
tools = kwargs.get("tools", None)
127+
tool_choice = kwargs.get("tool_choice", None)
127128
return {
128129
SpanAttributes.LLM_OPERATION_NAME: operation_name,
129130
SpanAttributes.LLM_REQUEST_MODEL: model
@@ -141,7 +142,7 @@ def get_llm_request_attributes(kwargs, prompts=None, model=None, operation_name=
141142
SpanAttributes.LLM_FREQUENCY_PENALTY: kwargs.get("frequency_penalty"),
142143
SpanAttributes.LLM_REQUEST_SEED: kwargs.get("seed"),
143144
SpanAttributes.LLM_TOOLS: json.dumps(tools) if tools else None,
144-
SpanAttributes.LLM_TOOL_CHOICE: kwargs.get("tool_choice"),
145+
SpanAttributes.LLM_TOOL_CHOICE: json.dumps(tool_choice) if tool_choice else None,
145146
SpanAttributes.LLM_REQUEST_LOGPROPS: kwargs.get("logprobs"),
146147
SpanAttributes.LLM_REQUEST_LOGITBIAS: kwargs.get("logit_bias"),
147148
SpanAttributes.LLM_REQUEST_TOP_LOGPROPS: kwargs.get("top_logprobs"),

src/langtrace_python_sdk/utils/misc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,11 @@ def is_serializable(value):
6060

6161
# Convert to string representation
6262
return json.dumps(serializable_args)
63+
64+
65+
class datetime_encoder(json.JSONEncoder):
66+
def default(self, o):
67+
if isinstance(o, datetime):
68+
return o.isoformat()
69+
70+
return json.JSONEncoder.default(self, o)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.3.15"
1+
__version__ = "2.3.16"

0 commit comments

Comments
 (0)