Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Langstudio embedding connection for embedder_operator #338

Merged
merged 5 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/pai_rag/integrations/embeddings/pai/embedding_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import os
from llama_index.core import Settings
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.dashscope import DashScopeEmbedding
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from pai_rag.utils.download_models import ModelScopeDownloader
from pai_rag.integrations.embeddings.pai.pai_embedding_config import (
PaiBaseEmbeddingConfig,
DashScopeEmbeddingConfig,
OpenAIEmbeddingConfig,
HuggingFaceEmbeddingConfig,
CnClipEmbeddingConfig,
LangStudioEmbeddingConfig,
)

from llama_index.core import Settings
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.dashscope import DashScopeEmbedding
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from pai_rag.integrations.embeddings.clip.cnclip_embedding import CnClipEmbedding
import os
from loguru import logger
from pai_rag.utils.download_models import ModelScopeDownloader


def create_embedding(
Expand All @@ -22,6 +22,7 @@ def create_embedding(
if isinstance(embed_config, OpenAIEmbeddingConfig):
embed_model = OpenAIEmbedding(
api_key=embed_config.api_key,
api_base=embed_config.api_base,
embed_batch_size=embed_config.embed_batch_size,
callback_manager=Settings.callback_manager,
)
Expand Down Expand Up @@ -91,7 +92,16 @@ def create_embedding(
logger.info(
f"Initialized CnClip embedding model {embed_config.model} with {embed_config.embed_batch_size} batch size."
)
elif isinstance(embed_config, LangStudioEmbeddingConfig):
from pai_rag.integrations.embeddings.pai.langstudio_utils import (
convert_langstudio_embed_config,
)

converted_embed_config = convert_langstudio_embed_config(embed_config)
logger.info(
f"Initialized LangStudio embedding model with {converted_embed_config}."
)
return create_embedding(converted_embed_config, pai_rag_model_dir)
else:
raise ValueError(f"Unknown Embedding source: {embed_config}")

Expand Down
100 changes: 100 additions & 0 deletions src/pai_rag/integrations/embeddings/pai/langstudio_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
from alibabacloud_credentials.client import Client as CredentialClient
from alibabacloud_credentials.models import Config as CredentialConfig
from alibabacloud_pailangstudio20240710.client import Client as LangStudioClient
from alibabacloud_pailangstudio20240710.models import (
GetConnectionRequest,
ListConnectionsRequest,
)
from alibabacloud_tea_openapi import models as open_api_models
from pai_rag.utils.constants import DEFAULT_DASHSCOPE_EMBEDDING_MODEL
from pai_rag.integrations.embeddings.pai.pai_embedding_config import parse_embed_config
from loguru import logger


def get_region_id():
return next(
(
os.environ[key]
for key in ["REGION", "REGION_ID", "ALIBABA_CLOUD_REGION_ID"]
if key in os.environ and os.environ[key]
),
"cn-hangzhou",
)


def get_connection_info(region_id: str, connection_name: str, workspace_id: str):
"""
Get Connection information from LangStudio API.
"""
config1 = CredentialConfig(
type="access_key",
access_key_id=os.environ.get("ALIBABA_CLOUD_ACCESS_KEY_ID"),
access_key_secret=os.environ.get("ALIBABA_CLOUD_ACCESS_KEY_SECRET"),
)
public_endpoint = f"pailangstudio.{region_id}.aliyuncs.com"
client = LangStudioClient(
config=open_api_models.Config(
# Use default credential chain, see:
# https://help.aliyun.com/zh/sdk/developer-reference/v2-manage-python-access-credentials#3ca299f04bw3c
credential=CredentialClient(config=config1),
endpoint=public_endpoint,
)
)
resp = client.list_connections(
request=ListConnectionsRequest(
connection_name=connection_name, workspace_id=workspace_id, max_results=50
)
)
connection_info = next(
(
conn
for conn in resp.body.connections
if conn.connection_name == connection_name
),
None,
)
if not connection_info:
raise ValueError(f"Connection {connection_name} not found")
ls_connection = client.get_connection(
connection_id=connection_info.connection_id,
request=GetConnectionRequest(
workspace_id=workspace_id,
encrypt_option="PlainText",
),
)
conn_info = ls_connection.body
configs = conn_info.configs or {}
secrets = conn_info.secrets or {}

logger.info(f"Configs conn_info:\n {conn_info}")
return conn_info, configs, secrets


def convert_langstudio_embed_config(embed_config):
region_id = embed_config.region_id or get_region_id()
conn_info, config, secrets = get_connection_info(
region_id, embed_config.connection_name, embed_config.workspace_id
)
if conn_info.custom_type == "OpenEmbeddingConnection":
return parse_embed_config(
{
"source": "openai",
"api_key": secrets.get("api_key", None),
"api_base": config.get("base_url", None),
"model": embed_config.model,
"embed_batch_size": embed_config.embed_batch_size,
}
)
elif conn_info.custom_type == "DashScopeConnection":
return parse_embed_config(
{
"source": "dashscope",
"api_key": secrets.get("api_key", None)
or os.getenv("DASHSCOPE_API_KEY"),
"model": embed_config.model or DEFAULT_DASHSCOPE_EMBEDDING_MODEL,
"embed_batch_size": embed_config.embed_batch_size,
}
)
else:
raise ValueError(f"Unknown connection type: {conn_info.custom_type}")
10 changes: 10 additions & 0 deletions src/pai_rag/integrations/embeddings/pai/pai_embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class SupportedEmbedType(str, Enum):
openai = "openai"
huggingface = "huggingface"
cnclip = "cnclip" # Chinese CLIP
langstudio = "langstudio"


class PaiBaseEmbeddingConfig(BaseModel):
Expand Down Expand Up @@ -42,6 +43,7 @@ class OpenAIEmbeddingConfig(PaiBaseEmbeddingConfig):
source: Literal[SupportedEmbedType.openai] = SupportedEmbedType.openai
model: str | None = None # use default
api_key: str | None = None # use default
api_base: str | None = None # use default


class HuggingFaceEmbeddingConfig(PaiBaseEmbeddingConfig):
Expand All @@ -54,6 +56,14 @@ class CnClipEmbeddingConfig(PaiBaseEmbeddingConfig):
model: str | None = "ViT-L-14"


class LangStudioEmbeddingConfig(PaiBaseEmbeddingConfig):
source: Literal[SupportedEmbedType.langstudio] = SupportedEmbedType.langstudio
region_id: str | None = "cn-hangzhou" # use default
connection_name: str | None = None
workspace_id: str | None = None
model: str | None = None


SupporttedEmbeddingClsMap = {
cls.get_type(): cls for cls in PaiBaseEmbeddingConfig.get_subclasses()
}
Expand Down
1 change: 1 addition & 0 deletions src/pai_rag/tools/data_process/docker/Dockerfile_cpu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ COPY . .

RUN poetry install && rm -rf $POETRY_CACHE_DIR
RUN poetry run aliyun-bootstrap -a install
RUN pip3 install https://sdk-portal-us-prod.oss-accelerate.aliyuncs.com/downloads/u-5fa6e81f-04cd-41d6-86ac-d8bffa4525e7-python-tea.zip
RUN pip3 install ray[default]

FROM python:3.11-slim AS prod
Expand Down
4 changes: 4 additions & 0 deletions src/pai_rag/tools/data_process/ops/embed_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
enable_sparse: bool = False,
enable_multimodal: bool = False,
multimodal_source: str = None,
connection_name: str = None,
workspace_id: str = None,
*args,
**kwargs,
):
Expand All @@ -39,6 +41,8 @@ def __init__(
"source": source,
"model": model,
"enable_sparse": enable_sparse,
"connection_name": connection_name,
"workspace_id": workspace_id,
}
)
# Init model download list
Expand Down
14 changes: 14 additions & 0 deletions src/pai_rag/tools/data_process/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def process_embedder(args):
"enable_sparse",
"enable_multimodal",
"multimodal_source",
"connection_name",
"workspace_id",
]
}
args.process.append("rag_embedder")
Expand Down Expand Up @@ -283,6 +285,18 @@ def init_configs():
default="cnclip",
help="Multi-modal embedding model source for rag_embedder operator.",
)
parser.add_argument(
"--connection_name",
type=str,
default=None,
help="Langstudio connection for rag_embedder operator.",
)
parser.add_argument(
"--workspace_id",
type=str,
default=None,
help="PAI workspace id for rag_embedder operator.",
)

args = parser.parse_args()

Expand Down
2 changes: 2 additions & 0 deletions src/pai_rag/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@
)

DEFAULT_DATAFILE_DIR = "./data"

DEFAULT_DASHSCOPE_EMBEDDING_MODEL = "text-embedding-v2"
Loading