Skip to content

Commit

Permalink
feat(prompts): return stale prompt version and refresh prompt cache i…
Browse files Browse the repository at this point in the history
…n background thread instead of waiting for the updated prompt (#902)
  • Loading branch information
marcklingen authored Sep 3, 2024
1 parent 150f09b commit 9fe2bee
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 19 deletions.
31 changes: 20 additions & 11 deletions langfuse/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def get_prompt(
)
except Exception as e:
if fallback:
self.log.warn(
self.log.warning(
f"Returning fallback prompt for '{cache_key}' due to fetch error: {e}"
)

Expand Down Expand Up @@ -1058,21 +1058,30 @@ def get_prompt(
raise e

if cached_prompt.is_expired():
self.log.debug(f"Stale prompt '{cache_key}' found in cache.")
try:
return self._fetch_prompt_and_update_cache(
name,
version=version,
label=label,
ttl_seconds=cache_ttl_seconds,
max_retries=bounded_max_retries,
fetch_timeout_seconds=fetch_timeout_seconds,
# refresh prompt in background thread, refresh_prompt deduplicates tasks
self.log.debug(f"Refreshing prompt '{cache_key}' in background.")
self.prompt_cache.add_refresh_prompt_task(
cache_key,
lambda: self._fetch_prompt_and_update_cache(
name,
version=version,
label=label,
ttl_seconds=cache_ttl_seconds,
max_retries=bounded_max_retries,
fetch_timeout_seconds=fetch_timeout_seconds,
),
)
self.log.debug(f"Returning stale prompt '{cache_key}' from cache.")
# return stale prompt
return cached_prompt.value

except Exception as e:
self.log.warn(
f"Returning expired prompt cache for '{cache_key}' due to fetch error: {e}"
self.log.warning(
f"Error when refreshing cached prompt '{cache_key}', returning cached version. Error: {e}"
)

# creation of refresh prompt task failed, return stale prompt
return cached_prompt.value

return cached_prompt.value
Expand Down
122 changes: 120 additions & 2 deletions langfuse/prompt_cache.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
"""@private"""

from datetime import datetime
from typing import Optional, Dict
from typing import List, Optional, Dict, Set
from threading import Thread
import atexit
import logging
from queue import Empty, Queue

from langfuse.model import PromptClient


DEFAULT_PROMPT_CACHE_TTL_SECONDS = 60

DEFAULT_PROMPT_CACHE_REFRESH_WORKERS = 1


class PromptCacheItem:
def __init__(self, prompt: PromptClient, ttl_seconds: int):
Expand All @@ -22,11 +28,119 @@ def get_epoch_seconds() -> int:
return int(datetime.now().timestamp())


class PromptCacheRefreshConsumer(Thread):
_log = logging.getLogger("langfuse")
_queue: Queue
_identifier: int
running: bool = True

def __init__(self, queue: Queue, identifier: int):
super().__init__()
self.daemon = True
self._queue = queue
self._identifier = identifier

def run(self):
while self.running:
try:
task = self._queue.get(timeout=1)
self._log.debug(
f"PromptCacheRefreshConsumer processing task, {self._identifier}"
)
try:
task()
# Task failed, but we still consider it processed
except Exception as e:
self._log.warning(
f"PromptCacheRefreshConsumer encountered an error, cache was not refreshed: {self._identifier}, {e}"
)

self._queue.task_done()
except Empty:
pass

def pause(self):
"""Pause the consumer."""
self.running = False


class PromptCacheTaskManager(object):
_log = logging.getLogger("langfuse")
_consumers: List[PromptCacheRefreshConsumer]
_threads: int
_queue: Queue
_processing_keys: Set[str]

def __init__(self, threads: int = 1):
self._queue = Queue()
self._consumers = []
self._threads = threads
self._processing_keys = set()

for i in range(self._threads):
consumer = PromptCacheRefreshConsumer(self._queue, i)
consumer.start()
self._consumers.append(consumer)

atexit.register(self.shutdown)

def add_task(self, key: str, task):
if key not in self._processing_keys:
self._log.debug(f"Adding prompt cache refresh task for key: {key}")
self._processing_keys.add(key)
wrapped_task = self._wrap_task(key, task)
self._queue.put((wrapped_task))
else:
self._log.debug(
f"Prompt cache refresh task already submitted for key: {key}"
)

def active_tasks(self) -> int:
return len(self._processing_keys)

def _wrap_task(self, key: str, task):
def wrapped():
self._log.debug(f"Refreshing prompt cache for key: {key}")
try:
task()
finally:
self._processing_keys.remove(key)
self._log.debug(f"Refreshed prompt cache for key: {key}")

return wrapped

def shutdown(self):
self._log.debug(
f"Shutting down prompt refresh task manager, {len(self._consumers)} consumers,..."
)

for consumer in self._consumers:
consumer.pause()

for consumer in self._consumers:
try:
consumer.join()
except RuntimeError:
# consumer thread has not started
pass

self._log.debug("Shutdown of prompt refresh task manager completed.")


class PromptCache:
_cache: Dict[str, PromptCacheItem]

def __init__(self):
_task_manager: PromptCacheTaskManager
"""Task manager for refreshing cache"""

_log = logging.getLogger("langfuse")

def __init__(
self, max_prompt_refresh_workers: int = DEFAULT_PROMPT_CACHE_REFRESH_WORKERS
):
self._cache = {}
self._task_manager = PromptCacheTaskManager(threads=max_prompt_refresh_workers)
self._log.debug("Prompt cache initialized.")

def get(self, key: str) -> Optional[PromptCacheItem]:
return self._cache.get(key, None)
Expand All @@ -37,6 +151,10 @@ def set(self, key: str, value: PromptClient, ttl_seconds: Optional[int]):

self._cache[key] = PromptCacheItem(value, ttl_seconds)

def add_refresh_prompt_task(self, key: str, fetch_func):
self._log.debug(f"Submitting refresh task for key: {key}")
self._task_manager.add_task(key, fetch_func)

@staticmethod
def generate_cache_key(
name: str, *, version: Optional[int], label: Optional[str]
Expand Down
2 changes: 1 addition & 1 deletion langfuse/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""@private"""

__version__ = "2.45.2"
__version__ = "2.45.3a0"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langfuse"
version = "2.45.2"
version = "2.45.3a0"
description = "A client library for accessing langfuse"
authors = ["langfuse <[email protected]>"]
license = "MIT"
Expand Down
90 changes: 86 additions & 4 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from time import sleep
import pytest
from unittest.mock import Mock, patch

Expand Down Expand Up @@ -628,7 +629,7 @@ def test_get_valid_cached_chat_prompt(langfuse):

# Should refetch and return new prompt if cached one is expired according to custom TTL
@patch.object(PromptCacheItem, "get_epoch_seconds")
def test_get_fresh_prompt_when_expired_cache_custom_ttl(mock_time, langfuse):
def test_get_fresh_prompt_when_expired_cache_custom_ttl(mock_time, langfuse: Langfuse):
mock_time.return_value = 0
ttl_seconds = 20

Expand Down Expand Up @@ -662,13 +663,84 @@ def test_get_fresh_prompt_when_expired_cache_custom_ttl(mock_time, langfuse):
mock_time.return_value = ttl_seconds + 1

result_call_3 = langfuse.get_prompt(prompt_name)

while True:
if langfuse.prompt_cache._task_manager.active_tasks() == 0:
break
sleep(0.1)

assert mock_server_call.call_count == 2 # New call
assert result_call_3 == prompt_client


# Should return stale prompt immediately if cached one is expired according to default TTL and add to refresh promise map
@patch.object(PromptCacheItem, "get_epoch_seconds")
def test_get_stale_prompt_when_expired_cache_default_ttl(mock_time, langfuse: Langfuse):
import logging

logging.basicConfig(level=logging.DEBUG)
mock_time.return_value = 0

prompt_name = "test"
prompt = Prompt_Text(
name=prompt_name,
version=1,
prompt="Make me laugh",
labels=[],
type="text",
config={},
tags=[],
)
prompt_client = TextPromptClient(prompt)

mock_server_call = langfuse.client.prompts.get
mock_server_call.return_value = prompt

result_call_1 = langfuse.get_prompt(prompt_name)
assert mock_server_call.call_count == 1
assert result_call_1 == prompt_client

# Update the version of the returned mocked prompt
updated_prompt = Prompt_Text(
name=prompt_name,
version=2,
prompt="Make me laugh",
labels=[],
type="text",
config={},
tags=[],
)
mock_server_call.return_value = updated_prompt

# Set time to just AFTER cache expiry
mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS + 1

stale_result = langfuse.get_prompt(prompt_name)
assert stale_result == prompt_client

# Ensure that only one refresh is triggered despite multiple calls
# Cannot check for value as the prompt might have already been updated
langfuse.get_prompt(prompt_name)
langfuse.get_prompt(prompt_name)
langfuse.get_prompt(prompt_name)
langfuse.get_prompt(prompt_name)

while True:
if langfuse.prompt_cache._task_manager.active_tasks() == 0:
break
sleep(0.1)

assert mock_server_call.call_count == 2 # Only one new call to server

# Check that the prompt has been updated after refresh
updated_result = langfuse.get_prompt(prompt_name)
assert updated_result.version == 2
assert updated_result == TextPromptClient(updated_prompt)


# Should refetch and return new prompt if cached one is expired according to default TTL
@patch.object(PromptCacheItem, "get_epoch_seconds")
def test_get_fresh_prompt_when_expired_cache_default_ttl(mock_time, langfuse):
def test_get_fresh_prompt_when_expired_cache_default_ttl(mock_time, langfuse: Langfuse):
mock_time.return_value = 0

prompt_name = "test"
Expand Down Expand Up @@ -701,13 +773,18 @@ def test_get_fresh_prompt_when_expired_cache_default_ttl(mock_time, langfuse):
mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS + 1

result_call_3 = langfuse.get_prompt(prompt_name)
while True:
if langfuse.prompt_cache._task_manager.active_tasks() == 0:
break
sleep(0.1)

assert mock_server_call.call_count == 2 # New call
assert result_call_3 == prompt_client


# Should return expired prompt if refetch fails
@patch.object(PromptCacheItem, "get_epoch_seconds")
def test_get_expired_prompt_when_failing_fetch(mock_time, langfuse):
def test_get_expired_prompt_when_failing_fetch(mock_time, langfuse: Langfuse):
mock_time.return_value = 0

prompt_name = "test"
Expand Down Expand Up @@ -735,12 +812,17 @@ def test_get_expired_prompt_when_failing_fetch(mock_time, langfuse):
mock_server_call.side_effect = Exception("Server error")

result_call_2 = langfuse.get_prompt(prompt_name, max_retries=1)
while True:
if langfuse.prompt_cache._task_manager.active_tasks() == 0:
break
sleep(0.1)

assert mock_server_call.call_count == 2
assert result_call_2 == prompt_client


# Should fetch new prompt if version changes
def test_get_fresh_prompt_when_version_changes(langfuse):
def test_get_fresh_prompt_when_version_changes(langfuse: Langfuse):
prompt_name = "test"
prompt = Prompt_Text(
name=prompt_name,
Expand Down
Loading

0 comments on commit 9fe2bee

Please sign in to comment.