Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inline selectors #1667

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
481 changes: 481 additions & 0 deletions examples/quickstart/inline_selectors_quickstart.ipynb

Large diffs are not rendered by default.

20 changes: 19 additions & 1 deletion src/core/trulens/core/app/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ def __init__(self, app: App, record_metadata: JSON = None):
self.records: List[mod_record_schema.Record] = []
"""Completed records."""

self.inline_data: Dict[str, Any] = {}
"""Inline data to attach to the currently tracked record."""

self.lock: Lock = Lock()
"""Lock blocking access to `calls` and `records` when adding calls or finishing a record."""

Expand Down Expand Up @@ -417,11 +420,20 @@ def add_call(self, call: mod_record_schema.RecordAppCall):
# processing calls with awaitable or generator results.
self.calls[call.call_id] = call

def add_inline_data(self, key: str, value: Any, **kwargs):
"""
Add inline data to the currently tracked call list.
"""
with self.lock:
# TODO: make value a constant
self.inline_data[key] = {"value": value, **kwargs}

def finish_record(
self,
calls_to_record: Callable[
[
List[mod_record_schema.RecordAppCall],
Dict[str, Dict[str, Any]],
mod_types_schema.Metadata,
Optional[mod_record_schema.Record],
],
Expand All @@ -436,9 +448,13 @@ def finish_record(

with self.lock:
record = calls_to_record(
list(self.calls.values()), self.record_metadata, existing_record
list(self.calls.values()),
self.inline_data,
self.record_metadata,
existing_record,
)
self.calls = {}
self.inline_data = {}

if existing_record is None:
# If existing record was given, we assume it was already
Expand Down Expand Up @@ -1110,6 +1126,7 @@ def on_add_record(

def build_record(
calls: Iterable[mod_record_schema.RecordAppCall],
inline_data: JSON,
record_metadata: JSON,
existing_record: Optional[mod_record_schema.Record] = None,
) -> mod_record_schema.Record:
Expand Down Expand Up @@ -1137,6 +1154,7 @@ def build_record(
perf=perf,
app_id=self.app_id,
tags=self.tags,
inline_data=jsonify(inline_data),
meta=jsonify(record_metadata),
)

Expand Down
7 changes: 7 additions & 0 deletions src/core/trulens/core/feedback/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from time import sleep
from types import ModuleType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Expand All @@ -18,6 +19,7 @@
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Expand All @@ -43,6 +45,9 @@
from trulens.core.utils.serial import SerialModel
from trulens.core.utils.threading import DEFAULT_NETWORK_TIMEOUT

if TYPE_CHECKING:
from trulens.core.app.base import RecordingContext

logger = logging.getLogger(__name__)

pp = PrettyPrinter()
Expand Down Expand Up @@ -512,6 +517,7 @@ def track_all_costs(
@staticmethod
def track_all_costs_tally(
__func: mod_asynchro_utils.CallableMaybeAwaitable[A, T],
contexts: Set[RecordingContext],
*args,
with_openai: bool = True,
with_hugs: bool = True,
Expand All @@ -524,6 +530,7 @@ def track_all_costs_tally(
Track costs of all of the apis we can currently track, over the
execution of thunk.
"""
assert contexts, "No recording context set."

result, cbs = Endpoint.track_all_costs(
__func,
Expand Down
46 changes: 40 additions & 6 deletions src/core/trulens/core/feedback/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
TypeVar,
Expand Down Expand Up @@ -235,6 +236,7 @@ def __init__(

self.imp = imp
self.agg = agg
self.fill_inline_selectors()

# Verify that `imp` expects the arguments specified in `selectors`:
if self.imp is not None:
Expand All @@ -245,6 +247,26 @@ def __init__(
f"Its arguments are {list(sig.parameters.keys())}."
)

def fill_inline_selectors(self):
"""
Use inline data for filling missing feedback function arguments.
"""

assert (
self.imp is not None
), "Feedback function implementation is required to determine default argument names."

sig: Signature = signature(self.imp)
par_names = list(
k for k in sig.parameters.keys() if k not in self.selectors
)
self.selectors = {
par_name: Select.RecordInlineData[par_name]
if par_name not in self.selectors
else self.selectors[par_name]
for par_name in par_names
}

def on_input_output(self) -> Feedback:
"""
Specifies that the feedback implementation arguments are to be the main
Expand Down Expand Up @@ -359,7 +381,7 @@ def evaluate_deferred(

def prepare_feedback(
row,
) -> Optional[mod_feedback_schema.FeedbackResultStatus]:
) -> Optional[mod_feedback_schema.FeedbackResult]:
record_json = row.record_json
record = mod_record_schema.Record.model_validate(record_json)

Expand Down Expand Up @@ -660,7 +682,10 @@ def check_selectors(

# with c.capture() as cap:
for k, q in self.selectors.items():
if q.exists(source_data):
if q.exists(
source_data
) or Select.RecordInlineData.is_immediate_prefix_of(q):
# Skip if q exists in record or references inline data that should be supplied at app runtime.
continue

msg += f"""
Expand Down Expand Up @@ -1047,9 +1072,10 @@ def run_and_log(
)
)

feedback_result = self.run(app=app, record=record).update(
feedback_result_id=feedback_result_id
)
feedback_result = self.run(
app=app,
record=record,
).update(feedback_result_id=feedback_result_id)

except Exception:
# Convert traceback to a UTF-8 string, replacing errors to avoid encoding issues
Expand Down Expand Up @@ -1189,9 +1215,17 @@ def _construct_source_data(
if app is not None:
source_data["__app__"] = app

if record is not None:
if record:
source_data["__record__"] = record.layout_calls_as_app()

if isinstance(record.inline_data, Mapping):
inline_data = {
k: val["value"]
for k, val in record.inline_data.items()
if isinstance(val, Mapping) and "value" in val
}
source_data = {**source_data, **inline_data}

return source_data

def extract_selection(
Expand Down
53 changes: 40 additions & 13 deletions src/core/trulens/core/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from trulens.core.utils.text import retab

if TYPE_CHECKING:
from trulens.core.app import App
from trulens.core.app.base import RecordingContext

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -406,21 +407,22 @@ def tru_wrapper(*args, **kwargs):
inspect.isasyncgenfunction(func),
)

apps = getattr(tru_wrapper, Instrument.APPS)
apps: Iterable[App] = getattr(tru_wrapper, Instrument.APPS)

# If not within a root method, call the wrapped function without
# any recording.

# Get any contexts already known from higher in the call stack.
contexts = get_first_local_in_call_stack(
key="contexts",
func=find_instrumented,
offset=1,
skip=python_utils.caller_frame(),
_contexts: Optional[Set[RecordingContext]] = (
get_first_local_in_call_stack(
key="contexts",
func=find_instrumented,
offset=1,
skip=python_utils.caller_frame(),
)
)
# Note: are empty sets false?
if contexts is None:
contexts = set()
contexts: Set[RecordingContext] = _contexts or set()

# And add any new contexts from all apps wishing to record this
# function. This may produce some of the same contexts that were
Expand Down Expand Up @@ -480,11 +482,8 @@ def tru_wrapper(*args, **kwargs):

# First prepare the stacks for each context.
for ctx in contexts:
# Get app that has instrumented this method.
app = ctx.app

# The path to this method according to the app.
path = app.get_method_path(
path = ctx.app.get_method_path(
args[0], func
) # hopefully args[0] is self, owner of func

Expand Down Expand Up @@ -532,7 +531,7 @@ def tru_wrapper(*args, **kwargs):
bindings: BoundArguments = sig.bind(*args, **kwargs)

rets, cost = mod_endpoint.Endpoint.track_all_costs_tally(
func, *args, **kwargs
func, contexts, *args, **kwargs
)

except BaseException as e:
Expand Down Expand Up @@ -1039,3 +1038,31 @@ def __set_name__(self, cls: type, name: str):
# Note that this does not actually change the method, just adds it to
# list of filters.
self.method(cls, name)


def label_value(
value: Any, labels: Union[str, Iterable[str]], collection: bool = False
):
"""Set inline data for the given key."""

def _find_contexts_frame(f):
return id(f) == id(mod_endpoint.Endpoint.track_all_costs_tally.__code__)

if isinstance(labels, str):
labels = [labels]

# get previously known inline data
contexts: Optional[Set[RecordingContext]] = get_first_local_in_call_stack(
key="contexts",
func=_find_contexts_frame,
offset=1,
skip=caller_frame(),
)
if contexts is None:
return

for context in contexts:
for label in labels:
context.add_inline_data(
label, jsonify(value), collection=collection
)
3 changes: 3 additions & 0 deletions src/core/trulens/core/schema/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ class Record(serial.SerialModel, Hashable):
main_error: Optional[serial.JSON] = None # if error
"""The app's main error if there was an error."""

inline_data: Optional[serial.JSON] = None
"""Inline data added to the record."""

calls: List[RecordAppCall] = []
"""The collection of calls recorded.

Expand Down
3 changes: 3 additions & 0 deletions src/core/trulens/core/schema/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class Select:
RecordOutput: Query = Record.main_output
"""Selector for the main app output."""

RecordInlineData: Query = Record.inline_data
"""Selector for the inline data of the record."""

RecordCalls: Query = Record.app # type: ignore
"""Selector for the calls made by the wrapped app.

Expand Down
6 changes: 5 additions & 1 deletion src/core/trulens/core/tru.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,11 @@ def _submit_feedback_functions(
for ffunc in feedback_functions:
# Run feedback function and the on_done callback. This makes sure
# that Future.result() returns only after on_done has finished.
def run_and_call_callback(ffunc, app, record):
def run_and_call_callback(
ffunc: feedback.Feedback,
app: mod_app_schema.AppDefinition,
record: mod_record_schema.Record,
):
temp = ffunc.run(app=app, record=record)
if on_done is not None:
try:
Expand Down
Loading