Source code for text_machina.src.metrics.mauve

import json
from pathlib import Path
from typing import Dict

import matplotlib.pyplot as plt
import mauve
from datasets import Dataset

from ..common.exceptions import InvalidTaskTypeForMetric, UnsupportedMetricParam
from ..types import TaskType
from .base import Metric


[docs]class MAUVEMetric(Metric): """ Implements the MAUVE metric: https://arxiv.org/abs/2102.01454 Supported tasks: detection. """ def _run(self, dataset: Dataset, **kwargs) -> Dict: if self.task_type != TaskType.DETECTION: raise InvalidTaskTypeForMetric(self.name, self.task_type) self.check_kwargs_ok(kwargs) df = dataset.to_pandas() generated_texts = df[df["label"] == "generated"]["text"] human_texts = df[df["label"] == "human"]["text"] outputs = mauve.compute_mauve( p_text=generated_texts, q_text=human_texts, **kwargs ) return vars(outputs) def _save(self, outputs: Dict, path: Path) -> None: plt.plot( outputs["divergence_curve"][:, 1], outputs["divergence_curve"][:, 0] ) plt.savefig(path / "divergence_curve.pdf") result = {} for k, v in outputs.items(): if k in {"mauve", "frontier_integral"}: result[k] = v elif k in {"q_hist", "p_hist"}: result[k] = v.tolist() with open(path / "summary.json", "w") as f: json.dump(result, f, indent=4) def _log(self, outputs: Dict, logger) -> None: logger.info(f"MAUVE score: {outputs['mauve']}") logger.info(f"Frontier Integral: {outputs['frontier_integral']}")
[docs] def check_kwargs_ok(self, kwargs) -> None: # Can't accept these since they are mutually exclusive with passing text as input unsupported = [ x for x in {"p_tokens", "q_tokens", "p_features", "q_features"} if x in kwargs ] if unsupported: raise UnsupportedMetricParam(unsupported[0], self.name)