Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions examples/libritts/tts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,50 @@ Trained on ~190k hours of data from [EMILIA](https://huggingface.co/datasets/amp
**Test set:** [seed-tts-zh](https://github.com/BytedanceSpeech/seed-tts-eval) and [seed-tts-en](https://github.com/BytedanceSpeech/seed-tts-eval) from the [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval) benchmark.


| LLM | Tokenizer | testset | CER/WER (%) | #N | #SUB | #INS + DEL | SS |
|------------|----------------------------|---------|---------|-------|------|------------|--------|
| Qwen2-0.5B | speech_tokenizer_v3_25hz | test-zh | 1.56 | 42241 | 562 | 96 | 0.812 |
| Qwen2-0.5B | speech_tokenizer_v3_25hz | test-en | 2.34 | 11820 | 212 | 65 | 0.822 |
| LLM | Tokenizer | mode | spk | testset | CER/WER (%) | #N | #SUB | #INS + DEL | SS |
|------------|----------------------------|----------| ---------------|---------|-------------|-------|------|------------|------------|
| Qwen2-0.5B | speech_tokenizer_v3_25hz | pretrain | - | test-zh | 1.56 | 42241 | 562 | 96 | 0.812 |
| Qwen2-0.5B | speech_tokenizer_v3_25hz | pretrain | - | test-en | 2.34 | 11820 | 212 | 65 | 0.822 |
| Qwen2-0.5B | speech_tokenizer_v3_25hz | SFT | biaobei female | test-zh | **1.25** | 42241 | 503 | 24 | **0.869** |
| Qwen2-0.5B | speech_tokenizer_v3_25hz | SFT | internal | test-en | **1.77** | 11820 | 175 | 34 | **0.911** |

- Details
- **spk (SFT)**
- `biaobei female`: [标贝中文标准女声音库](https://www.data-baker.com/open_source.html).
- `internal`: In-house English speaker (female); **1056** utterances, **~1.5** hours in total.

- **Training details**
```
LLM: 8 A800 GPUs, pack 20000, 264k steps
Flow: 8 A800 GPUs, batch 64, 100k steps
Pretrain:
LLM: 8 A800 GPUs, pack 20000, 264k steps, lr 3e-4
Flow: 8 A800 GPUs, batch 64, 100k steps, lr 3e-4

SFT:
LLM: 8 A800 GPUs, pack 20000, 7k steps, lr 4e-4
Flow: 8 A800 GPUs, batch 64, 7k steps, lr 3e-4
```

**CER/WER comparison with CosyVoice series (Seed-TTS test-zh / test-en)**

| Model | test-zh CER (%) | test-en WER (%) |
|---------------------------------------|-----------------|-----------------|
| Qwen2-0.5B + speech_tokenizer_v3_25hz | 1.56 | 2.34 |
| Qwen2-0.5B + speech_tokenizer_v3_25hz | *1.56* | *2.34* |
| CosyVoice | 3.63 | 4.29 |
| CosyVoice2 | 1.45 | 2.57 |
| CosyVoice3-0.5B | 1.16 | 2.02 |
| CosyVoice3-0.5B | 1.16 | 2.02 |

---
## Field reference

| Field | Meaning |
|------|---------|
| `txt` | Text aligned with the reference audio (prompt transcript). |
| `wav` | Path to the reference / prompt waveform. |
| `syn_txt` | Target text to synthesize (used at inference). |
| `spk` | Speaker id or short description (SFT). |
| `ins` | Speaker style instruction (SFT). |

**Special tokens**:

- `<|spk_eos|>` — end of speaker block
- `<|ins_eos|>` — end of instruction block
- `<|audio_bos|>` — start of discrete audio tokens
4 changes: 4 additions & 0 deletions examples/libritts/tts/conf/spk_prompt_wav_map.example.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"speaker_id_1": "/abs/path/to/speaker_id_1_prompt.wav",
"speaker_id_2": "/abs/path/to/speaker_id_2_prompt.wav"
}
8 changes: 6 additions & 2 deletions examples/libritts/tts/run_flow.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F ',' '{print NF}')

stage=train
data=data
dir=exp/touch_flow-Qwen2.5-0.5B-Audio-FSQ_v3_25hz-libritts
dir=exp/touch_flow-Qwen2.5-0.5B-Audio-FSQ_v3_25hz-spk_ins_sft
steps=50000 # training steps

# None for pretrain, JSON for sft
spk_prompt_wav_map_path=conf/spk_prompt_wav_map.example.json

. tools/parse_options.sh

if [ $stage == "data" ] || [ $stage == "all" ]; then
Expand Down Expand Up @@ -45,7 +48,8 @@ if [ $stage == "train" ] || [ $stage == "all" ]; then
--dataloader_prefetch_factor 10 \
--ignore_data_skip True \
--deepspeed conf/ds_config_zero2.json \
--accelerator_config conf/accelerator_config.json
--accelerator_config conf/accelerator_config.json \
--spk_prompt_wav_map_path $spk_prompt_wav_map_path
fi


Expand Down
8 changes: 7 additions & 1 deletion tools/add_speech_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
model = AutoModelForCausalLM.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)

