Skip to content

archaeo_super_prompt.modeling.other_dag

[docs] module archaeo_super_prompt.modeling.other_dag

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""DAGs to train the FieldExtractor models."""

from typing import NamedTuple

from archaeo_super_prompt.modeling.struct_extract.field_extractor import (
    FieldExtractor,
)

from ..dataset.thesauri import load_comune
from .DAG_builder import DAGBuilder, DAGComponent
from .entity_extractor import NerModel, NeSelector
from .pdf_to_text import VLLM_Preprocessing
from .struct_extract.chunks_to_text import ChunksToText
from .struct_extract.extractors.archiving_date import ArchivingDateProvider
from .struct_extract.extractors.comune import ComuneExtractor
from .struct_extract.extractors.intervention_date import (
    InterventionStartExtractor,
)


class ExtractionDAGParts(NamedTuple):
    """A decomposition of the general DAG into different parts for a better handling between the training, the inference and the evaluation modes."""

    preprocessing_root: DAGBuilder
    extraction_parts: list[tuple[DAGComponent[FieldExtractor], DAGComponent]]
    final_component: tuple[DAGComponent, list[DAGComponent]]


def get_advanced_pipeline() -> DAGBuilder:
    """Return the most advanced pre-processing DAG for the model.

    All its estimators and transformers are initialized with particular
    parametres.

    Return:
        A part of the complete DAG for getting the pre-processed data.
        The field extractors related to their parent node, to apply on these extractors special training or evaluation operations or to bind them to the preprocessing dag
        The final union component to finish the building of the complete DAG
        in inference mode.
    """
    llm_model_id = "google/gemma-3-27b-it"
    llm_provider = "vllm"
    llm_model_temp = 0.05

    vllm = DAGComponent(
        "vision-lm-Reader",
        VLLM_Preprocessing(
            vlm_provider="vllm",
            vlm_model_id="ibm-granite/granite-vision-3.3-2b",
            incipit_only=True,
            prompt="OCR this part of Italian document for markdown-based processing.",
            embedding_model_hf_id="nomic-ai/nomic-embed-text-v1.5",
        ),
    )
    ner = DAGComponent("NER-Extractor", NerModel())
    ner_featured = DAGComponent("ner-featured", "passthrough")
    archiving_date = DAGComponent(
        "archiving-date-Oracle", ArchivingDateProvider()
    )
    intervention_date_chunk_filter = DAGComponent(
        "interv-start-CF",
        NeSelector(
            "data",
            {
                "DATA",
            },
            lambda: list(
                enumerate(
                    [
                        "primavera",
                        "estate",
                        "autunno",
                        "inverno",
                    ]
                )
            ),
            True,
        ),
    )
    intervention_date_chunk_merger = DAGComponent(
        "interv-start-CM", ChunksToText()
    )
    intervention_date_extractor = DAGComponent(
        "interv-start-Extractor",
        InterventionStartExtractor(llm_provider, llm_model_id, llm_model_temp),
    )
    comune_extractor = DAGComponent(
        "comune-Extractor",
        ComuneExtractor(llm_provider, llm_model_id, llm_model_temp),
    )
    comune_chunk_filter = DAGComponent(
        "comune-CF",
        NeSelector(
            "comune",
            {
                "INDIRIZZO",
                "CODICE_POSTALE",
                "LUOGO",
            },
            load_comune,
        ),
    )
    comune_chunk_merger = DAGComponent("comune-CM", ChunksToText())

    intervention_date_entrypoint = DAGComponent(
        "interv-start-entrypoint", "passthrough"
    )
    final_results = DAGComponent[FieldExtractor]("FINAL", "passthrough")

    fi_entrypoint = DAGComponent("fonte-informaz-entrypoint", "passthrough")
    fonte_informazione = DAGComponent(
        "fonte-informaz-Deductor", ArchivingDateProvider()
    )
    functionary_selector = DAGComponent(
        "functionary-CF",
        NeSelector("functionary", {"NOME", "COGNOME"}, lambda: [], True),
    )
    functionary_merger = DAGComponent("functionary-CM", ChunksToText())
    functionary_entrypoint = DAGComponent(
        "functionary-entrypoint", "passthrough"
    )
    functionary = DAGComponent(
        "functionary-Extractor", ArchivingDateProvider()
    )

    return (
        DAGBuilder()
        .add_node(vllm)
        .add_node(ner, [vllm])
        .add_node(ner_featured, [vllm, ner])
        .add_node(archiving_date, [vllm])
        .add_linearly_chained_nodes(
            [comune_chunk_filter, comune_chunk_merger],
            [ner_featured],
        )
        .add_linearly_chained_nodes(
            [intervention_date_chunk_filter, intervention_date_chunk_merger],
            [ner_featured],
        )
        .add_node(
            intervention_date_entrypoint,
            [intervention_date_chunk_merger, archiving_date],
        )
        .add_node(intervention_date_extractor, [intervention_date_entrypoint])
        .add_node(comune_extractor, [comune_chunk_merger])
        .add_linearly_chained_nodes(
            [fi_entrypoint, fonte_informazione],
            [archiving_date, comune_extractor],
        )
        .add_linearly_chained_nodes([functionary_selector, functionary_merger], [ner_featured])
        .add_node(functionary_entrypoint, [functionary_merger,
                                           intervention_date_extractor,
                                           comune_extractor])
        .add_node(functionary, [functionary_entrypoint])
        .add_node(final_results, [archiving_date, comune_extractor,
                                  intervention_date_extractor,
                                  fonte_informazione, functionary])
    )