Skip to content

archaeo_super_prompt.modeling.train

[docs] module archaeo_super_prompt.modeling.train

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""DAGs to train the FieldExtractor models."""

from typing import NamedTuple, cast

from archaeo_super_prompt.dataset.load import MagohDataset
from archaeo_super_prompt.modeling.struct_extract.field_extractor import (
    FieldExtractor,
)

from ..dataset.thesauri import load_comune
from ..types.pdfpaths import PDFPathDataset
from ..utils.result import get_model_store_dir
from .DAG_builder import DAGBuilder, DAGComponent
from .entity_extractor import NerModel, NeSelector
from .pdf_to_text import VLLM_Preprocessing
from .struct_extract.chunks_to_text import ChunksToText
from .struct_extract.extractors.archiving_date import ArchivingDateProvider
from .struct_extract.extractors.comune import ComuneExtractor
from .struct_extract.extractors.intervention_date import (
    InterventionStartExtractor,
)

class ExtractionDAGParts(NamedTuple):
    """A decomposition of the general DAG into different parts for a better handling between the training, the inference and the evaluation modes."""
    preprocessing_root: DAGBuilder
    extraction_parts: list[tuple[DAGComponent[FieldExtractor], DAGComponent]]
    final_component: tuple[DAGComponent, list[DAGComponent]]


def get_training_dag() -> ExtractionDAGParts:
    """Return the most advanced pre-processing DAG for the model.

    All its estimators and transformers are initialized with particular
    parametres.

    Return:
        A part of the complete DAG for getting the pre-processed data.
        The field extractors related to their parent node, to apply on these extractors special training or evaluation operations or to bind them to the preprocessing dag
        The final union component to finish the building of the complete DAG
        in inference mode.
    """
    llm_model_id = "google/gemma-3-27b-it"
    llm_provider = "vllm"
    llm_model_temp = 0.05

    vllm = DAGComponent(
        "vision-lm-Reader",
        VLLM_Preprocessing(
            vlm_provider="vllm",
            vlm_model_id="ibm-granite/granite-vision-3.3-2b",
            incipit_only=True,
            prompt="OCR this part of Italian document for markdown-based processing.",
            embedding_model_hf_id="nomic-ai/nomic-embed-text-v1.5",
        ),
    )
    ner = DAGComponent("NER-Extractor", NerModel())
    ner_featured = DAGComponent("ner-featured", "passthrough")
    archiving_date = DAGComponent(
        "archiving-date-Oracle", ArchivingDateProvider()
    )
    intervention_date_chunk_filter = DAGComponent(
        "interv-start-CF",
        NeSelector(
            "data",
            {
                "DATA",
            },
            lambda: list(
                enumerate(
                    [
                        "primavera",
                        "estate",
                        "autunno",
                        "inverno",
                    ]
                )
            ),
            True,
        ),
    )
    intervention_date_chunk_merger = DAGComponent(
        "interv-start-CM", ChunksToText()
    )
    intervention_date_extractor = DAGComponent(
        "interv-start-Extractor",
        InterventionStartExtractor(llm_provider, llm_model_id, llm_model_temp),
    )
    comune_extractor = DAGComponent(
        "comune-Extractor",
        ComuneExtractor(llm_provider, llm_model_id, llm_model_temp),
    )
    comune_chunk_filter = DAGComponent(
        "comune-CF",
        NeSelector(
            "comune",
            {
                "INDIRIZZO",
                "CODICE_POSTALE",
                "LUOGO",
            },
            load_comune,
        ),
    )
    comune_chunk_merger = DAGComponent("comune-CM", ChunksToText())

    intervention_date_entrypoint = DAGComponent(
        "interv-start-entrypoint", "passthrough"
    )
    final_results = DAGComponent[FieldExtractor]("FINAL", "passthrough")

    preprocessing_part = (
        DAGBuilder()
        .add_node(vllm)
        .add_node(ner, [vllm])
        .add_node(ner_featured, [vllm, ner])
        .add_node(archiving_date, [vllm])
        .add_linearly_chained_nodes(
            [comune_chunk_filter, comune_chunk_merger],
            [ner_featured],
        )
        .add_linearly_chained_nodes(
            [intervention_date_chunk_filter, intervention_date_chunk_merger],
            [ner_featured],
        )
        .add_node(
            intervention_date_entrypoint,
            [intervention_date_chunk_merger, archiving_date],
        )
    )
    extraction_part = cast(
        list[tuple[DAGComponent[FieldExtractor], DAGComponent]],
        [
            (intervention_date_extractor, intervention_date_entrypoint),
            (comune_extractor, comune_chunk_merger),
        ],
    )
    final_part = (
        final_results,
        [
            archiving_date,
            intervention_date_extractor,
            comune_extractor,
        ],
    )
    return ExtractionDAGParts(preprocessing_part, extraction_part, final_part)


def train_from_scratch(training_input: PDFPathDataset, ds: MagohDataset) -> ExtractionDAGParts:
    """Return the most advanced DAG model, fitted from the data.

    Apply a training for each FieldExtractor model.
    """
    preprocessing_part, extraction_part, final_part = get_training_dag()
    preprocess_pipeline = preprocessing_part.make_dag()
    preprocessed_inputs = preprocess_pipeline.fit_transform(training_input, ds)
    for fe_component, dep in extraction_part:
        field_extractor = fe_component.component
        if isinstance(field_extractor, str):
            # impossible
            continue
        field_extractor.fit(preprocessed_inputs[dep.component_id], ds)
        field_extractor.prompt_model_.save(
            get_model_store_dir() / f"{fe_component.component_id}.json"
        )
    return ExtractionDAGParts(preprocessing_part, extraction_part, final_part)


def get_fitted_model(training_input: PDFPathDataset, ds: MagohDataset):
    """Return the most advanced DAG model, mockly fitted from the data.

    The FieldExtractor model are supposed already fitted from saved dspy
    models in get_model_store_dir() path.
    """
    preprocessing_part, extraction_part, final_part = get_training_dag()
    preprocess_pipeline = preprocessing_part.make_dag()
    preprocessed_inputs = preprocess_pipeline.fit_transform(training_input, ds)
    for fe_component, dep in extraction_part:
        field_extractor = fe_component.component
        if isinstance(field_extractor, str):
            # impossible
            continue
        field_extractor.fit(
            preprocessed_inputs[dep.component_id],
            ds,
            compiled_dspy_model_path=get_model_store_dir()
            / f"{fe_component.component_id}.json",
        )
    return ExtractionDAGParts(preprocessing_part, extraction_part, final_part)


# TODO: set the inference from the paths and the evaluation