special_audio_tokens = ['<|audio|>', '<|audio_bos|>', '<|audio_eos|>']
special_audio_tokens = [
"<|audio|>",
"<|audio_bos|>",
"<|audio_eos|>",
"<|spk_eos|>",
"<|ins_eos|>",
]

special_tokens_dict = {'additional_special_tokens': special_audio_tokens}

Expand Down
7 changes: 6 additions & 1 deletion west/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,13 @@ def main():
else: # load from pretrained
model = AutoModel.from_pretrained(training_args.model_config_or_dir)
config = model.config
p = (getattr(data_args, 'spk_prompt_wav_map_path', None) or '').strip()
tokenizer = model.init_tokenizer()
extractor = Extractor.get_class(model.model_type)(tokenizer, config)
extractor = Extractor.get_class(model.model_type)(
tokenizer,
config,
spk_prompt_wav_map_path=p or None,
)

print("Loading data...")
train_dataset = SpeechDataset(extractor, data_args)
Expand Down
20 changes: 15 additions & 5 deletions west/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import random
import sys
from dataclasses import dataclass, field
from typing import Dict
from typing import Dict, Optional

import torch
import torch.distributed as dist
Expand All @@ -22,6 +22,12 @@
class DataArguments:
data_path: str = field(default=None,
metadata={"help": "Path to the training data."})
spk_prompt_wav_map_path: Optional[str] = field(
default=None,
metadata={
"help":
"TouchFlow SFT: JSON mapping spk_id -> prompt wav for mel_speaker."
})
batch_size: int = field(default=1, metadata={"help": "batch size"})
pack_size: int = field(
default=0,
Expand Down Expand Up @@ -105,11 +111,11 @@ def _read_one(self):
x['txt'] = x['txt'].decode('utf8')
x['wav'] = io.BytesIO(x['wav'])
yield x
except Exception:
logging.info(f'Dataset decode error, {line}')
except Exception as e:
logging.info(f'Dataset decode error, {line}, {e}')
continue
except Exception:
logging.info(f'Dataset parsing error, {line}')
except Exception as e:
logging.info(f'Dataset parsing error, {line}, {e}')
continue

def _pack_sequence(self, seqs):
Expand Down Expand Up @@ -164,6 +170,8 @@ def _batch(self, seqs, pack=False):
fields_static = self.extractor.fields_batch_static - \
self.extractor.fields_pack_offset
for k in fields_dynamic:
if not all(k in s for s in seqs):
continue
if k == 'input_ids':
padding_value = self.tokenizer.pad_token_id
elif k == 'labels':
Expand All @@ -180,6 +188,8 @@ def _batch(self, seqs, pack=False):
ret['attention_mask'] = ret['input_ids'].ne(
self.tokenizer.pad_token_id)
for k in fields_static:
if not all(k in s for s in seqs):
continue
ret[k] = torch.tensor([s[k] for s in seqs], dtype=torch.int)
if not pack:
ret['batch_idx'] = torch.tensor(list(range(len(seqs))),
Expand Down
7 changes: 6 additions & 1 deletion west/dataset/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@ class Extractor(ABC):
fields_batch_dynamic = {}
fields_pack_offset = {}

def __init__(self, tokenizer, model_config, inference=False):
def __init__(self,
tokenizer,
model_config,
inference=False,
spk_prompt_wav_map_path=None):
self.tokenizer = tokenizer
self.model_config = model_config
self.inference = inference
self.spk_prompt_wav_map_path = spk_prompt_wav_map_path

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
Expand Down
72 changes: 62 additions & 10 deletions west/models/touch_flow/extractor_touch_flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright (c) 2025 Binbin Zhang(binbzha@qq.com)

import json
import logging
import os

import torch
import torchaudio
from torchaudio.compliance import kaldi
Expand All @@ -8,15 +12,49 @@
from west.utils.audio import mel_spectrogram


def _mel_speaker_fbank(waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
"""16k mono fbank for WeSpeaker, same as training pipeline."""
audio = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
mel_speaker = kaldi.fbank(audio,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
sample_frequency=16000,)
return mel_speaker - torch.mean(mel_speaker, 0)


class ExtractorTouchFlow(Extractor):
model_type = 'touch_flow'

fields_batch_dynamic = {'mel_speaker', 'mel_token', 'mel_vocoder'}

def __init__(self, tokenizer, model_config, inference=False):
super().__init__(tokenizer, model_config, inference)
def __init__(self,
tokenizer,
model_config,
inference=False,
spk_prompt_wav_map_path=None):
super().__init__(tokenizer,
model_config,
inference,
spk_prompt_wav_map_path=spk_prompt_wav_map_path)
if self.inference:
self.fields_batch_dynamic.add('llm_token')
self.spk_prompt_wav_map = {}
path = (self.spk_prompt_wav_map_path or '').strip()
if path and os.path.isfile(path):
with open(path, 'r', encoding='utf8') as f:
self.spk_prompt_wav_map = json.load(f)
logging.info('ExtractorTouchFlow: loaded spk_prompt_wav_map from %s (%d entries)',
path, len(self.spk_prompt_wav_map))
elif path:
logging.warning('ExtractorTouchFlow: spk_prompt_wav_map_path not found: %s', path)

def _mel_speaker_from_wav_path(self, wav_path: str) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(wav_path)
if waveform.size(0) > 1:
waveform = waveform[:1]
return _mel_speaker_fbank(waveform, sample_rate)

def extract(self, item):
import s3tokenizer
Expand All @@ -39,14 +77,28 @@ def extract(self, item):
fmax=8000,
center=False)
mel_vocoder = mel_vocoder[0].transpose(0, 1)
# for campplus-200k model, use povey window
mel_speaker = kaldi.fbank(audio,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
sample_frequency=16000,)
mel_speaker = mel_speaker - torch.mean(mel_speaker, 0)
spk_id = (item.get('spk') or '').strip()
use_prompt_spk = bool(spk_id and spk_id in self.spk_prompt_wav_map)
if use_prompt_spk:
# spk: str speaker id (e.g. "biaobei");
# lookup spk_prompt_wav_map[spk] -> prompt wav for mel_speaker
prompt_path = self.spk_prompt_wav_map[spk_id]
if not os.path.isfile(prompt_path):
logging.warning(
'ExtractorTouchFlow: spk %s map path missing %s, fallback to item wav',
spk_id, prompt_path)
mel_speaker = _mel_speaker_fbank(waveform, sample_rate)
else:
mel_speaker = self._mel_speaker_from_wav_path(prompt_path)
if self.inference:
logging.info('ExtractorTouchFlow: using spk %s prompt wav: %s', spk_id, prompt_path)
else:
if spk_id and self.spk_prompt_wav_map:
logging.warning(
'ExtractorTouchFlow: spk %r not in spk_prompt_wav_map, '
'mel_speaker from item wav', spk_id)
mel_speaker = _mel_speaker_fbank(waveform, sample_rate)

mel_token = s3tokenizer.log_mel_spectrogram(audio[0])
mel_token = mel_token.transpose(0, 1)
ret = {
Expand Down
Loading