Skip to content

Commit

Permalink
fix(openai): structured output parsing with openai >= 1.50 (#957)
Browse files Browse the repository at this point in the history
  • Loading branch information
hassiebp authored Oct 9, 2024
1 parent 42d3b6e commit 3399a69
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 13 deletions.
41 changes: 33 additions & 8 deletions langfuse/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

import copy
import logging
from inspect import isclass
import types

from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional

import openai.resources
Expand All @@ -34,6 +36,7 @@
from langfuse.decorators import langfuse_context
from langfuse.utils import _get_timestamp
from langfuse.utils.langfuse_singleton import LangfuseSingleton
from pydantic import BaseModel

try:
import openai
Expand All @@ -53,19 +56,14 @@
log = logging.getLogger("langfuse")


@dataclass
class OpenAiDefinition:
module: str
object: str
method: str
type: str
sync: bool

def __init__(self, module: str, object: str, method: str, type: str, sync: bool):
self.module = module
self.object = object
self.method = method
self.type = type
self.sync = sync
min_version: Optional[str] = None


OPENAI_METHODS_V0 = [
Expand Down Expand Up @@ -115,6 +113,22 @@ def __init__(self, module: str, object: str, method: str, type: str, sync: bool)
type="completion",
sync=False,
),
OpenAiDefinition(
module="openai.resources.beta.chat.completions",
object="Completions",
method="parse",
type="chat",
sync=True,
min_version="1.50.0",
),
OpenAiDefinition(
module="openai.resources.beta.chat.completions",
object="AsyncCompletions",
method="parse",
type="chat",
sync=False,
min_version="1.50.0",
),
]


Expand All @@ -136,7 +150,13 @@ def __init__(
self.args["metadata"] = (
metadata
if "response_format" not in kwargs
else {**(metadata or {}), "response_format": kwargs["response_format"]}
else {
**(metadata or {}),
"response_format": kwargs["response_format"].model_json_schema()
if isclass(kwargs["response_format"])
and issubclass(kwargs["response_format"], BaseModel)
else kwargs["response_format"],
}
)
self.args["trace_id"] = trace_id
self.args["session_id"] = session_id
Expand Down Expand Up @@ -651,6 +671,11 @@ def register_tracing(self):
resources = OPENAI_METHODS_V1 if _is_openai_v1() else OPENAI_METHODS_V0

for resource in resources:
if resource.min_version is not None and Version(
openai.__version__
) < Version(resource.min_version):
continue

wrap_function_wrapper(
resource.module,
f"{resource.object}.{resource.method}",
Expand Down
53 changes: 48 additions & 5 deletions tests/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,26 +1482,69 @@ def test_structured_output_response_format_kwarg():

def test_structured_output_beta_completions_parse():
from typing import List
from packaging.version import Version

class CalendarEvent(BaseModel):
name: str
date: str
participants: List[str]

openai.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
messages=[
generation_name = create_uuid()
api = get_api()

params = {
"model": "gpt-4o-2024-08-06",
"messages": [
{"role": "system", "content": "Extract the event information."},
{
"role": "user",
"content": "Alice and Bob are going to a science fair on Friday.",
},
],
response_format=CalendarEvent,
)
"response_format": CalendarEvent,
"name": generation_name,
}

# The beta API is only wrapped for this version range. prior to that, implicitly another wrapped method was called
if Version(openai.__version__) < Version("1.50.0"):
params.pop("name")

openai.beta.chat.completions.parse(**params)

openai.flush_langfuse()

if Version(openai.__version__) >= Version("1.50.0"):
# Check the trace and observation properties
generation = api.observations.get_many(name=generation_name, type="GENERATION")

assert len(generation.data) == 1
assert generation.data[0].name == generation_name
assert generation.data[0].type == "GENERATION"
assert generation.data[0].model == "gpt-4o-2024-08-06"
assert generation.data[0].start_time is not None
assert generation.data[0].end_time is not None
assert generation.data[0].start_time < generation.data[0].end_time

# Check input and output
assert len(generation.data[0].input) == 2
assert generation.data[0].input[0]["role"] == "system"
assert generation.data[0].input[1]["role"] == "user"
assert isinstance(generation.data[0].output, dict)
assert "name" in generation.data[0].output["content"]
assert "date" in generation.data[0].output["content"]
assert "participants" in generation.data[0].output["content"]

# Check usage
assert generation.data[0].usage.input is not None
assert generation.data[0].usage.output is not None
assert generation.data[0].usage.total is not None

# Check trace
trace = api.trace.get(generation.data[0].trace_id)

assert trace.input is not None
assert trace.output is not None


@pytest.mark.asyncio
async def test_close_async_stream():
Expand Down

0 comments on commit 3399a69

Please sign in to comment.