File size: 3,647 Bytes
f71bc95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List

from .constants import DEFAULT_ESPEAK_VOICE, EMOTION_TO_SYMBOL, INTENSITY_SYMBOLS


@dataclass(frozen=True)
class PreparedInput:
    text: str
    phonemes: List[str]
    token_ids: List[int]
    emotion: str
    intensity: float
    emotion_symbol: str
    intensity_symbol: str


def clamp_unit(value: float) -> float:
    if value != value:
        return 0.0

    if value < 0.0:
        return 0.0

    if value > 1.0:
        return 1.0

    return float(value)


def load_token_map(config: dict[str, Any]) -> Dict[str, int]:
    phoneme_id_map = config.get("phoneme_id_map")
    if not isinstance(phoneme_id_map, dict):
        raise KeyError("config.json is missing phoneme_id_map")

    token_map: Dict[str, int] = {}

    for symbol, raw_value in phoneme_id_map.items():
        if isinstance(raw_value, int):
            token_map[symbol] = raw_value
            continue

        if isinstance(raw_value, list) and len(raw_value) == 1:
            token_map[symbol] = int(raw_value[0])
            continue

        raise ValueError(
            f"Unsupported token mapping for symbol {symbol!r}: expected int or single-item list"
        )

    return token_map


def intensity_to_symbol(intensity: float) -> str:
    value = clamp_unit(intensity)
    idx = int(value * len(INTENSITY_SYMBOLS))
    idx = max(0, min(idx, len(INTENSITY_SYMBOLS) - 1))
    return INTENSITY_SYMBOLS[idx]


def normalize_emotion(emotion: str | None) -> str:
    value = (emotion or "neutral").strip().lower()
    if value not in EMOTION_TO_SYMBOL:
        raise ValueError(
            f"Unsupported emotion {emotion!r}. Expected one of: {', '.join(EMOTION_TO_SYMBOL)}"
        )

    return value


def phonemize_full_utterance(text: str, espeak_voice: str = DEFAULT_ESPEAK_VOICE) -> List[str]:
    try:
        from piper_phonemize import phonemize_espeak
    except ImportError as exc:
        raise ImportError(
            "wfloat-tts requires piper-phonemize for phonemization. "
            "Install it with: pip install \"piper-phonemize==1.3.0\" "
            "-f https://k2-fsa.github.io/icefall/piper_phonemize"
        ) from exc

    sentence_groups = phonemize_espeak(text, espeak_voice)
    phonemes: List[str] = []

    for group in sentence_groups:
        if not group:
            continue

        if phonemes:
            phonemes.append(" ")

        phonemes.extend(group)

    return phonemes


def prepare_input(
    text: str,
    config: dict[str, Any],
    emotion: str = "neutral",
    intensity: float = 0.5,
    espeak_voice: str = DEFAULT_ESPEAK_VOICE,
) -> PreparedInput:
    normalized_emotion = normalize_emotion(emotion)
    normalized_intensity = clamp_unit(intensity)

    phonemes = phonemize_full_utterance(text, espeak_voice=espeak_voice)
    emotion_symbol = EMOTION_TO_SYMBOL[normalized_emotion]
    intensity_symbol = intensity_to_symbol(normalized_intensity)
    phonemes.extend([emotion_symbol, intensity_symbol])

    token_map = load_token_map(config)

    missing = [symbol for symbol in phonemes if symbol not in token_map]
    if missing:
        joined = ", ".join(sorted(set(missing)))
        raise KeyError(f"Missing symbol(s) in config.json phoneme_id_map: {joined}")

    token_ids = [token_map[symbol] for symbol in phonemes]

    return PreparedInput(
        text=text,
        phonemes=phonemes,
        token_ids=token_ids,
        emotion=normalized_emotion,
        intensity=normalized_intensity,
        emotion_symbol=emotion_symbol,
        intensity_symbol=intensity_symbol,
    )