Source code for dstk.models.tools

"""
Utilities for building, executing, and extending modular linguistic workflows.

This module provides the infrastructure used to define and run workflow-based
processing pipelines. A workflow consists of a sequence of parameters, where
each parameter is associated with one or more processing methods or hooks.
The module dynamically loads and executes these methods, passing the output
of one step as the input to the next.

Core functionalities include:

* Building reusable workflow execution models with ``ModelBuilder``
* Running parameterized linguistic processing pipelines step by step
* Returning intermediate results from selected workflow stages
* Executing custom hooks alongside standard workflow parameters
* Dynamically exposing workflow methods through wrapper objects
* Adapting sentence and token sequence collections for uniform processing
* Defining a protocol for semantic similarity and nearest-neighbor operations
  on word embeddings

The module is intended for constructing flexible and reusable text-processing
pipelines, allowing researchers and digital humanities practitioners to
combine linguistic operations into configurable workflows.
"""

import importlib
import warnings
from functools import wraps

from ..hooks.tools import Hook

from .._dstk_utils.checks import check_return_results
from ..utilities.typeguards import is_workflow, is_sequences, is_documents

from typing import (
    Any,
    Sequence,
    overload,
    Literal,
    Protocol,
    Callable,
    Concatenate,
    ParamSpec,
    TypeVar,
    Generic,
    runtime_checkable,
)
from types import ModuleType
from ..lib_types import (
    ParameterResult,
    ReturnParameterGenerator,
    ReturnAllGenerator,
    Neighbors,
    Workflow,
    Sentence,
    LinguisticSequences,
    Word,
    MethodDict,
)

P = ParamSpec("P")
R = TypeVar("R")
W = TypeVar("W")


def _sequence_list_adaptor(
    sequence_list: LinguisticSequences, method: Callable, **kwargs: Any
) -> list[Sequence[Word]]:
    """
    Adapt a mixture of sentence and word sequences into a uniform list of word sequences.

    :param sequence_list: A collection of sentences, words, or tokens.
    :type sequence_list: LinguisticSequences
    :param method: The function to apply to each item in the list.
    :type method: Callable
    :param kwargs: Additional arguments passed to the processing method.
    :type kwargs: Any

    :return: A filtered list of non-empty word sequences.
    :rtype: list[Sequence[Word]]
    """
    processed_sequences = [
        method(sequence.words if isinstance(sequence, Sentence) else sequence, **kwargs)
        for sequence in sequence_list
    ]

    return [sequence for sequence in processed_sequences if sequence]


