From 4102a76b991825136af2996af79b881dfc904e61 Mon Sep 17 00:00:00 2001 From: corey Date: Sat, 10 Aug 2024 09:51:55 -0700 Subject: [PATCH 1/3] add inline data support for instrumentation --- src/core/trulens/core/app/base.py | 20 +++++++++- src/core/trulens/core/feedback/endpoint.py | 7 ++++ src/core/trulens/core/feedback/feedback.py | 20 +++++++--- src/core/trulens/core/instruments.py | 44 +++++++++++++++------- src/core/trulens/core/schema/record.py | 3 ++ src/core/trulens/core/tru.py | 6 ++- 6 files changed, 80 insertions(+), 20 deletions(-) diff --git a/src/core/trulens/core/app/base.py b/src/core/trulens/core/app/base.py index 164a9b71c..a163093ed 100644 --- a/src/core/trulens/core/app/base.py +++ b/src/core/trulens/core/app/base.py @@ -356,6 +356,9 @@ def __init__(self, app: App, record_metadata: JSON = None): self.records: List[mod_record_schema.Record] = [] """Completed records.""" + self.inline_data: Dict[str, Any] = {} + """Inline data to attach to the currently tracked record.""" + self.lock: Lock = Lock() """Lock blocking access to `calls` and `records` when adding calls or finishing a record.""" @@ -410,11 +413,20 @@ def add_call(self, call: mod_record_schema.RecordAppCall): # processing calls with awaitable or generator results. self.calls[call.call_id] = call + def add_inline_data(self, key: str, value: Any, **kwargs): + """ + Add inline data to the currently tracked call list. + """ + with self.lock: + # TODO: make value a constant + self.inline_data[key] = {"value": value, **kwargs} + def finish_record( self, calls_to_record: Callable[ [ List[mod_record_schema.RecordAppCall], + Dict[str, Dict[str, Any]], mod_types_schema.Metadata, Optional[mod_record_schema.Record], ], @@ -429,9 +441,13 @@ def finish_record( with self.lock: record = calls_to_record( - list(self.calls.values()), self.record_metadata, existing_record + list(self.calls.values()), + self.inline_data, + self.record_metadata, + existing_record, ) self.calls = {} + self.inline_data = {} if existing_record is None: # If existing record was given, we assume it was already @@ -1088,6 +1104,7 @@ def on_add_record( def build_record( calls: Iterable[mod_record_schema.RecordAppCall], + inline_data: JSON, record_metadata: JSON, existing_record: Optional[mod_record_schema.Record] = None, ) -> mod_record_schema.Record: @@ -1107,6 +1124,7 @@ def build_record( perf=perf, app_id=self.app_id, tags=self.tags, + inline_data=jsonify(inline_data), meta=jsonify(record_metadata), ) diff --git a/src/core/trulens/core/feedback/endpoint.py b/src/core/trulens/core/feedback/endpoint.py index e4d9449bc..7bc1b2230 100644 --- a/src/core/trulens/core/feedback/endpoint.py +++ b/src/core/trulens/core/feedback/endpoint.py @@ -12,6 +12,7 @@ from time import sleep from types import ModuleType from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -20,6 +21,7 @@ List, Optional, Sequence, + Set, Tuple, Type, TypeVar, @@ -46,6 +48,9 @@ from trulens.core.utils.serial import SerialModel from trulens.core.utils.threading import DEFAULT_NETWORK_TIMEOUT +if TYPE_CHECKING: + from trulens.core.app.base import RecordingContext + logger = logging.getLogger(__name__) pp = PrettyPrinter() @@ -515,6 +520,7 @@ def track_all_costs( @staticmethod def track_all_costs_tally( __func: mod_asynchro_utils.CallableMaybeAwaitable[A, T], + contexts: Set[RecordingContext], *args, with_openai: bool = True, with_hugs: bool = True, @@ -527,6 +533,7 @@ def track_all_costs_tally( Track costs of all of the apis we can currently track, over the execution of thunk. """ + assert contexts, "No recording context set." result, cbs = Endpoint.track_all_costs( __func, diff --git a/src/core/trulens/core/feedback/feedback.py b/src/core/trulens/core/feedback/feedback.py index a52bb0cec..c25449633 100644 --- a/src/core/trulens/core/feedback/feedback.py +++ b/src/core/trulens/core/feedback/feedback.py @@ -16,6 +16,7 @@ Dict, Iterable, List, + Mapping, Optional, Tuple, TypeVar, @@ -359,7 +360,7 @@ def evaluate_deferred( def prepare_feedback( row, - ) -> Optional[mod_feedback_schema.FeedbackResultStatus]: + ) -> Optional[mod_feedback_schema.FeedbackResult]: record_json = row.record_json record = mod_record_schema.Record.model_validate(record_json) @@ -1047,9 +1048,10 @@ def run_and_log( ) ) - feedback_result = self.run(app=app, record=record).update( - feedback_result_id=feedback_result_id - ) + feedback_result = self.run( + app=app, + record=record, + ).update(feedback_result_id=feedback_result_id) except Exception: # Convert traceback to a UTF-8 string, replacing errors to avoid encoding issues @@ -1189,9 +1191,17 @@ def _construct_source_data( if app is not None: source_data["__app__"] = app - if record is not None: + if record: source_data["__record__"] = record.layout_calls_as_app() + if isinstance(record.inline_data, Mapping): + inline_data = { + k: val["value"] + for k, val in record.inline_data + if isinstance(val, Mapping) and "value" in val + } + source_data = {**source_data, **inline_data} + return source_data def extract_selection( diff --git a/src/core/trulens/core/instruments.py b/src/core/trulens/core/instruments.py index 9ab153a38..ace85a7ce 100644 --- a/src/core/trulens/core/instruments.py +++ b/src/core/trulens/core/instruments.py @@ -61,6 +61,7 @@ from trulens.core.utils.text import retab if TYPE_CHECKING: + from trulens.core.app import App from trulens.core.app.base import RecordingContext logger = logging.getLogger(__name__) @@ -406,21 +407,22 @@ def tru_wrapper(*args, **kwargs): inspect.isasyncgenfunction(func), ) - apps = getattr(tru_wrapper, Instrument.APPS) + apps: Iterable[App] = getattr(tru_wrapper, Instrument.APPS) # If not within a root method, call the wrapped function without # any recording. # Get any contexts already known from higher in the call stack. - contexts = get_first_local_in_call_stack( - key="contexts", - func=find_instrumented, - offset=1, - skip=python.caller_frame(), + _contexts: Optional[Set[RecordingContext]] = ( + get_first_local_in_call_stack( + key="contexts", + func=find_instrumented, + offset=1, + skip=python.caller_frame(), + ) ) # Note: are empty sets false? - if contexts is None: - contexts = set() + contexts: Set[RecordingContext] = _contexts or set() # And add any new contexts from all apps wishing to record this # function. This may produce some of the same contexts that were @@ -480,11 +482,8 @@ def tru_wrapper(*args, **kwargs): # First prepare the stacks for each context. for ctx in contexts: - # Get app that has instrumented this method. - app = ctx.app - # The path to this method according to the app. - path = app.get_method_path( + path = ctx.app.get_method_path( args[0], func ) # hopefully args[0] is self, owner of func @@ -532,7 +531,7 @@ def tru_wrapper(*args, **kwargs): bindings: BoundArguments = sig.bind(*args, **kwargs) rets, cost = mod_endpoint.Endpoint.track_all_costs_tally( - func, *args, **kwargs + func, contexts, *args, **kwargs ) except BaseException as e: @@ -1035,3 +1034,22 @@ def __set_name__(self, cls: type, name: str): # Note that this does not actually change the method, just adds it to # list of filters. self.method(cls, name) + + +def tag(key: str, value: Any, collection: bool = False): + """Set inline data for the given key.""" + + def _find_contexts_frame(f): + return id(f) == id(mod_endpoint.Endpoint.track_all_costs_tally.__code__) + + # get previously known inline data + contexts: Optional[Set[RecordingContext]] = get_first_local_in_call_stack( + key="contexts", + func=_find_contexts_frame, + offset=1, + skip=python.caller_frame(), + ) + # Note: are empty sets false? + if contexts: + for context in contexts: + context.add_inline_data(key, jsonify(value), collection=collection) diff --git a/src/core/trulens/core/schema/record.py b/src/core/trulens/core/schema/record.py index 382ff8e99..71e1355d1 100644 --- a/src/core/trulens/core/schema/record.py +++ b/src/core/trulens/core/schema/record.py @@ -131,6 +131,9 @@ class Record(serial.SerialModel, Hashable): main_error: Optional[serial.JSON] = None # if error """The app's main error if there was an error.""" + inline_data: Optional[serial.JSON] = None + """Inline data added to the record.""" + calls: List[RecordAppCall] = [] """The collection of calls recorded. diff --git a/src/core/trulens/core/tru.py b/src/core/trulens/core/tru.py index ad7d5e2e4..2a410c6c0 100644 --- a/src/core/trulens/core/tru.py +++ b/src/core/trulens/core/tru.py @@ -485,7 +485,11 @@ def _submit_feedback_functions( for ffunc in feedback_functions: # Run feedback function and the on_done callback. This makes sure # that Future.result() returns only after on_done has finished. - def run_and_call_callback(ffunc, app, record): + def run_and_call_callback( + ffunc: feedback.Feedback, + app: mod_app_schema.AppDefinition, + record: mod_record_schema.Record, + ): temp = ffunc.run(app=app, record=record) if on_done is not None: try: From 565d16342f71a1e43bfcdc93e4e9901884863cef Mon Sep 17 00:00:00 2001 From: corey Date: Sat, 10 Aug 2024 09:57:31 -0700 Subject: [PATCH 2/3] feedback should default missing values to inline data --- src/core/trulens/core/feedback/feedback.py | 28 ++++++++++++++++++++-- src/core/trulens/core/instruments.py | 19 +++++++++++---- src/core/trulens/core/schema/select.py | 3 +++ 3 files changed, 43 insertions(+), 7 deletions(-) diff --git a/src/core/trulens/core/feedback/feedback.py b/src/core/trulens/core/feedback/feedback.py index c25449633..74d4f6685 100644 --- a/src/core/trulens/core/feedback/feedback.py +++ b/src/core/trulens/core/feedback/feedback.py @@ -236,6 +236,7 @@ def __init__( self.imp = imp self.agg = agg + self.fill_inline_selectors() # Verify that `imp` expects the arguments specified in `selectors`: if self.imp is not None: @@ -246,6 +247,26 @@ def __init__( f"Its arguments are {list(sig.parameters.keys())}." ) + def fill_inline_selectors(self): + """ + Use inline data for filling missing feedback function arguments. + """ + + assert ( + self.imp is not None + ), "Feedback function implementation is required to determine default argument names." + + sig: Signature = signature(self.imp) + par_names = list( + k for k in sig.parameters.keys() if k not in self.selectors + ) + self.selectors = { + par_name: Select.RecordInlineData[par_name] + if par_name not in self.selectors + else self.selectors[par_name] + for par_name in par_names + } + def on_input_output(self) -> Feedback: """ Specifies that the feedback implementation arguments are to be the main @@ -661,7 +682,10 @@ def check_selectors( # with c.capture() as cap: for k, q in self.selectors.items(): - if q.exists(source_data): + if q.exists( + source_data + ) or Select.RecordInlineData.is_immediate_prefix_of(q): + # Skip if q exists in record or references inline data that should be supplied at app runtime. continue msg += f""" @@ -1197,7 +1221,7 @@ def _construct_source_data( if isinstance(record.inline_data, Mapping): inline_data = { k: val["value"] - for k, val in record.inline_data + for k, val in record.inline_data.items() if isinstance(val, Mapping) and "value" in val } source_data = {**source_data, **inline_data} diff --git a/src/core/trulens/core/instruments.py b/src/core/trulens/core/instruments.py index ace85a7ce..b6aa72865 100644 --- a/src/core/trulens/core/instruments.py +++ b/src/core/trulens/core/instruments.py @@ -1036,12 +1036,17 @@ def __set_name__(self, cls: type, name: str): self.method(cls, name) -def tag(key: str, value: Any, collection: bool = False): +def label_value( + value: Any, labels: Union[str, Iterable[str]], collection: bool = False +): """Set inline data for the given key.""" def _find_contexts_frame(f): return id(f) == id(mod_endpoint.Endpoint.track_all_costs_tally.__code__) + if isinstance(labels, str): + labels = [labels] + # get previously known inline data contexts: Optional[Set[RecordingContext]] = get_first_local_in_call_stack( key="contexts", @@ -1049,7 +1054,11 @@ def _find_contexts_frame(f): offset=1, skip=python.caller_frame(), ) - # Note: are empty sets false? - if contexts: - for context in contexts: - context.add_inline_data(key, jsonify(value), collection=collection) + if contexts is None: + return + + for context in contexts: + for label in labels: + context.add_inline_data( + label, jsonify(value), collection=collection + ) diff --git a/src/core/trulens/core/schema/select.py b/src/core/trulens/core/schema/select.py index 9ebf74c13..be2b1223a 100644 --- a/src/core/trulens/core/schema/select.py +++ b/src/core/trulens/core/schema/select.py @@ -36,6 +36,9 @@ class Select: RecordOutput: Query = Record.main_output """Selector for the main app output.""" + RecordInlineData: Query = Record.inline_data + """Selector for the inline data of the record.""" + RecordCalls: Query = Record.app # type: ignore """Selector for the calls made by the wrapped app. From 32e5fc211a9187b246d1f37f28f99693bfc38f51 Mon Sep 17 00:00:00 2001 From: corey Date: Mon, 12 Aug 2024 13:57:54 -0700 Subject: [PATCH 3/3] working feedback changes --- .../inline_selectors_quickstart.ipynb | 481 ++++++++++++++++++ src/feedback/trulens/feedback/llm_provider.py | 52 +- 2 files changed, 507 insertions(+), 26 deletions(-) create mode 100644 examples/quickstart/inline_selectors_quickstart.ipynb diff --git a/examples/quickstart/inline_selectors_quickstart.ipynb b/examples/quickstart/inline_selectors_quickstart.ipynb new file mode 100644 index 000000000..8291438cc --- /dev/null +++ b/examples/quickstart/inline_selectors_quickstart.ipynb @@ -0,0 +1,481 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 📓 TruLens Quickstart\n", + "\n", + "In this quickstart you will create a RAG from scratch and learn how to log it and get feedback on an LLM response.\n", + "\n", + "For evaluation, we will leverage the \"hallucination triad\" of groundedness, context relevance and answer relevance.\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/truera/trulens/blob/main/examples/quickstart/quickstart.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install trulens trulens-providers-openai chromadb openai" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get Data\n", + "\n", + "In this case, we'll just initialize some simple text in the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "uw_info = \"\"\"\n", + "The University of Washington, founded in 1861 in Seattle, is a public research university\n", + "with over 45,000 students across three campuses in Seattle, Tacoma, and Bothell.\n", + "As the flagship institution of the six public universities in Washington state,\n", + "UW encompasses over 500 buildings and 20 million square feet of space,\n", + "including one of the largest library systems in the world.\n", + "\"\"\"\n", + "\n", + "wsu_info = \"\"\"\n", + "Washington State University, commonly known as WSU, founded in 1890, is a public research university in Pullman, Washington.\n", + "With multiple campuses across the state, it is the state's second largest institution of higher education.\n", + "WSU is known for its programs in veterinary medicine, agriculture, engineering, architecture, and pharmacy.\n", + "\"\"\"\n", + "\n", + "seattle_info = \"\"\"\n", + "Seattle, a city on Puget Sound in the Pacific Northwest, is surrounded by water, mountains and evergreen forests, and contains thousands of acres of parkland.\n", + "It's home to a large tech industry, with Microsoft and Amazon headquartered in its metropolitan area.\n", + "The futuristic Space Needle, a legacy of the 1962 World's Fair, is its most iconic landmark.\n", + "\"\"\"\n", + "\n", + "starbucks_info = \"\"\"\n", + "Starbucks Corporation is an American multinational chain of coffeehouses and roastery reserves headquartered in Seattle, Washington.\n", + "As the world's largest coffeehouse chain, Starbucks is seen to be the main representation of the United States' second wave of coffee culture.\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create Vector Store\n", + "\n", + "Create a chromadb vector store in memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import chromadb\n", + "from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction\n", + "\n", + "embedding_function = OpenAIEmbeddingFunction(\n", + " api_key=os.environ.get(\"OPENAI_API_KEY\"),\n", + " model_name=\"text-embedding-ada-002\",\n", + ")\n", + "\n", + "\n", + "chroma_client = chromadb.Client()\n", + "vector_store = chroma_client.get_or_create_collection(\n", + " name=\"Washington\", embedding_function=embedding_function\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Populate the vector store." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vector_store.add(\"uw_info\", documents=uw_info)\n", + "vector_store.add(\"wsu_info\", documents=wsu_info)\n", + "vector_store.add(\"seattle_info\", documents=seattle_info)\n", + "vector_store.add(\"starbucks_info\", documents=starbucks_info)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build RAG from scratch\n", + "\n", + "Build a custom RAG from scratch, and add TruLens custom instrumentation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens.core import Tru\n", + "from trulens.core.app.custom import instrument\n", + "\n", + "tru = Tru()\n", + "tru.reset_database()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "\n", + "oai_client = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "\n", + "oai_client = OpenAI()\n", + "\n", + "\n", + "class RAG_from_scratch:\n", + " @instrument\n", + " def retrieve(self, query: str) -> list:\n", + " \"\"\"\n", + " Retrieve relevant text from vector store.\n", + " \"\"\"\n", + " results = vector_store.query(query_texts=query, n_results=4)\n", + " # Flatten the list of lists into a single list\n", + " return [doc for sublist in results[\"documents\"] for doc in sublist]\n", + "\n", + " @instrument\n", + " def generate_completion(self, query: str, context_str: list) -> str:\n", + " \"\"\"\n", + " Generate answer from context.\n", + " \"\"\"\n", + " completion = (\n", + " oai_client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo\",\n", + " temperature=0,\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"We have provided context information below. \\n\"\n", + " f\"---------------------\\n\"\n", + " f\"{context_str}\"\n", + " f\"\\n---------------------\\n\"\n", + " f\"Given this information, please answer the question: {query}\",\n", + " }\n", + " ],\n", + " )\n", + " .choices[0]\n", + " .message.content\n", + " )\n", + " return completion\n", + "\n", + " @instrument\n", + " def query(self, query: str) -> str:\n", + " context_str = self.retrieve(query)\n", + " completion = self.generate_completion(query, context_str)\n", + " return completion\n", + "\n", + "\n", + "rag = RAG_from_scratch()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set up feedback functions.\n", + "\n", + "Here we'll use groundedness, answer relevance and context relevance to detect hallucination." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from trulens.core import Feedback\n", + "from trulens.providers.openai import OpenAI\n", + "\n", + "provider = OpenAI(model_engine=\"gpt-4o\")\n", + "\n", + "# Define a groundedness feedback function\n", + "f_groundedness = Feedback(\n", + " provider.groundedness_measure_with_cot_reasons, name=\"Groundedness\"\n", + ")\n", + "# Question/answer relevance between overall question and answer.\n", + "f_answer_relevance = Feedback(\n", + " provider.relevance_with_cot_reasons, name=\"Answer Relevance\"\n", + ")\n", + "\n", + "# Context relevance between question and each context chunk.\n", + "f_context_relevance = (\n", + " Feedback(\n", + " provider.context_relevance_with_cot_reasons, name=\"Context Relevance\"\n", + " ).aggregate(np.mean) # choose a different aggregation method if you wish\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Construct the app\n", + "Wrap the custom RAG with TruCustomApp, add list of feedbacks for eval" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens.core import TruCustomApp\n", + "\n", + "tru_rag = TruCustomApp(\n", + " rag,\n", + " app_id=\"RAG v1\",\n", + " feedbacks=[f_groundedness, f_answer_relevance, f_context_relevance],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run the app\n", + "Use `tru_rag` as a context manager for the custom RAG-from-scratch app." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with tru_rag as recording:\n", + " rag.query(\"When was the University of Washington founded?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check results\n", + "\n", + "We can view results in the leaderboard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tru.get_leaderboard()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens.core.utils.trulens import get_feedback_result\n", + "\n", + "last_record = recording.records[-1]\n", + "get_feedback_result(last_record, \"Context Relevance\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use guardrails\n", + "\n", + "In addition to making informed iteration, we can also directly use feedback results as guardrails at inference time. In particular, here we show how to use the context relevance score as a guardrail to filter out irrelevant context before it gets passed to the LLM. This both reduces hallucination and improves efficiency.\n", + "\n", + "To do so, we'll rebuild our RAG using the @context-filter decorator on the method we want to filter, and pass in the feedback function and threshold to use for guardrailing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens.core.guardrails.base import context_filter\n", + "\n", + "# note: feedback function used for guardrail must only return a score, not also reasons\n", + "f_context_relevance_score = Feedback(\n", + " provider.context_relevance, name=\"Context Relevance\"\n", + ")\n", + "\n", + "\n", + "class filtered_RAG_from_scratch:\n", + " @instrument\n", + " @context_filter(f_context_relevance_score, 0.75, keyword_for_prompt=\"query\")\n", + " def retrieve(self, query: str) -> list:\n", + " \"\"\"\n", + " Retrieve relevant text from vector store.\n", + " \"\"\"\n", + " results = vector_store.query(query_texts=query, n_results=4)\n", + " return [doc for sublist in results[\"documents\"] for doc in sublist]\n", + "\n", + " @instrument\n", + " def generate_completion(self, query: str, context_str: list) -> str:\n", + " \"\"\"\n", + " Generate answer from context.\n", + " \"\"\"\n", + " completion = (\n", + " oai_client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo\",\n", + " temperature=0,\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"We have provided context information below. \\n\"\n", + " f\"---------------------\\n\"\n", + " f\"{context_str}\"\n", + " f\"\\n---------------------\\n\"\n", + " f\"Given this information, please answer the question: {query}\",\n", + " }\n", + " ],\n", + " )\n", + " .choices[0]\n", + " .message.content\n", + " )\n", + " return completion\n", + "\n", + " @instrument\n", + " def query(self, query: str) -> str:\n", + " context_str = self.retrieve(query=query)\n", + " completion = self.generate_completion(\n", + " query=query, context_str=context_str\n", + " )\n", + " return completion\n", + "\n", + "\n", + "filtered_rag = filtered_RAG_from_scratch()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Record and operate as normal" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens.core import TruCustomApp\n", + "\n", + "filtered_tru_rag = TruCustomApp(\n", + " filtered_rag,\n", + " app_id=\"RAG v2\",\n", + " feedbacks=[f_groundedness, f_answer_relevance, f_context_relevance],\n", + ")\n", + "\n", + "with filtered_tru_rag as recording:\n", + " filtered_rag.query(query=\"when was the university of washington founded?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tru.get_leaderboard(app_ids=[])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See the power of filtering!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens.core.utils.trulens import get_feedback_result\n", + "\n", + "last_record = recording.records[-1]\n", + "get_feedback_result(last_record, \"Context Relevance\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens.dashboard import run_dashboard\n", + "\n", + "run_dashboard(tru, port=3453, force=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "trulens18_release", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/feedback/trulens/feedback/llm_provider.py b/src/feedback/trulens/feedback/llm_provider.py index 636aea823..2a2552ecb 100644 --- a/src/feedback/trulens/feedback/llm_provider.py +++ b/src/feedback/trulens/feedback/llm_provider.py @@ -311,7 +311,7 @@ def _determine_output_space( def context_relevance( self, - question: str, + prompt: str, context: str, criteria: str = "", min_score_val: int = 0, @@ -336,8 +336,8 @@ def context_relevance( ``` Args: - question (str): A question being asked. - context (str): Context related to the question. + prompt (str): The prompt to the application + context (str): Context related to the prompt. criteria (str): Overriding evaluation criteria for evaluation . min_score_val (int): The minimum score value. Defaults to 0. max_score_val (int): The maximum score value. Defaults to 3. @@ -358,7 +358,7 @@ def context_relevance( system_prompt=ContextRelevance.system_prompt, user_prompt=str.format( prompts.CONTEXT_RELEVANCE_USER, - question=question, + question=prompt, context=context, ), min_score_val=min_score_val, @@ -368,7 +368,7 @@ def context_relevance( def context_relevance_with_cot_reasons( self, - question: str, + prompt: str, context: str, criteria: str = "", min_score_val: int = 0, @@ -394,8 +394,8 @@ def context_relevance_with_cot_reasons( ``` Args: - question (str): A question being asked. - context (str): Context related to the question. + prompt (str): The prompt to the application. + context (str): Context related to the prompt. criteria (str): Overriding evaluation criteria for evaluation . min_score_val (int): The minimum score value. Defaults to 0. max_score_val (int): The maximum score value. Defaults to 3. @@ -406,7 +406,7 @@ def context_relevance_with_cot_reasons( """ user_prompt = str.format( - prompts.CONTEXT_RELEVANCE_USER, question=question, context=context + prompts.CONTEXT_RELEVANCE_USER, question=prompt, context=context ) user_prompt = user_prompt.replace( "RELEVANCE:", prompts.COT_REASONS_TEMPLATE @@ -431,7 +431,7 @@ def context_relevance_with_cot_reasons( def context_relevance_verb_confidence( self, - question: str, + prompt: str, context: str, criteria: str = "", min_score_val: int = 0, @@ -457,8 +457,8 @@ def context_relevance_verb_confidence( ``` Args: - question (str): A question being asked. - context (str): Context related to the question. + prompt (str): The prompt to the application. + context (str): Context related to the prompt. criteria (str): Overriding evaluation criteria for evaluation . min_score_val (int): The minimum score value. Defaults to 0. max_score_val (int): The maximum score value. Defaults to 3. @@ -483,7 +483,7 @@ def context_relevance_verb_confidence( + ContextRelevance.verb_confidence_prompt, user_prompt=str.format( prompts.CONTEXT_RELEVANCE_USER, - question=question, + question=prompt, context=context, ), min_score_val=min_score_val, @@ -1314,7 +1314,7 @@ def stereotypes_with_cot_reasons( return self.generate_score_and_reasons(system_prompt, user_prompt) def groundedness_measure_with_cot_reasons( - self, source: str, statement: str + self, context: str, response: str ) -> Tuple[float, dict]: """A measure to track if the source material supports each sentence in the statement using an LLM provider. @@ -1339,8 +1339,8 @@ def groundedness_measure_with_cot_reasons( ) ``` Args: - source: The source that should support the statement. - statement: The statement to check groundedness. + context: The source that should support the statement. + response: The statement to check groundedness. Returns: Tuple[float, dict]: A tuple containing a value between 0.0 (not grounded) and 1.0 (grounded) and a dictionary containing the reasons for the evaluation. @@ -1349,13 +1349,13 @@ def groundedness_measure_with_cot_reasons( groundedness_scores = {} reasons_str = "" - hypotheses = sent_tokenize(statement) + hypotheses = sent_tokenize(response) system_prompt = prompts.LLM_GROUNDEDNESS_SYSTEM def evaluate_hypothesis(index, hypothesis): user_prompt = prompts.LLM_GROUNDEDNESS_USER.format( - premise=f"{source}", hypothesis=f"{hypothesis}" + premise=f"{context}", hypothesis=f"{hypothesis}" ) score, reason = self.generate_score_and_reasons( system_prompt, user_prompt @@ -1392,7 +1392,7 @@ def evaluate_hypothesis(index, hypothesis): return average_groundedness_score, {"reasons": reasons_str} def groundedness_measure_with_cot_reasons_consider_answerability( - self, source: str, statement: str, question: str + self, context: str, response: str, prompt: str ) -> Tuple[float, dict]: """A measure to track if the source material supports each sentence in the statement using an LLM provider. @@ -1420,9 +1420,9 @@ def groundedness_measure_with_cot_reasons_consider_answerability( ) ``` Args: - source: The source that should support the statement. - statement: The statement to check groundedness. - question: The question to check answerability. + context: The source that should support the statement. + response: The statement to check groundedness. + prompt: The prompt to check answerability. Returns: Tuple[float, dict]: A tuple containing a value between 0.0 (not grounded) and 1.0 (grounded) and a dictionary containing the reasons for the evaluation. @@ -1440,30 +1440,30 @@ def evaluate_abstention(statement): ) return score - def evaluate_answerability(question, source): + def evaluate_answerability(prompt: str, context: str): user_prompt = prompts.LLM_ANSWERABILITY_USER.format( - question=question, source=source + question=prompt, source=context ) score = self.generate_score( prompts.LLM_ANSWERABILITY_SYSTEM, user_prompt ) return score - hypotheses = sent_tokenize(statement) + hypotheses = sent_tokenize(response) system_prompt = prompts.LLM_GROUNDEDNESS_SYSTEM def evaluate_hypothesis(index, hypothesis): abstention_score = evaluate_abstention(hypothesis) if abstention_score > 0.5: - answerability_score = evaluate_answerability(question, source) + answerability_score = evaluate_answerability(prompt, context) if answerability_score > 0.5: return index, 0.0, {"reason": "Answerable abstention"} else: return index, 1.0, {"reason": "Unanswerable abstention"} else: user_prompt = prompts.LLM_GROUNDEDNESS_USER.format( - premise=f"{source}", hypothesis=f"{hypothesis}" + premise=f"{context}", hypothesis=f"{hypothesis}" ) score, reason = self.generate_score_and_reasons( system_prompt, user_prompt