diff --git a/langfuse/callback/langchain.py b/langfuse/callback/langchain.py index 30466e46..812ea4c8 100644 --- a/langfuse/callback/langchain.py +++ b/langfuse/callback/langchain.py @@ -108,6 +108,7 @@ def __init__( ) self.runs = {} + self.prompt_to_parent_run_map = {} if stateful_client and isinstance(stateful_client, StatefulSpanClient): self.runs[stateful_client.id] = stateful_client @@ -203,6 +204,7 @@ def on_chain_start( version=self.version, **kwargs, ) + self._register_langfuse_prompt(parent_run_id, metadata) content = { "id": self.next_span_id, @@ -223,6 +225,21 @@ def on_chain_start( except Exception as e: self.log.exception(e) + def _register_langfuse_prompt( + self, parent_run_id: Optional[UUID], metadata: Optional[Dict[str, Any]] + ): + """We need to register any passed Langfuse prompt to the parent_run_id so that we can link following generations with that prompt. + + If parent_run_id is None, we are at the root of a trace and should not attempt to register the prompt, as there will be no LLM invocation following it. + Otherwise it would have been traced in with a parent run consisting of the prompt template formatting and the LLM invocation. + """ + if metadata and "langfuse_prompt" in metadata and parent_run_id: + self.prompt_to_parent_run_map[parent_run_id] = metadata["langfuse_prompt"] + + def _deregister_langfuse_prompt(self, run_id: Optional[UUID]): + if run_id in self.prompt_to_parent_run_map: + del self.prompt_to_parent_run_map[run_id] + def __generate_trace_and_parent( self, serialized: Dict[str, Any], @@ -384,6 +401,7 @@ def on_chain_end( self._update_trace_and_remove_state( run_id, parent_run_id, outputs, input=kwargs.get("inputs") ) + self._deregister_langfuse_prompt(run_id) except Exception as e: self.log.exception(e) @@ -635,6 +653,10 @@ def __on_llm_action( model_name = None model_name = self._parse_model_and_log_errors(serialized, kwargs) + registered_prompt = self.prompt_to_parent_run_map.get(parent_run_id, None) + + if registered_prompt: + self._deregister_langfuse_prompt(parent_run_id) content = { "name": self.get_langchain_run_name(serialized, **kwargs), @@ -643,6 +665,7 @@ def __on_llm_action( "model": model_name, "model_parameters": self._parse_model_parameters(kwargs), "version": self.version, + "prompt": registered_prompt, } if parent_run_id in self.runs: @@ -794,7 +817,7 @@ def __join_tags_and_metadata( final_dict.update(metadata) if trace_metadata is not None: final_dict.update(trace_metadata) - return final_dict if final_dict != {} else None + return _strip_langfuse_keys_from_dict(final_dict) if final_dict != {} else None def _report_error(self, error: dict): event = SdkLogBody(log=error) @@ -1008,3 +1031,16 @@ def _parse_model(response: LLMResult): break return llm_model + + +def _strip_langfuse_keys_from_dict(metadata: Optional[Dict[str, Any]]): + if metadata is None or not isinstance(metadata, dict): + return metadata + + langfuse_keys = ["langfuse_prompt"] + metadata_copy = metadata.copy() + + for key in langfuse_keys: + metadata_copy.pop(key, None) + + return metadata_copy diff --git a/tests/test_langchain.py b/tests/test_langchain.py index 5c4acb4c..3c743f9c 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -31,6 +31,7 @@ from tests.utils import create_uuid, get_api from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM +from langchain_core.output_parsers import StrOutputParser def test_callback_init(): @@ -1763,68 +1764,246 @@ def test_disabled_langfuse(): api.trace.get(trace_id) -# # Enable this test when the ChatBedrock is available in CI -# def test_chat_bedrock(): -# handler = CallbackHandler(debug=True) +def test_link_langfuse_prompts_invoke(): + langfuse = Langfuse() + trace_name = "test_link_langfuse_prompts_invoke" -# llm = ChatBedrock( -# model_id="anthropic.claude-3-sonnet-20240229-v1:0", -# # model_id="amazon.titan-text-lite-v1", -# region_name="eu-central-1", -# callbacks=[handler], -# ) + # Create prompts + joke_prompt_name = "joke_prompt_" + create_uuid()[:8] + joke_prompt_string = "Tell me a joke involving the animal {{animal}}" -# messages = [ -# ( -# "system", -# "You are a expert software engineer.", -# ), -# ("human", "Give me fizzbuzz algo in C++"), -# ] + explain_prompt_name = "explain_prompt_" + create_uuid()[:8] + explain_prompt_string = "Explain the joke to me like I'm a 5 year old {{joke}}" -# ai_msg = llm.stream("Give me fizzbuzz algo in C++") + langfuse.create_prompt( + name=joke_prompt_name, + prompt=joke_prompt_string, + labels=["production"], + ) -# for chunk in ai_msg: -# print(chunk) + langfuse.create_prompt( + name=explain_prompt_name, + prompt=explain_prompt_string, + labels=["production"], + ) + # Get prompts + langfuse_joke_prompt = langfuse.get_prompt(joke_prompt_name) + langfuse_explain_prompt = langfuse.get_prompt(explain_prompt_name) -# def test_langchain_anthropic_package(): -# langfuse_handler = CallbackHandler(debug=False) -# from langchain_anthropic import ChatAnthropic + langchain_joke_prompt = PromptTemplate.from_template( + langfuse_joke_prompt.get_langchain_prompt(), + metadata={"langfuse_prompt": langfuse_joke_prompt}, + ) -# chat = ChatAnthropic( -# model="claude-3-sonnet-20240229", -# temperature=0.1, -# ) + langchain_explain_prompt = PromptTemplate.from_template( + langfuse_explain_prompt.get_langchain_prompt(), + metadata={"langfuse_prompt": langfuse_explain_prompt}, + ) -# system = "You are a helpful assistant that translates {input_language} to {output_language}." -# human = "{text}" -# prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)]) + # Create chain + parser = StrOutputParser() + model = OpenAI() + chain = ( + {"joke": langchain_joke_prompt | model | parser} + | langchain_explain_prompt + | model + | parser + ) -# chain = prompt | chat -# chain.invoke( -# { -# "input_language": "English", -# "output_language": "Korean", -# "text": "I love Python", -# }, -# config={"callbacks": [langfuse_handler]}, -# ) + # Run chain + langfuse_handler = CallbackHandler(debug=True) + + output = chain.invoke( + {"animal": "dog"}, + config={ + "callbacks": [langfuse_handler], + "run_name": trace_name, + }, + ) + + langfuse_handler.flush() + + observations = get_api().trace.get(langfuse_handler.get_trace_id()).observations + + generations = sorted( + list(filter(lambda x: x.type == "GENERATION", observations)), + key=lambda x: x.start_time, + ) + + assert len(generations) == 2 + assert generations[0].input == "Tell me a joke involving the animal dog" + assert "Explain the joke to me like I'm a 5 year old" in generations[1].input -# langfuse_handler.flush() + assert generations[0].prompt_name == joke_prompt_name + assert generations[1].prompt_name == explain_prompt_name + + assert generations[0].prompt_version == langfuse_joke_prompt.version + assert generations[1].prompt_version == langfuse_explain_prompt.version + + assert generations[1].output == output.strip() + + +def test_link_langfuse_prompts_stream(): + langfuse = Langfuse() + trace_name = "test_link_langfuse_prompts_stream" + + # Create prompts + joke_prompt_name = "joke_prompt_" + create_uuid()[:8] + joke_prompt_string = "Tell me a joke involving the animal {{animal}}" + + explain_prompt_name = "explain_prompt_" + create_uuid()[:8] + explain_prompt_string = "Explain the joke to me like I'm a 5 year old {{joke}}" + + langfuse.create_prompt( + name=joke_prompt_name, + prompt=joke_prompt_string, + labels=["production"], + ) + + langfuse.create_prompt( + name=explain_prompt_name, + prompt=explain_prompt_string, + labels=["production"], + ) + + # Get prompts + langfuse_joke_prompt = langfuse.get_prompt(joke_prompt_name) + langfuse_explain_prompt = langfuse.get_prompt(explain_prompt_name) + + langchain_joke_prompt = PromptTemplate.from_template( + langfuse_joke_prompt.get_langchain_prompt(), + metadata={"langfuse_prompt": langfuse_joke_prompt}, + ) + + langchain_explain_prompt = PromptTemplate.from_template( + langfuse_explain_prompt.get_langchain_prompt(), + metadata={"langfuse_prompt": langfuse_explain_prompt}, + ) + + # Create chain + parser = StrOutputParser() + model = OpenAI() + chain = ( + {"joke": langchain_joke_prompt | model | parser} + | langchain_explain_prompt + | model + | parser + ) -# observations = get_api().trace.get(langfuse_handler.get_trace_id()).observations + # Run chain + langfuse_handler = CallbackHandler(debug=True) + + stream = chain.stream( + {"animal": "dog"}, + config={ + "callbacks": [langfuse_handler], + "run_name": trace_name, + }, + ) + + output = "" + for chunk in stream: + output += chunk + + langfuse_handler.flush() + + observations = get_api().trace.get(langfuse_handler.get_trace_id()).observations + + generations = sorted( + list(filter(lambda x: x.type == "GENERATION", observations)), + key=lambda x: x.start_time, + ) + + assert len(generations) == 2 + assert generations[0].input == "Tell me a joke involving the animal dog" + assert "Explain the joke to me like I'm a 5 year old" in generations[1].input + + assert generations[0].prompt_name == joke_prompt_name + assert generations[1].prompt_name == explain_prompt_name + + assert generations[0].prompt_version == langfuse_joke_prompt.version + assert generations[1].prompt_version == langfuse_explain_prompt.version + + assert generations[1].output == output.strip() + + +def test_link_langfuse_prompts_batch(): + langfuse = Langfuse() + trace_name = "test_link_langfuse_prompts_batch_" + create_uuid()[:8] + + # Create prompts + joke_prompt_name = "joke_prompt_" + create_uuid()[:8] + joke_prompt_string = "Tell me a joke involving the animal {{animal}}" + + explain_prompt_name = "explain_prompt_" + create_uuid()[:8] + explain_prompt_string = "Explain the joke to me like I'm a 5 year old {{joke}}" + + langfuse.create_prompt( + name=joke_prompt_name, + prompt=joke_prompt_string, + labels=["production"], + ) + + langfuse.create_prompt( + name=explain_prompt_name, + prompt=explain_prompt_string, + labels=["production"], + ) + + # Get prompts + langfuse_joke_prompt = langfuse.get_prompt(joke_prompt_name) + langfuse_explain_prompt = langfuse.get_prompt(explain_prompt_name) + + langchain_joke_prompt = PromptTemplate.from_template( + langfuse_joke_prompt.get_langchain_prompt(), + metadata={"langfuse_prompt": langfuse_joke_prompt}, + ) + + langchain_explain_prompt = PromptTemplate.from_template( + langfuse_explain_prompt.get_langchain_prompt(), + metadata={"langfuse_prompt": langfuse_explain_prompt}, + ) + + # Create chain + parser = StrOutputParser() + model = OpenAI() + chain = ( + {"joke": langchain_joke_prompt | model | parser} + | langchain_explain_prompt + | model + | parser + ) + + # Run chain + langfuse_handler = CallbackHandler(debug=True) + + chain.batch( + [{"animal": "dog"}, {"animal": "cat"}, {"animal": "elephant"}], + config={ + "callbacks": [langfuse_handler], + "run_name": trace_name, + }, + ) + + langfuse_handler.flush() + + traces = get_api().trace.list(name=trace_name).data + + assert len(traces) == 3 + + for trace in traces: + observations = get_api().trace.get(trace.id).observations + + generations = sorted( + list(filter(lambda x: x.type == "GENERATION", observations)), + key=lambda x: x.start_time, + ) -# assert len(observations) == 3 + assert len(generations) == 2 -# generation = list(filter(lambda x: x.type == "GENERATION", observations))[0] + assert generations[0].prompt_name == joke_prompt_name + assert generations[1].prompt_name == explain_prompt_name -# assert generation.output is not None -# assert generation.output != "" -# assert generation.input is not None -# assert generation.input != "" -# assert generation.usage is not None -# assert generation.usage.input is not None -# assert generation.usage.output is not None -# assert generation.usage.total is not None -# assert generation.model == "claude-3-sonnet-20240229" + assert generations[0].prompt_version == langfuse_joke_prompt.version + assert generations[1].prompt_version == langfuse_explain_prompt.version