-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_number_extractor.py
75 lines (61 loc) · 2.4 KB
/
custom_number_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import re
from typing import Any, Dict, List, Text, Type
from rasa.engine.graph import ExecutionContext, GraphComponent
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
from rasa.engine.storage.resource import Resource
from rasa.engine.storage.storage import ModelStorage
from rasa.shared.nlu.constants import ENTITIES, TEXT
from rasa.nlu.extractors.extractor import EntityExtractorMixin
from rasa.shared.nlu.training_data.message import Message
@DefaultV1Recipe.register(
DefaultV1Recipe.ComponentType.ENTITY_EXTRACTOR,
is_trainable=False
)
class CustomNumberExtractor(GraphComponent, EntityExtractorMixin):
"""Entity extractor which uses regular expressions to find numbers."""
@staticmethod
def get_default_config() -> Dict[Text, Any]:
"""The component's default config."""
return {
"number_pattern": r'\b\d+\b'
}
def __init__(self, config: Dict[Text, Any]) -> None:
"""Initialize CustomNumberExtractor."""
self._config = config
@classmethod
def create(
cls,
config: Dict[Text, Any],
model_storage: ModelStorage,
resource: Resource,
execution_context: ExecutionContext,
) -> GraphComponent:
"""Creates a new component."""
return cls(config)
def process(self, messages: List[Message]) -> List[Message]:
"""Extract numbers using regular expressions.
Args:
messages: List of messages to process.
Returns: The processed messages.
"""
number_pattern = re.compile(self._config["number_pattern"])
for message in messages:
text = message.get(TEXT)
matches = number_pattern.finditer(text)
extracted_entities = []
for match in matches:
start, end = match.span()
value = match.group()
entity = {
"entity": "number",
"value": value,
"start": start,
"confidence": None,
"end": end,
"extractor": "CustomNumberExtractor",
}
extracted_entities.append(entity)
message.set(
ENTITIES, message.get(ENTITIES, []) + extracted_entities, add_to_output=True
)
return messages