diff --git a/src/pai_rag/integrations/embeddings/pai/embedding_utils.py b/src/pai_rag/integrations/embeddings/pai/embedding_utils.py index 89cf4e9e..f0774c64 100644 --- a/src/pai_rag/integrations/embeddings/pai/embedding_utils.py +++ b/src/pai_rag/integrations/embeddings/pai/embedding_utils.py @@ -1,19 +1,19 @@ +import os +from llama_index.core import Settings +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.embeddings.dashscope import DashScopeEmbedding +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from pai_rag.utils.download_models import ModelScopeDownloader from pai_rag.integrations.embeddings.pai.pai_embedding_config import ( PaiBaseEmbeddingConfig, DashScopeEmbeddingConfig, OpenAIEmbeddingConfig, HuggingFaceEmbeddingConfig, CnClipEmbeddingConfig, + LangStudioEmbeddingConfig, ) - -from llama_index.core import Settings -from llama_index.embeddings.openai import OpenAIEmbedding -from llama_index.embeddings.dashscope import DashScopeEmbedding -from llama_index.embeddings.huggingface import HuggingFaceEmbedding from pai_rag.integrations.embeddings.clip.cnclip_embedding import CnClipEmbedding -import os from loguru import logger -from pai_rag.utils.download_models import ModelScopeDownloader def create_embedding( @@ -22,6 +22,7 @@ def create_embedding( if isinstance(embed_config, OpenAIEmbeddingConfig): embed_model = OpenAIEmbedding( api_key=embed_config.api_key, + api_base=embed_config.api_base, embed_batch_size=embed_config.embed_batch_size, callback_manager=Settings.callback_manager, ) @@ -91,7 +92,16 @@ def create_embedding( logger.info( f"Initialized CnClip embedding model {embed_config.model} with {embed_config.embed_batch_size} batch size." ) + elif isinstance(embed_config, LangStudioEmbeddingConfig): + from pai_rag.integrations.embeddings.pai.langstudio_utils import ( + convert_langstudio_embed_config, + ) + converted_embed_config = convert_langstudio_embed_config(embed_config) + logger.info( + f"Initialized LangStudio embedding model with {converted_embed_config}." + ) + return create_embedding(converted_embed_config, pai_rag_model_dir) else: raise ValueError(f"Unknown Embedding source: {embed_config}") diff --git a/src/pai_rag/integrations/embeddings/pai/langstudio_utils.py b/src/pai_rag/integrations/embeddings/pai/langstudio_utils.py new file mode 100644 index 00000000..d11efee8 --- /dev/null +++ b/src/pai_rag/integrations/embeddings/pai/langstudio_utils.py @@ -0,0 +1,100 @@ +import os +from alibabacloud_credentials.client import Client as CredentialClient +from alibabacloud_credentials.models import Config as CredentialConfig +from alibabacloud_pailangstudio20240710.client import Client as LangStudioClient +from alibabacloud_pailangstudio20240710.models import ( + GetConnectionRequest, + ListConnectionsRequest, +) +from alibabacloud_tea_openapi import models as open_api_models +from pai_rag.utils.constants import DEFAULT_DASHSCOPE_EMBEDDING_MODEL +from pai_rag.integrations.embeddings.pai.pai_embedding_config import parse_embed_config +from loguru import logger + + +def get_region_id(): + return next( + ( + os.environ[key] + for key in ["REGION", "REGION_ID", "ALIBABA_CLOUD_REGION_ID"] + if key in os.environ and os.environ[key] + ), + "cn-hangzhou", + ) + + +def get_connection_info(region_id: str, connection_name: str, workspace_id: str): + """ + Get Connection information from LangStudio API. + """ + config1 = CredentialConfig( + type="access_key", + access_key_id=os.environ.get("ALIBABA_CLOUD_ACCESS_KEY_ID"), + access_key_secret=os.environ.get("ALIBABA_CLOUD_ACCESS_KEY_SECRET"), + ) + public_endpoint = f"pailangstudio.{region_id}.aliyuncs.com" + client = LangStudioClient( + config=open_api_models.Config( + # Use default credential chain, see: + # https://help.aliyun.com/zh/sdk/developer-reference/v2-manage-python-access-credentials#3ca299f04bw3c + credential=CredentialClient(config=config1), + endpoint=public_endpoint, + ) + ) + resp = client.list_connections( + request=ListConnectionsRequest( + connection_name=connection_name, workspace_id=workspace_id, max_results=50 + ) + ) + connection_info = next( + ( + conn + for conn in resp.body.connections + if conn.connection_name == connection_name + ), + None, + ) + if not connection_info: + raise ValueError(f"Connection {connection_name} not found") + ls_connection = client.get_connection( + connection_id=connection_info.connection_id, + request=GetConnectionRequest( + workspace_id=workspace_id, + encrypt_option="PlainText", + ), + ) + conn_info = ls_connection.body + configs = conn_info.configs or {} + secrets = conn_info.secrets or {} + + logger.info(f"Configs conn_info:\n {conn_info}") + return conn_info, configs, secrets + + +def convert_langstudio_embed_config(embed_config): + region_id = embed_config.region_id or get_region_id() + conn_info, config, secrets = get_connection_info( + region_id, embed_config.connection_name, embed_config.workspace_id + ) + if conn_info.custom_type == "OpenEmbeddingConnection": + return parse_embed_config( + { + "source": "openai", + "api_key": secrets.get("api_key", None), + "api_base": config.get("base_url", None), + "model": embed_config.model, + "embed_batch_size": embed_config.embed_batch_size, + } + ) + elif conn_info.custom_type == "DashScopeConnection": + return parse_embed_config( + { + "source": "dashscope", + "api_key": secrets.get("api_key", None) + or os.getenv("DASHSCOPE_API_KEY"), + "model": embed_config.model or DEFAULT_DASHSCOPE_EMBEDDING_MODEL, + "embed_batch_size": embed_config.embed_batch_size, + } + ) + else: + raise ValueError(f"Unknown connection type: {conn_info.custom_type}") diff --git a/src/pai_rag/integrations/embeddings/pai/pai_embedding_config.py b/src/pai_rag/integrations/embeddings/pai/pai_embedding_config.py index 6b4a5b1f..7640a43d 100644 --- a/src/pai_rag/integrations/embeddings/pai/pai_embedding_config.py +++ b/src/pai_rag/integrations/embeddings/pai/pai_embedding_config.py @@ -12,6 +12,7 @@ class SupportedEmbedType(str, Enum): openai = "openai" huggingface = "huggingface" cnclip = "cnclip" # Chinese CLIP + langstudio = "langstudio" class PaiBaseEmbeddingConfig(BaseModel): @@ -42,6 +43,7 @@ class OpenAIEmbeddingConfig(PaiBaseEmbeddingConfig): source: Literal[SupportedEmbedType.openai] = SupportedEmbedType.openai model: str | None = None # use default api_key: str | None = None # use default + api_base: str | None = None # use default class HuggingFaceEmbeddingConfig(PaiBaseEmbeddingConfig): @@ -54,6 +56,14 @@ class CnClipEmbeddingConfig(PaiBaseEmbeddingConfig): model: str | None = "ViT-L-14" +class LangStudioEmbeddingConfig(PaiBaseEmbeddingConfig): + source: Literal[SupportedEmbedType.langstudio] = SupportedEmbedType.langstudio + region_id: str | None = "cn-hangzhou" # use default + connection_name: str | None = None + workspace_id: str | None = None + model: str | None = None + + SupporttedEmbeddingClsMap = { cls.get_type(): cls for cls in PaiBaseEmbeddingConfig.get_subclasses() } diff --git a/src/pai_rag/tools/data_process/docker/Dockerfile_cpu b/src/pai_rag/tools/data_process/docker/Dockerfile_cpu index 1fd32266..11cc797c 100644 --- a/src/pai_rag/tools/data_process/docker/Dockerfile_cpu +++ b/src/pai_rag/tools/data_process/docker/Dockerfile_cpu @@ -11,6 +11,7 @@ COPY . . RUN poetry install && rm -rf $POETRY_CACHE_DIR RUN poetry run aliyun-bootstrap -a install +RUN pip3 install https://sdk-portal-us-prod.oss-accelerate.aliyuncs.com/downloads/u-5fa6e81f-04cd-41d6-86ac-d8bffa4525e7-python-tea.zip RUN pip3 install ray[default] FROM python:3.11-slim AS prod diff --git a/src/pai_rag/tools/data_process/ops/embed_op.py b/src/pai_rag/tools/data_process/ops/embed_op.py index d3d55217..caa2773c 100644 --- a/src/pai_rag/tools/data_process/ops/embed_op.py +++ b/src/pai_rag/tools/data_process/ops/embed_op.py @@ -30,6 +30,8 @@ def __init__( enable_sparse: bool = False, enable_multimodal: bool = False, multimodal_source: str = None, + connection_name: str = None, + workspace_id: str = None, *args, **kwargs, ): @@ -39,6 +41,8 @@ def __init__( "source": source, "model": model, "enable_sparse": enable_sparse, + "connection_name": connection_name, + "workspace_id": workspace_id, } ) # Init model download list diff --git a/src/pai_rag/tools/data_process/run.py b/src/pai_rag/tools/data_process/run.py index 042c299d..7989331a 100644 --- a/src/pai_rag/tools/data_process/run.py +++ b/src/pai_rag/tools/data_process/run.py @@ -105,6 +105,8 @@ def process_embedder(args): "enable_sparse", "enable_multimodal", "multimodal_source", + "connection_name", + "workspace_id", ] } args.process.append("rag_embedder") @@ -283,6 +285,18 @@ def init_configs(): default="cnclip", help="Multi-modal embedding model source for rag_embedder operator.", ) + parser.add_argument( + "--connection_name", + type=str, + default=None, + help="Langstudio connection for rag_embedder operator.", + ) + parser.add_argument( + "--workspace_id", + type=str, + default=None, + help="PAI workspace id for rag_embedder operator.", + ) args = parser.parse_args() diff --git a/src/pai_rag/utils/constants.py b/src/pai_rag/utils/constants.py index 584ac691..daf92f6f 100644 --- a/src/pai_rag/utils/constants.py +++ b/src/pai_rag/utils/constants.py @@ -21,3 +21,5 @@ ) DEFAULT_DATAFILE_DIR = "./data" + +DEFAULT_DASHSCOPE_EMBEDDING_MODEL = "text-embedding-v2"