Skip to content

archaeo_super_prompt.modeling.entity_extractor.model

[docs] module archaeo_super_prompt.modeling.entity_extractor.model

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Core functions for inferring and filtering named entities in chunks."""

import itertools
from typing import cast

import requests
from tqdm import tqdm

from ...config.env import getenv_or_throw
from .types import CompleteEntity, NerOutput, NerXXLEntities


def _fetch_entities(ner_model_hosturl: str, chunks: list[str]) -> list[list[NerOutput]]:
    if not chunks:
        return []
    print("Fetching the transformers model")
    payload = {"chunks": chunks}
    response = requests.post(f"{ner_model_hosturl}/ner", json=payload,
                             timeout=60)
    response.raise_for_status()
    entities = list(
        map(
            lambda lst: list(map(lambda dct: NerOutput(**dct), lst)),
            cast(list[list[dict]], response.json()),
        )
    )
    return entities


def fetch_entities(chunks: list[str]):
    """Infer into the remote NER model to find named entities in each chunk."""
    ner_model_hosturl = getenv_or_throw("NER_MODEL_HOST_URL")
    return list(
        itertools.chain.from_iterable(
            _fetch_entities(ner_model_hosturl, list(c))
            for c in tqdm(
                itertools.batched(chunks, 50),
                desc="NER analysing",
                unit="Fraction of total text chunks",
                total=len(chunks) // 50 + int(len(chunks) % 50 != 0),
            )
        )
    )


def gatherEntityChunks(entity_chunks: list[NerOutput], confidence_treshold:
                       float):
    """Gather the chunk of entity output from one text chunk."""
    entity_set: list[CompleteEntity] = list()
    current_accumulated_entity: CompleteEntity | None = None
    for current_entity_chunk in entity_chunks:
        # Edge-case when a chunks is under the confidence treshold
        # We only keep the already added confident chunk of the entity
        # and ignore the following chunks
        if current_entity_chunk.score < confidence_treshold:
            if current_accumulated_entity is not None:
                entity_set.append(current_accumulated_entity)
                current_accumulated_entity = None
            continue

        if current_entity_chunk.entity.startswith("B-"):
            # Start a new entity with B- entities
            if current_accumulated_entity is not None:
                entity_set.append(current_accumulated_entity)
            current_accumulated_entity = CompleteEntity(
                entity=cast(NerXXLEntities, current_entity_chunk.entity[2:]),
                word=current_entity_chunk.word,
                start=current_entity_chunk.start,
                end=current_entity_chunk.end,
            )
        elif (
            current_accumulated_entity is not None
            # the condition below allows entities of the same type that
            # are consecutive or separated by one space to be merged
            # WARN: it is expected that the output content of the ner model
            # is normalized so words are only separated by 1 space at
            # maximum
            and abs(
                current_entity_chunk.start - current_accumulated_entity.end
            )
            <= 1
        ):
            current_accumulated_entity.end = current_entity_chunk.end
            # Complete an entity with its additional chunks
            if current_entity_chunk.word.startswith("##"):
                # the chunk belongs to the same entity word
                current_accumulated_entity.word += current_entity_chunk.word[
                    2:
                ]
            else:
                # the entity is composed of several words
                current_accumulated_entity.word += (
                    " " + current_entity_chunk.word
                )
    return entity_set


def postrocess_entities(
    entitiesPerTextChunk: list[list[NerOutput]], confidence_treshold: float
):
    """Return a set of the occured entities for each chunks.

    Arguments:
        entitiesPerTextChunk: for each chunk, a list of its retrieved \
entities ordered by their occurence in the chunk's text content
        confidence_treshold: a treshold between 0 and 1 to tolerate only a \
subset of entities
    """
    return [
        gatherEntityChunks(entity_chunks, confidence_treshold)
        for entity_chunks in entitiesPerTextChunk
    ]


def filter_entities(
    complete_entity_sets: list[
        list[CompleteEntity]
    ],  # List[Set[CompleteEntity]]
    allowed_entities: set[NerXXLEntities],
) -> list[list[CompleteEntity]]:  # List[Set[CompleteEntity]]
    """For each text chunk, keep only the entities included in the given group of allowed entity types."""
    return [
        list(filter(lambda e: e.entity in allowed_entities, s))
        for s in complete_entity_sets
    ]