Skip to content

Commit

Permalink
feat(langchain): link langfuse traces to langchain executions (#928)
Browse files Browse the repository at this point in the history
  • Loading branch information
hassiebp authored Sep 17, 2024
1 parent 7159451 commit 583be36
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 52 deletions.
38 changes: 37 additions & 1 deletion langfuse/callback/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
)

self.runs = {}
self.prompt_to_parent_run_map = {}

if stateful_client and isinstance(stateful_client, StatefulSpanClient):
self.runs[stateful_client.id] = stateful_client
Expand Down Expand Up @@ -203,6 +204,7 @@ def on_chain_start(
version=self.version,
**kwargs,
)
self._register_langfuse_prompt(parent_run_id, metadata)

content = {
"id": self.next_span_id,
Expand All @@ -223,6 +225,21 @@ def on_chain_start(
except Exception as e:
self.log.exception(e)

def _register_langfuse_prompt(
self, parent_run_id: Optional[UUID], metadata: Optional[Dict[str, Any]]
):
"""We need to register any passed Langfuse prompt to the parent_run_id so that we can link following generations with that prompt.
If parent_run_id is None, we are at the root of a trace and should not attempt to register the prompt, as there will be no LLM invocation following it.
Otherwise it would have been traced in with a parent run consisting of the prompt template formatting and the LLM invocation.
"""
if metadata and "langfuse_prompt" in metadata and parent_run_id:
self.prompt_to_parent_run_map[parent_run_id] = metadata["langfuse_prompt"]

def _deregister_langfuse_prompt(self, run_id: Optional[UUID]):
if run_id in self.prompt_to_parent_run_map:
del self.prompt_to_parent_run_map[run_id]

def __generate_trace_and_parent(
self,
serialized: Dict[str, Any],
Expand Down Expand Up @@ -384,6 +401,7 @@ def on_chain_end(
self._update_trace_and_remove_state(
run_id, parent_run_id, outputs, input=kwargs.get("inputs")
)
self._deregister_langfuse_prompt(run_id)
except Exception as e:
self.log.exception(e)

Expand Down Expand Up @@ -635,6 +653,10 @@ def __on_llm_action(
model_name = None

model_name = self._parse_model_and_log_errors(serialized, kwargs)
registered_prompt = self.prompt_to_parent_run_map.get(parent_run_id, None)

if registered_prompt:
self._deregister_langfuse_prompt(parent_run_id)

content = {
"name": self.get_langchain_run_name(serialized, **kwargs),
Expand All @@ -643,6 +665,7 @@ def __on_llm_action(
"model": model_name,
"model_parameters": self._parse_model_parameters(kwargs),
"version": self.version,
"prompt": registered_prompt,
}

if parent_run_id in self.runs:
Expand Down Expand Up @@ -794,7 +817,7 @@ def __join_tags_and_metadata(
final_dict.update(metadata)
if trace_metadata is not None:
final_dict.update(trace_metadata)
return final_dict if final_dict != {} else None
return _strip_langfuse_keys_from_dict(final_dict) if final_dict != {} else None

def _report_error(self, error: dict):
event = SdkLogBody(log=error)
Expand Down Expand Up @@ -1008,3 +1031,16 @@ def _parse_model(response: LLMResult):
break

return llm_model


def _strip_langfuse_keys_from_dict(metadata: Optional[Dict[str, Any]]):
if metadata is None or not isinstance(metadata, dict):
return metadata

langfuse_keys = ["langfuse_prompt"]
metadata_copy = metadata.copy()

for key in langfuse_keys:
metadata_copy.pop(key, None)

return metadata_copy
281 changes: 230 additions & 51 deletions tests/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tests.utils import create_uuid, get_api
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.output_parsers import StrOutputParser


def test_callback_init():
Expand Down Expand Up @@ -1763,68 +1764,246 @@ def test_disabled_langfuse():
api.trace.get(trace_id)


# # Enable this test when the ChatBedrock is available in CI
# def test_chat_bedrock():
# handler = CallbackHandler(debug=True)
def test_link_langfuse_prompts_invoke():
langfuse = Langfuse()
trace_name = "test_link_langfuse_prompts_invoke"

# llm = ChatBedrock(
# model_id="anthropic.claude-3-sonnet-20240229-v1:0",
# # model_id="amazon.titan-text-lite-v1",
# region_name="eu-central-1",
# callbacks=[handler],
# )
# Create prompts
joke_prompt_name = "joke_prompt_" + create_uuid()[:8]
joke_prompt_string = "Tell me a joke involving the animal {{animal}}"

# messages = [
# (
# "system",
# "You are a expert software engineer.",
# ),
# ("human", "Give me fizzbuzz algo in C++"),
# ]
explain_prompt_name = "explain_prompt_" + create_uuid()[:8]
explain_prompt_string = "Explain the joke to me like I'm a 5 year old {{joke}}"

# ai_msg = llm.stream("Give me fizzbuzz algo in C++")
langfuse.create_prompt(
name=joke_prompt_name,
prompt=joke_prompt_string,
labels=["production"],
)

# for chunk in ai_msg:
# print(chunk)
langfuse.create_prompt(
name=explain_prompt_name,
prompt=explain_prompt_string,
labels=["production"],
)

# Get prompts
langfuse_joke_prompt = langfuse.get_prompt(joke_prompt_name)
langfuse_explain_prompt = langfuse.get_prompt(explain_prompt_name)

# def test_langchain_anthropic_package():
# langfuse_handler = CallbackHandler(debug=False)
# from langchain_anthropic import ChatAnthropic
langchain_joke_prompt = PromptTemplate.from_template(
langfuse_joke_prompt.get_langchain_prompt(),
metadata={"langfuse_prompt": langfuse_joke_prompt},
)

# chat = ChatAnthropic(
# model="claude-3-sonnet-20240229",
# temperature=0.1,
# )
langchain_explain_prompt = PromptTemplate.from_template(
langfuse_explain_prompt.get_langchain_prompt(),
metadata={"langfuse_prompt": langfuse_explain_prompt},
)

# system = "You are a helpful assistant that translates {input_language} to {output_language}."
# human = "{text}"
# prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])
# Create chain
parser = StrOutputParser()
model = OpenAI()
chain = (
{"joke": langchain_joke_prompt | model | parser}
| langchain_explain_prompt
| model
| parser
)

# chain = prompt | chat
# chain.invoke(
# {
# "input_language": "English",
# "output_language": "Korean",
# "text": "I love Python",
# },
# config={"callbacks": [langfuse_handler]},
# )
# Run chain
langfuse_handler = CallbackHandler(debug=True)

output = chain.invoke(
{"animal": "dog"},
config={
"callbacks": [langfuse_handler],
"run_name": trace_name,
},
)

langfuse_handler.flush()

observations = get_api().trace.get(langfuse_handler.get_trace_id()).observations

generations = sorted(
list(filter(lambda x: x.type == "GENERATION", observations)),
key=lambda x: x.start_time,
)

assert len(generations) == 2
assert generations[0].input == "Tell me a joke involving the animal dog"
assert "Explain the joke to me like I'm a 5 year old" in generations[1].input

# langfuse_handler.flush()
assert generations[0].prompt_name == joke_prompt_name
assert generations[1].prompt_name == explain_prompt_name

assert generations[0].prompt_version == langfuse_joke_prompt.version
assert generations[1].prompt_version == langfuse_explain_prompt.version

assert generations[1].output == output.strip()


def test_link_langfuse_prompts_stream():
langfuse = Langfuse()
trace_name = "test_link_langfuse_prompts_stream"

# Create prompts
joke_prompt_name = "joke_prompt_" + create_uuid()[:8]
joke_prompt_string = "Tell me a joke involving the animal {{animal}}"

explain_prompt_name = "explain_prompt_" + create_uuid()[:8]
explain_prompt_string = "Explain the joke to me like I'm a 5 year old {{joke}}"

langfuse.create_prompt(
name=joke_prompt_name,
prompt=joke_prompt_string,
labels=["production"],
)

langfuse.create_prompt(
name=explain_prompt_name,
prompt=explain_prompt_string,
labels=["production"],
)

# Get prompts
langfuse_joke_prompt = langfuse.get_prompt(joke_prompt_name)
langfuse_explain_prompt = langfuse.get_prompt(explain_prompt_name)

langchain_joke_prompt = PromptTemplate.from_template(
langfuse_joke_prompt.get_langchain_prompt(),
metadata={"langfuse_prompt": langfuse_joke_prompt},
)

langchain_explain_prompt = PromptTemplate.from_template(
langfuse_explain_prompt.get_langchain_prompt(),
metadata={"langfuse_prompt": langfuse_explain_prompt},
)

# Create chain
parser = StrOutputParser()
model = OpenAI()
chain = (
{"joke": langchain_joke_prompt | model | parser}
| langchain_explain_prompt
| model
| parser
)

# observations = get_api().trace.get(langfuse_handler.get_trace_id()).observations
# Run chain
langfuse_handler = CallbackHandler(debug=True)

stream = chain.stream(
{"animal": "dog"},
config={
"callbacks": [langfuse_handler],
"run_name": trace_name,
},
)

output = ""
for chunk in stream:
output += chunk

langfuse_handler.flush()

observations = get_api().trace.get(langfuse_handler.get_trace_id()).observations

generations = sorted(
list(filter(lambda x: x.type == "GENERATION", observations)),
key=lambda x: x.start_time,
)

assert len(generations) == 2
assert generations[0].input == "Tell me a joke involving the animal dog"
assert "Explain the joke to me like I'm a 5 year old" in generations[1].input

assert generations[0].prompt_name == joke_prompt_name
assert generations[1].prompt_name == explain_prompt_name

assert generations[0].prompt_version == langfuse_joke_prompt.version
assert generations[1].prompt_version == langfuse_explain_prompt.version

assert generations[1].output == output.strip()


def test_link_langfuse_prompts_batch():
langfuse = Langfuse()
trace_name = "test_link_langfuse_prompts_batch_" + create_uuid()[:8]

# Create prompts
joke_prompt_name = "joke_prompt_" + create_uuid()[:8]
joke_prompt_string = "Tell me a joke involving the animal {{animal}}"

explain_prompt_name = "explain_prompt_" + create_uuid()[:8]
explain_prompt_string = "Explain the joke to me like I'm a 5 year old {{joke}}"

langfuse.create_prompt(
name=joke_prompt_name,
prompt=joke_prompt_string,
labels=["production"],
)

langfuse.create_prompt(
name=explain_prompt_name,
prompt=explain_prompt_string,
labels=["production"],
)

# Get prompts
langfuse_joke_prompt = langfuse.get_prompt(joke_prompt_name)
langfuse_explain_prompt = langfuse.get_prompt(explain_prompt_name)

langchain_joke_prompt = PromptTemplate.from_template(
langfuse_joke_prompt.get_langchain_prompt(),
metadata={"langfuse_prompt": langfuse_joke_prompt},
)

langchain_explain_prompt = PromptTemplate.from_template(
langfuse_explain_prompt.get_langchain_prompt(),
metadata={"langfuse_prompt": langfuse_explain_prompt},
)

# Create chain
parser = StrOutputParser()
model = OpenAI()
chain = (
{"joke": langchain_joke_prompt | model | parser}
| langchain_explain_prompt
| model
| parser
)

# Run chain
langfuse_handler = CallbackHandler(debug=True)

chain.batch(
[{"animal": "dog"}, {"animal": "cat"}, {"animal": "elephant"}],
config={
"callbacks": [langfuse_handler],
"run_name": trace_name,
},
)

langfuse_handler.flush()

traces = get_api().trace.list(name=trace_name).data

assert len(traces) == 3

for trace in traces:
observations = get_api().trace.get(trace.id).observations

generations = sorted(
list(filter(lambda x: x.type == "GENERATION", observations)),
key=lambda x: x.start_time,
)

# assert len(observations) == 3
assert len(generations) == 2

# generation = list(filter(lambda x: x.type == "GENERATION", observations))[0]
assert generations[0].prompt_name == joke_prompt_name
assert generations[1].prompt_name == explain_prompt_name

# assert generation.output is not None
# assert generation.output != ""
# assert generation.input is not None
# assert generation.input != ""
# assert generation.usage is not None
# assert generation.usage.input is not None
# assert generation.usage.output is not None
# assert generation.usage.total is not None
# assert generation.model == "claude-3-sonnet-20240229"
assert generations[0].prompt_version == langfuse_joke_prompt.version
assert generations[1].prompt_version == langfuse_explain_prompt.version

0 comments on commit 583be36

Please sign in to comment.