Skip to content

Commit

Permalink
Bugfix for rag_eval (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
wwxxzz authored Jan 23, 2025
1 parent fe822cc commit ac520d4
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 34 deletions.
8 changes: 5 additions & 3 deletions docs/qca_generation_and_evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,23 @@ RAG评估工具是一种用于测试和评估基于检索的文本生成系统
1. 示例配置如下

```yaml
- name: "exp1"
- name: "text_exp1"
eval_data_path: "example_data/eval_docs_text"
rag_setting_file: "src/pai_rag/config/evaluation/settings_eval_for_text.toml"
eval_model_llm:
source: "dashscope"
model: "qwen-max"
max_tokens: 1024
rag_setting_file: "src/pai_rag/config/evaluation/settings_eval_for_text.toml"
use_pai_eval: true
```
2. 参数说明:
- name: 评估实验名称。
- eval_data_path: 评估数据集路径,支持本地文件路径,或者oss路径。
- eval_model_llm: 用于评估大模型的配置,支持dashscope、openai、paieas等。
- rag_setting_file: rag配置文件路径。
- eval_model_llm: 用于评估大模型的配置,支持dashscope、openai、paieas等。
- use_pai_eval: 是否使用pai_llm_evals评估,默认为true,如果为false,则使用本地评估。
3. 评估维度:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ vector_store.type = "FAISS"
# endpoint = ""
# token = ""
[rag.llm]
source = "OpenAI"
model = "gpt-4o-2024-08-06"
source = "DashScope"
model = "qwen-max"

[rag.multimodal_embedding]
source = "cnclip"
Expand Down
19 changes: 9 additions & 10 deletions src/pai_rag/evaluation/dataset/rag_eval_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Type, Dict
from typing import List, Optional, Dict
from llama_index.core.bridge.pydantic import Field
import json
from llama_index.core.bridge.pydantic import BaseModel
Expand All @@ -11,31 +11,31 @@ class EvaluationSample(RagQcaSample):
"""Response Evaluation RAG example class."""

hitrate: Optional[float] = Field(
default_factory=None,
default=None,
description="The hitrate value for retrieval evaluation.",
)
mrr: Optional[float] = Field(
default_factory=None,
default=None,
description="The mrr value for retrieval evaluation.",
)

faithfulness_score: Optional[float] = Field(
default_factory=None,
default=None,
description="The faithfulness score for response evaluation.",
)

faithfulness_reason: Optional[str] = Field(
default_factory=None,
default=None,
description="The faithfulness reason for response evaluation.",
)

correctness_score: Optional[float] = Field(
default_factory=None,
default=None,
description="The correctness score for response evaluation.",
)

correctness_reason: Optional[str] = Field(
default_factory=None,
default=None,
description="The correctness reason for response evaluation.",
)
evaluated_by: Optional[CreatedBy] = Field(
Expand All @@ -49,7 +49,6 @@ def class_name(self) -> str:


class PaiRagEvalDataset(BaseModel):
_example_type: Type[EvaluationSample] = EvaluationSample # type: ignore[misc]
examples: List[EvaluationSample] = Field(
default=[], description="Data examples of this dataset."
)
Expand Down Expand Up @@ -93,7 +92,7 @@ def save_json(self, path: str) -> None:
self.cal_mean_metric_score()

with open(path, "w", encoding="utf-8") as f:
examples = [self._example_type.dict(el) for el in self.examples]
examples = [el.model_dump() for el in self.examples]
data = {
"examples": examples,
"results": self.results,
Expand All @@ -109,7 +108,7 @@ def from_json(cls, path: str) -> "PaiRagEvalDataset":
with open(path) as f:
data = json.load(f)

examples = [cls._example_type.parse_obj(el) for el in data["examples"]]
examples = [EvaluationSample.model_validate(el) for el in data["examples"]]
results = data["results"]
status = data["status"]

Expand Down
31 changes: 14 additions & 17 deletions src/pai_rag/evaluation/dataset/rag_qca_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Type
from typing import List, Optional
from llama_index.core.bridge.pydantic import Field
from llama_index.core.llama_dataset.base import BaseLlamaDataExample
from llama_index.core.llama_dataset import CreatedBy
Expand All @@ -13,49 +13,47 @@ class RagQcaSample(BaseLlamaDataExample):
to evaluate the prediction.
"""

query: str = Field(
default_factory=str, description="The user query for the example."
)
query: str = Field(default=str, description="The user query for the example.")
query_by: Optional[CreatedBy] = Field(
default=None, description="What generated the query."
)
reference_contexts: Optional[List[str]] = Field(
default_factory=None,
default=None,
description="The contexts used to generate the reference answer.",
)
reference_node_ids: Optional[List[str]] = Field(
default_factory=None, description="The node id corresponding to the contexts"
default=None, description="The node id corresponding to the contexts"
)
reference_image_url_list: Optional[List[str]] = Field(
default_factory=None,
default=None,
description="The image urls used to generate the reference answer.",
)
reference_answer: str = Field(
default_factory=str,
default=str,
description="The reference (ground-truth) answer to the example.",
)
reference_answer_by: Optional[CreatedBy] = Field(
default=None, description="What model generated the reference answer."
)

predicted_contexts: Optional[List[str]] = Field(
default_factory=None,
default=None,
description="The contexts used to generate the predicted answer.",
)
predicted_node_ids: Optional[List[str]] = Field(
default_factory=None,
default=None,
description="The node id corresponding to the predicted contexts",
)
predicted_node_scores: Optional[List[float]] = Field(
default_factory=None,
default=None,
description="The node score corresponding to the predicted contexts",
)
predicted_image_url_list: Optional[List[str]] = Field(
default_factory=None,
default=None,
description="The image urls used to generate the reference answer.",
)
predicted_answer: str = Field(
default_factory=str,
default="",
description="The predicted answer to the example.",
)
predicted_answer_by: Optional[CreatedBy] = Field(
Expand All @@ -69,9 +67,8 @@ def class_name(self) -> str:


class PaiRagQcaDataset(BaseModel):
_example_type: Type[RagQcaSample] = RagQcaSample # type: ignore[misc]
examples: List[RagQcaSample] = Field(
default=[], description="Data examples of this dataset."
default_factory=list, description="Data examples of this dataset."
)
labelled: bool = Field(
default=False, description="Whether the dataset is labelled or not."
Expand All @@ -88,7 +85,7 @@ def class_name(self) -> str:
def save_json(self, path: str) -> None:
"""Save json."""
with open(path, "w", encoding="utf-8") as f:
examples = [self._example_type.dict(el) for el in self.examples]
examples = [el.model_dump() for el in self.examples]
data = {
"examples": examples,
"labelled": self.labelled,
Expand All @@ -105,7 +102,7 @@ def from_json(cls, path: str) -> "PaiRagQcaDataset":
data = json.load(f)

if len(data["examples"]) > 0:
examples = [cls._example_type.parse_obj(el) for el in data["examples"]]
examples = [RagQcaSample.model_validate(el) for el in data["examples"]]
labelled = data["labelled"]
predicted = data["predicted"]

Expand Down
5 changes: 3 additions & 2 deletions src/pai_rag/evaluation/generator/rag_qca_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DEFAULT_TEXT_QA_PROMPT_TMPL,
DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL,
)
import asyncio


class RagQcaGenerator:
Expand Down Expand Up @@ -213,7 +214,7 @@ async def agenerate_labelled_qca_dataset(
async def agenerate_predicted_multimodal_qca_sample(self, qca_sample):
query_bundle = PaiQueryBundle(query_str=qca_sample.query)
response = await self._query_engine.aquery(query_bundle)

await asyncio.sleep(3)
qca_sample.predicted_answer = response.response
predicted_contexts = []
predicted_node_ids = []
Expand Down Expand Up @@ -247,7 +248,7 @@ async def agenerate_predicted_multimodal_qca_sample(self, qca_sample):
async def agenerate_predicted_qca_sample(self, qca_sample):
query_bundle = PaiQueryBundle(query_str=qca_sample.query)
response = await self._query_engine.aquery(query_bundle)

await asyncio.sleep(3)
qca_sample.predicted_answer = response.response
qca_sample.predicted_contexts = [
node.node.text for node in response.source_nodes
Expand Down

0 comments on commit ac520d4

Please sign in to comment.