From a8920ab1355f8f0cc86ea6c29c789623c4defee7 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Mon, 2 Sep 2024 15:17:40 -0300 Subject: [PATCH] fix(openai): allow iteration with anext protocol (#870) Co-authored-by: Sin-Woo Bang @sinwoobang --- langfuse/openai.py | 24 ++++++++ tests/test_openai.py | 128 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 147 insertions(+), 5 deletions(-) diff --git a/langfuse/openai.py b/langfuse/openai.py index 5f33632e..8e0bccb1 100644 --- a/langfuse/openai.py +++ b/langfuse/openai.py @@ -722,6 +722,18 @@ def __iter__(self): finally: self._finalize() + def __next__(self): + try: + item = self.response.__next__() + self.items.append(item) + + return item + + except StopIteration: + self._finalize() + + raise + def __enter__(self): return self.__iter__() @@ -769,6 +781,18 @@ async def __aiter__(self): finally: await self._finalize() + async def __anext__(self): + try: + item = await self.response.__anext__() + self.items.append(item) + + return item + + except StopAsyncIteration: + await self._finalize() + + raise + async def __aenter__(self): return self.__aiter__() diff --git a/tests/test_openai.py b/tests/test_openai.py index 34af30f4..962c7ba3 100644 --- a/tests/test_openai.py +++ b/tests/test_openai.py @@ -161,6 +161,70 @@ def test_openai_chat_completion_stream(): assert trace.output == chat_content +def test_openai_chat_completion_stream_with_next_iteration(): + api = get_api() + generation_name = create_uuid() + completion = chat_func( + name=generation_name, + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "1 + 1 = "}], + temperature=0, + metadata={"someKey": "someResponse"}, + stream=True, + ) + + assert iter(completion) + + chat_content = "" + + while True: + try: + c = next(completion) + chat_content += c.choices[0].delta.content or "" + + except StopIteration: + break + + assert len(chat_content) > 0 + + openai.flush_langfuse() + + generation = api.observations.get_many(name=generation_name, type="GENERATION") + + assert len(generation.data) != 0 + assert generation.data[0].name == generation_name + assert generation.data[0].metadata == {"someKey": "someResponse"} + + assert generation.data[0].input == [{"content": "1 + 1 = ", "role": "user"}] + assert generation.data[0].type == "GENERATION" + assert generation.data[0].model == "gpt-3.5-turbo-0125" + assert generation.data[0].start_time is not None + assert generation.data[0].end_time is not None + assert generation.data[0].start_time < generation.data[0].end_time + assert generation.data[0].model_parameters == { + "temperature": 0, + "top_p": 1, + "frequency_penalty": 0, + "max_tokens": "inf", + "presence_penalty": 0, + } + assert generation.data[0].usage.input is not None + assert generation.data[0].usage.output is not None + assert generation.data[0].usage.total is not None + assert generation.data[0].output == "2" + assert isinstance(generation.data[0].output, str) is True + assert generation.data[0].completion_start_time is not None + + # Completion start time for time-to-first-token + assert generation.data[0].completion_start_time is not None + assert generation.data[0].completion_start_time >= generation.data[0].start_time + assert generation.data[0].completion_start_time <= generation.data[0].end_time + + trace = api.trace.get(generation.data[0].trace_id) + assert trace.input == [{"role": "user", "content": "1 + 1 = "}] + assert trace.output == chat_content + + def test_openai_chat_completion_stream_fail(): api = get_api() generation_name = create_uuid() @@ -696,7 +760,6 @@ async def test_async_chat(): ) openai.flush_langfuse() - print(completion) generation = api.observations.get_many(name=generation_name, type="GENERATION") @@ -742,7 +805,6 @@ async def test_async_chat_stream(): print(c) openai.flush_langfuse() - print(completion) generation = api.observations.get_many(name=generation_name, type="GENERATION") @@ -772,6 +834,64 @@ async def test_async_chat_stream(): assert generation.data[0].completion_start_time <= generation.data[0].end_time +@pytest.mark.asyncio +async def test_async_chat_stream_with_anext(): + api = get_api() + client = AsyncOpenAI() + + generation_name = create_uuid() + + completion = await client.chat.completions.create( + messages=[{"role": "user", "content": "Give me a one-liner joke"}], + model="gpt-3.5-turbo", + name=generation_name, + stream=True, + ) + + result = "" + + while True: + try: + c = await completion.__anext__() + + result += c.choices[0].delta.content or "" + + except StopAsyncIteration: + break + + openai.flush_langfuse() + + print(result) + + generation = api.observations.get_many(name=generation_name, type="GENERATION") + + assert len(generation.data) != 0 + assert generation.data[0].name == generation_name + assert generation.data[0].input == [ + {"content": "Give me a one-liner joke", "role": "user"} + ] + assert generation.data[0].type == "GENERATION" + assert generation.data[0].model == "gpt-3.5-turbo-0125" + assert generation.data[0].start_time is not None + assert generation.data[0].end_time is not None + assert generation.data[0].start_time < generation.data[0].end_time + assert generation.data[0].model_parameters == { + "temperature": 1, + "top_p": 1, + "frequency_penalty": 0, + "max_tokens": "inf", + "presence_penalty": 0, + } + assert generation.data[0].usage.input is not None + assert generation.data[0].usage.output is not None + assert generation.data[0].usage.total is not None + + # Completion start time for time-to-first-token + assert generation.data[0].completion_start_time is not None + assert generation.data[0].completion_start_time >= generation.data[0].start_time + assert generation.data[0].completion_start_time <= generation.data[0].end_time + + def test_openai_function_call(): from typing import List @@ -880,7 +1000,7 @@ def test_openai_tool_call(): } ] messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] - completion = openai.chat.completions.create( + openai.chat.completions.create( model="gpt-3.5-turbo", messages=messages, tools=tools, @@ -888,8 +1008,6 @@ def test_openai_tool_call(): name=generation_name, ) - print(completion) - openai.flush_langfuse() generation = api.observations.get_many(name=generation_name, type="GENERATION")