[docs] class Wrapper: def __init__(self, input_data: Any): """ Stores data and exposes workflow methods as attributes. :param input_data: Data wrapped by the instance. :type input_data: Any """ self._input_data: Any = input_data
[docs] def add_method(self, method: Callable[Concatenate[Any, P], R]): setattr(self, method.__name__, self._inject_data(method, self._input_data))
@staticmethod def _inject_data( func: Callable[Concatenate[Any, P], R], input_data: Any ) -> Callable[P, R]: @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: return func(input_data, *args, **kwargs) return wrapper
[docs] class ModelBuilder(Generic[W]): def __init__( self, workflow: Workflow, wrapper: bool = False, name: str | None = None, ) -> None: self.name = name self.workflow = workflow self.wrapper = wrapper def _run_workflow(self, input_data: Any) -> ReturnParameterGenerator: """ Executes the workflow and yields intermediate results. :param input_data: Input data processed by the workflow. :type input_data: Any :yield: The result of each workflow parameter. :rtype: ParameterResult """ is_document: bool = is_documents(input_data) parameters: list[str] = list(self.workflow.keys()) last_parameter: str = parameters[-1] result: Any = input_data for parameter in parameters: value: list[MethodDict] | Hook = self.workflow[parameter] if isinstance(value, Hook): if parameter == last_parameter and self.wrapper: warnings.warn( "You cannot return a Wrapper when the last parameter is a hook. Wrapper will be ignored." ) self.wrapper = False result = value(result) else: if parameter == last_parameter and self.wrapper: result = Wrapper(input_data=result) if "context" in parameter and is_document: results: list[Any] = [] for document in result: results.append( self._execute_methods( value, parameter, last_parameter, document ) ) result = results else: result = self._execute_methods( value, parameter, last_parameter, result ) yield ParameterResult(name=parameter, result=result) @overload def __call__(self, input_data: Any) -> W: ... @overload def __call__( self, input_data: Any, *, return_parameters: list[str], return_all: Literal[False], ) -> ReturnParameterGenerator: ... @overload def __call__( self, input_data: Any, *, return_parameters: None, return_all: Literal[True], ) -> ReturnAllGenerator: ... @overload def __call__( self, input_data: Any, *, return_parameters: list[str] | None = None, return_all: bool = False, ) -> W | ReturnParameterGenerator | ReturnAllGenerator: ... def __call__( self, input_data: Any, *, return_parameters: list[str] | None = None, return_all: bool = False, ) -> Any: """ Executes the workflow. :param input_data: Input data processed by the workflow. :type input_data: Any :param return_parameters: Names of parameters whose results should be returned. :type return_parameters: list[str] | None :param return_all: If ``True``, returns results for all parameters. :type return_all: bool :return: The final workflow result, a generator of selected results, or a generator of all parameter results. :rtype: Any | ReturnParameterGenerator | ReturnAllGenerator :raises ValueError: If the workflow format is invalid. """ if not is_workflow(self.workflow): raise ValueError( "The workflow provided does not follow the right format. Please enter a valid workflow." ) if return_parameters: check_return_results( return_list=return_parameters, callable_names=list(self.workflow.keys()), callable_type="parameter", ) return ( result for name, result in self._run_workflow(input_data) if name in return_parameters ) elif return_all: return self._run_workflow(input_data) else: result: Any = input_data for _, result in self._run_workflow(input_data): pass else: return result def _execute_methods( self, method_list: list[MethodDict], parameter: str, last_parameter: str, input_data: Any, ) -> Any: """ Executes all methods associated with a workflow parameter. :param method_list: Ordered methods to execute. :type method_list: list[MethodDict] :param parameter: Name of the current workflow parameter. :type parameter: str :param last_parameter: Name of the last workflow parameter. :type last_parameter: str :param input_data: Input passed to the first method. :type input_data: Any :return: Result produced by the last method. :rtype: Any :raises RuntimeError: If a method cannot be found. """ input_output: Any = input_data for method_dict in method_list: method_name, kwargs = next(iter(method_dict.items())) module: ModuleType = importlib.import_module(f"dstk.parameters.{parameter}") if not hasattr(module, method_name): raise RuntimeError( f"Module {parameter} does not have a method called {method_name}" ) method: Callable = getattr(module, method_name) if parameter == last_parameter and self.wrapper: if kwargs: warnings.warn( "Because you set wrapper=True, the arguments you passed to the methods in the workflow will be ignored." ) input_data.add_method(method) else: if is_sequences(input_output) and parameter.split(".")[0] == "context": input_output = _sequence_list_adaptor( input_output, method, **kwargs ) else: input_output = method(input_output, **kwargs) return input_output
[docs] @runtime_checkable class DistanceMeasurements(Protocol): """ Interface for semantic similarity methods based on word embeddings. This protocol represents any object that implements methods for computing cosine similarity and retrieving nearest neighbors. Methods: cos_similarity(first_word, second_word): Computes the cosine similarity between two words. Equivalent to ``dstk.modules.geometric_distance.cos_similarity``. nearest_neighbors(word, metric, n_words, **kwargs): Returns the nearest neighbors to a word using a specified metric. Equivalent to ``dstk.modules.geometric_distance.nearest_neighbors``. """
[docs] def cos_similarity(self, first_word: str, second_word: str) -> float: """ Return the cosine similarity between two words. """ ...
[docs] def nearest_neighbors( self, word: str, metric: str = "cosine", n_words: int = 5, **kwargs ) -> Neighbors: """ Return the top-N nearest neighbors to a word using a given metric. """ ...
[docs] def approximate_nearest_neighbors( self, word: str, metric: str = "ivf", n_words: int = 5, n_centroids: int = 100, clusters_to_search: int = 10, n_connections: int = 16, search_depth: int = 8, construction_depth: int = 64, ) -> Neighbors: """ Find words with similar embeddings using a fast, memory-efficient approximate search. This function returns the closest words to a target word without checking every possible word directly. Instead, it uses structures that give very close results much faster than an exact search, especially on large embedding sets. """ ...