Source code for text_machina.src.extractors.utils
from typing import List
import spacy
from tqdm import tqdm
from ..common import get_logger
from .types import SPACY_MODEL_MAPPING
_logger = get_logger(__name__)
[docs]def clean_inputs(texts: List[str]) -> List[str]:
"""
Remove special symbols from the texts used as prompt inputs, to
avoid breaking the format of classical kinds of prompts.
Args:
texts (List[str]): list of texts.
Returns:
List[str]: cleaned texts.
"""
clean = []
repl_map = {"\n": " ", "\t": " ", "\r": " ", "->": " ", ":": " "}
for text in texts:
for src, repl in repl_map.items():
text = text.replace(src, repl)
text = " ".join(text.split())
text = text.strip()
clean.append(text)
return clean
[docs]def get_spacy_model(language: str) -> spacy.lang:
"""
Gets or download a Spacy model.
Args:
language (str): language.
Returns:
spacy.lang: a Spacy model.
"""
spacy_model = SPACY_MODEL_MAPPING.get(
language, SPACY_MODEL_MAPPING["multilingual"]
)
try:
nlp = spacy.load(spacy_model)
except OSError:
_logger.info(f"Downloading {spacy_model} from SpaCy.")
spacy.cli.download(spacy_model)
nlp = spacy.load(spacy_model)
return nlp
[docs]def spacy_pipeline(
texts: List[str],
language: str,
disable_pipes: List[str] = [],
n_process: int = 4,
) -> List[spacy.tokens.Doc]:
"""
Processes texts with spacy pipeline for entity extraction.
Args:
texts (List[str]): list of texts.
language (str): language of the text.
disable_pipes (List[str]): Spacy pipes to be disabled.
n_process (int): number of processes.
Returns:
List[spacy.tokens.Doc]: list of Spacy docs.
"""
nlp = get_spacy_model(language)
processed_texts = list(
tqdm(
nlp.pipe(
texts,
n_process=n_process,
disable=disable_pipes,
),
total=len(texts),
desc="Processing",
)
)
return processed_texts