From 0f58e71b58330ff92494ad6ff090a0a93821d5dc Mon Sep 17 00:00:00 2001 From: "hao.yin" Date: Mon, 30 Mar 2026 04:13:57 +0000 Subject: [PATCH 1/3] [touch_tts] add tts sft pipeline & benchmark --- examples/libritts/tts/README.md | 46 +++++-- examples/libritts/tts/run_flow.sh | 8 +- tools/add_speech_tokens.py | 8 +- west/bin/train.py | 7 +- west/dataset/dataset.py | 85 +++++++++++-- west/dataset/extractor.py | 7 +- .../models/touch_flow/extractor_touch_flow.py | 72 +++++++++-- west/models/touch_tts/extractor_touch_tts.py | 113 ++++++++++++++---- west/models/touch_tts/modeling_touch_tts.py | 33 ++--- 9 files changed, 307 insertions(+), 72 deletions(-) diff --git a/examples/libritts/tts/README.md b/examples/libritts/tts/README.md index 1b3ad02..0d6b51f 100644 --- a/examples/libritts/tts/README.md +++ b/examples/libritts/tts/README.md @@ -90,22 +90,50 @@ Trained on ~190k hours of data from [EMILIA](https://huggingface.co/datasets/amp **Test set:** [seed-tts-zh](https://github.com/BytedanceSpeech/seed-tts-eval) and [seed-tts-en](https://github.com/BytedanceSpeech/seed-tts-eval) from the [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval) benchmark. -| LLM | Tokenizer | testset | CER/WER (%) | #N | #SUB | #INS + DEL | SS | -|------------|----------------------------|---------|---------|-------|------|------------|--------| -| Qwen2-0.5B | speech_tokenizer_v3_25hz | test-zh | 1.56 | 42241 | 562 | 96 | 0.812 | -| Qwen2-0.5B | speech_tokenizer_v3_25hz | test-en | 2.34 | 11820 | 212 | 65 | 0.822 | +| LLM | Tokenizer | mode | spk | testset | CER/WER (%) | #N | #SUB | #INS + DEL | SS | +|------------|----------------------------|----------| ---------------|---------|-------------|-------|------|------------|------------| +| Qwen2-0.5B | speech_tokenizer_v3_25hz | pretrain | - | test-zh | 1.56 | 42241 | 562 | 96 | 0.812 | +| Qwen2-0.5B | speech_tokenizer_v3_25hz | pretrain | - | test-en | 2.34 | 11820 | 212 | 65 | 0.822 | +| Qwen2-0.5B | speech_tokenizer_v3_25hz | SFT | biaobei female | test-zh | **1.25** | 42241 | 503 | 24 | **0.869** | +| Qwen2-0.5B | speech_tokenizer_v3_25hz | SFT | internal | test-en | **1.77** | 11820 | 175 | 34 | **0.911** | -- Details +- **spk (SFT)** + - `biaobei female`: [标贝中文标准女声音库](https://www.data-baker.com/open_source.html). + - `internal`: In-house English speaker (female); **1056** utterances, **~1.5** hours in total. + +- **Training details** ``` -LLM: 8 A800 GPUs, pack 20000, 264k steps -Flow: 8 A800 GPUs, batch 64, 100k steps +Pretrain: +LLM: 8 A800 GPUs, pack 20000, 264k steps, lr 3e-4 +Flow: 8 A800 GPUs, batch 64, 100k steps, lr 3e-4 + +SFT: +LLM: 8 A800 GPUs, pack 20000, 7k steps, lr 4e-4 +Flow: 8 A800 GPUs, batch 64, 7k steps, lr 3e-4 ``` **CER/WER comparison with CosyVoice series (Seed-TTS test-zh / test-en)** | Model | test-zh CER (%) | test-en WER (%) | |---------------------------------------|-----------------|-----------------| -| Qwen2-0.5B + speech_tokenizer_v3_25hz | 1.56 | 2.34 | +| Qwen2-0.5B + speech_tokenizer_v3_25hz | *1.56* | *2.34* | | CosyVoice | 3.63 | 4.29 | | CosyVoice2 | 1.45 | 2.57 | -| CosyVoice3-0.5B | 1.16 | 2.02 | \ No newline at end of file +| CosyVoice3-0.5B | 1.16 | 2.02 | + +--- +## Field reference + +| Field | Meaning | +|------|---------| +| `txt` | Text aligned with the reference audio (prompt transcript). | +| `wav` | Path to the reference / prompt waveform. | +| `syn_txt` | Target text to synthesize (used at inference). | +| `spk` | Speaker id or short description (SFT). | +| `ins` | Speaker style instruction (SFT). | + +**Special tokens**: + +- `<|spk_eos|>` — end of speaker block +- `<|ins_eos|>` — end of instruction block +- `<|audio_bos|>` — start of discrete audio tokens diff --git a/examples/libritts/tts/run_flow.sh b/examples/libritts/tts/run_flow.sh index 0d653bd..6599f8e 100755 --- a/examples/libritts/tts/run_flow.sh +++ b/examples/libritts/tts/run_flow.sh @@ -9,9 +9,12 @@ num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F ',' '{print NF}') stage=train data=data -dir=exp/touch_flow-Qwen2.5-0.5B-Audio-FSQ_v3_25hz-libritts +dir=exp/touch_flow-Qwen2.5-0.5B-Audio-FSQ_v3_25hz-spk_ins_sft steps=50000 # training steps +# None for pretrain, JSON for sft +spk_prompt_wav_map_path=conf/spk_prompt_wav_map.example.json + . tools/parse_options.sh if [ $stage == "data" ] || [ $stage == "all" ]; then @@ -45,7 +48,8 @@ if [ $stage == "train" ] || [ $stage == "all" ]; then --dataloader_prefetch_factor 10 \ --ignore_data_skip True \ --deepspeed conf/ds_config_zero2.json \ - --accelerator_config conf/accelerator_config.json + --accelerator_config conf/accelerator_config.json \ + --spk_prompt_wav_map_path $spk_prompt_wav_map_path fi diff --git a/tools/add_speech_tokens.py b/tools/add_speech_tokens.py index 1924d7b..47c606f 100644 --- a/tools/add_speech_tokens.py +++ b/tools/add_speech_tokens.py @@ -11,7 +11,13 @@ model = AutoModelForCausalLM.from_pretrained(model_dir) tokenizer = AutoTokenizer.from_pretrained(model_dir) -special_audio_tokens = ['<|audio|>', '<|audio_bos|>', '<|audio_eos|>'] +special_audio_tokens = [ + "<|audio|>", + "<|audio_bos|>", + "<|audio_eos|>", + "<|spk_eos|>", + "<|ins_eos|>", +] special_tokens_dict = {'additional_special_tokens': special_audio_tokens} diff --git a/west/bin/train.py b/west/bin/train.py index 97327ce..1ce16ec 100644 --- a/west/bin/train.py +++ b/west/bin/train.py @@ -112,8 +112,13 @@ def main(): else: # load from pretrained model = AutoModel.from_pretrained(training_args.model_config_or_dir) config = model.config + p = (getattr(data_args, 'spk_prompt_wav_map_path', None) or '').strip() tokenizer = model.init_tokenizer() - extractor = Extractor.get_class(model.model_type)(tokenizer, config) + extractor = Extractor.get_class(model.model_type)( + tokenizer, + config, + spk_prompt_wav_map_path=p or None, + ) print("Loading data...") train_dataset = SpeechDataset(extractor, data_args) diff --git a/west/dataset/dataset.py b/west/dataset/dataset.py index 8bc9be9..4cf317c 100644 --- a/west/dataset/dataset.py +++ b/west/dataset/dataset.py @@ -6,7 +6,7 @@ import random import sys from dataclasses import dataclass, field -from typing import Dict +from typing import Dict, Optional import torch import torch.distributed as dist @@ -18,10 +18,34 @@ from west.dataset.extractor import Extractor +def _tar_sample_get_field(sample: dict, name: str): + """Resolve webdataset/tar field by exact key or suffix ``.``. + + Examples: ``txt`` / ``wav.txt`` -> ``txt``; ``wav`` / ``clip.wav`` -> ``wav``. + Prefers exact key ``name``; otherwise first key (sorted) ending with ``.``. + Skips dunder keys like ``__key__``. + """ + if name in sample: + return sample[name] + suf = '.' + name + for k in sorted(sample.keys()): + if not isinstance(k, str) or k.startswith('__'): + continue + if k.endswith(suf): + return sample[k] + return None + + @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) + spk_prompt_wav_map_path: Optional[str] = field( + default=None, + metadata={ + "help": + "TouchFlow SFT: JSON mapping spk_id -> prompt wav for mel_speaker. " + }) batch_size: int = field(default=1, metadata={"help": "batch size"}) pack_size: int = field( default=0, @@ -102,14 +126,55 @@ def _read_one(self): data = wds.tarfile_samples(src) for x in data: try: - x['txt'] = x['txt'].decode('utf8') - x['wav'] = io.BytesIO(x['wav']) - yield x - except Exception: - logging.info(f'Dataset decode error, {line}') + txt_val = _tar_sample_get_field(x, 'txt') + wav_val = _tar_sample_get_field(x, 'wav') + if txt_val is None or wav_val is None: + logging.warning( + 'Dataset tar sample missing required txt ' + 'or wav, skip. url=%s keys=%s', + line, list(x.keys())) + continue + out = {} + if isinstance(txt_val, bytes): + out['txt'] = txt_val.decode('utf8') + elif isinstance(txt_val, str): + out['txt'] = txt_val + else: + logging.warning( + 'Dataset tar txt must be bytes or str, ' + 'got %s, url=%s', + type(txt_val), line) + continue + if not isinstance(wav_val, bytes): + logging.warning( + 'Dataset tar wav must be bytes, ' + 'got %s, url=%s', + type(wav_val), line) + continue + out['wav'] = io.BytesIO(wav_val) + # for sft mode, spk & instruction are optional + for opt_key in ('spk', 'ins'): + v = _tar_sample_get_field(x, opt_key) + if v is None: + continue + if isinstance(v, bytes): + out[opt_key] = v.decode('utf8') + elif isinstance(v, str): + out[opt_key] = v + else: + logging.warning( + 'Dataset tar %s must be bytes or str, ' + 'got %s, url=%s', + opt_key, type(v), line) + for meta in ('__key__', '__url__'): + if meta in x: + out[meta] = x[meta] + yield out + except Exception as e: + logging.info(f'Dataset decode error, {line}, {e}') continue - except Exception: - logging.info(f'Dataset parsing error, {line}') + except Exception as e: + logging.info(f'Dataset parsing error, {line}, {e}') continue def _pack_sequence(self, seqs): @@ -164,6 +229,8 @@ def _batch(self, seqs, pack=False): fields_static = self.extractor.fields_batch_static - \ self.extractor.fields_pack_offset for k in fields_dynamic: + if not all(k in s for s in seqs): + continue if k == 'input_ids': padding_value = self.tokenizer.pad_token_id elif k == 'labels': @@ -180,6 +247,8 @@ def _batch(self, seqs, pack=False): ret['attention_mask'] = ret['input_ids'].ne( self.tokenizer.pad_token_id) for k in fields_static: + if not all(k in s for s in seqs): + continue ret[k] = torch.tensor([s[k] for s in seqs], dtype=torch.int) if not pack: ret['batch_idx'] = torch.tensor(list(range(len(seqs))), diff --git a/west/dataset/extractor.py b/west/dataset/extractor.py index 7094ba5..10cf688 100644 --- a/west/dataset/extractor.py +++ b/west/dataset/extractor.py @@ -13,10 +13,15 @@ class Extractor(ABC): fields_batch_dynamic = {} fields_pack_offset = {} - def __init__(self, tokenizer, model_config, inference=False): + def __init__(self, + tokenizer, + model_config, + inference=False, + spk_prompt_wav_map_path=None): self.tokenizer = tokenizer self.model_config = model_config self.inference = inference + self.spk_prompt_wav_map_path = spk_prompt_wav_map_path def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) diff --git a/west/models/touch_flow/extractor_touch_flow.py b/west/models/touch_flow/extractor_touch_flow.py index 98dbf62..6f63526 100644 --- a/west/models/touch_flow/extractor_touch_flow.py +++ b/west/models/touch_flow/extractor_touch_flow.py @@ -1,5 +1,9 @@ # Copyright (c) 2025 Binbin Zhang(binbzha@qq.com) +import json +import logging +import os + import torch import torchaudio from torchaudio.compliance import kaldi @@ -8,15 +12,49 @@ from west.utils.audio import mel_spectrogram +def _mel_speaker_fbank(waveform: torch.Tensor, sample_rate: int) -> torch.Tensor: + """16k mono fbank for WeSpeaker, same as training pipeline.""" + audio = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) + mel_speaker = kaldi.fbank(audio, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, + sample_frequency=16000,) + return mel_speaker - torch.mean(mel_speaker, 0) + + class ExtractorTouchFlow(Extractor): model_type = 'touch_flow' fields_batch_dynamic = {'mel_speaker', 'mel_token', 'mel_vocoder'} - def __init__(self, tokenizer, model_config, inference=False): - super().__init__(tokenizer, model_config, inference) + def __init__(self, + tokenizer, + model_config, + inference=False, + spk_prompt_wav_map_path=None): + super().__init__(tokenizer, + model_config, + inference, + spk_prompt_wav_map_path=spk_prompt_wav_map_path) if self.inference: self.fields_batch_dynamic.add('llm_token') + self.spk_prompt_wav_map = {} + path = (self.spk_prompt_wav_map_path or '').strip() + if path and os.path.isfile(path): + with open(path, 'r', encoding='utf8') as f: + self.spk_prompt_wav_map = json.load(f) + logging.info('ExtractorTouchFlow: loaded spk_prompt_wav_map from %s (%d entries)', + path, len(self.spk_prompt_wav_map)) + elif path: + logging.warning('ExtractorTouchFlow: spk_prompt_wav_map_path not found: %s', path) + + def _mel_speaker_from_wav_path(self, wav_path: str) -> torch.Tensor: + waveform, sample_rate = torchaudio.load(wav_path) + if waveform.size(0) > 1: + waveform = waveform[:1] + return _mel_speaker_fbank(waveform, sample_rate) def extract(self, item): import s3tokenizer @@ -39,14 +77,28 @@ def extract(self, item): fmax=8000, center=False) mel_vocoder = mel_vocoder[0].transpose(0, 1) - # for campplus-200k model, use povey window - mel_speaker = kaldi.fbank(audio, - num_mel_bins=80, - frame_length=25, - frame_shift=10, - dither=0.0, - sample_frequency=16000,) - mel_speaker = mel_speaker - torch.mean(mel_speaker, 0) + spk_id = (item.get('spk') or '').strip() + use_prompt_spk = bool(spk_id and spk_id in self.spk_prompt_wav_map) + if use_prompt_spk: + # spk: str speaker id (e.g. "biaobei"); + # lookup spk_prompt_wav_map[spk] -> prompt wav for mel_speaker + prompt_path = self.spk_prompt_wav_map[spk_id] + if not os.path.isfile(prompt_path): + logging.warning( + 'ExtractorTouchFlow: spk %s map path missing %s, fallback to item wav', + spk_id, prompt_path) + mel_speaker = _mel_speaker_fbank(waveform, sample_rate) + else: + mel_speaker = self._mel_speaker_from_wav_path(prompt_path) + if self.inference: + logging.info('ExtractorTouchFlow: using spk %s prompt wav: %s', spk_id, prompt_path) + else: + if spk_id and self.spk_prompt_wav_map: + logging.warning( + 'ExtractorTouchFlow: spk %r not in spk_prompt_wav_map, ' + 'mel_speaker from item wav', spk_id) + mel_speaker = _mel_speaker_fbank(waveform, sample_rate) + mel_token = s3tokenizer.log_mel_spectrogram(audio[0]) mel_token = mel_token.transpose(0, 1) ret = { diff --git a/west/models/touch_tts/extractor_touch_tts.py b/west/models/touch_tts/extractor_touch_tts.py index ed05b6c..540f66e 100644 --- a/west/models/touch_tts/extractor_touch_tts.py +++ b/west/models/touch_tts/extractor_touch_tts.py @@ -10,48 +10,109 @@ class ExtractorTouchTTS(Extractor): - model_type = 'touch_tts' - fields_batch_static = {'audio_offsets', 'text_lengths'} - fields_batch_dynamic = {'audio_features', 'input_ids', 'labels'} - fields_pack_offset = {'audio_offsets'} + model_type = "touch_tts" + fields_batch_static = {"audio_offsets", "text_lengths"} + fields_batch_dynamic = {"audio_features", "input_ids", "labels"} + fields_pack_offset = {"audio_offsets"} def extract(self, item): import s3tokenizer + IGNORE_TOKEN_ID = LabelSmoother.ignore_index - waveform, sample_rate = torchaudio.load(item['wav']) - duration = waveform.size(1) / sample_rate - if not self.inference and ( + spk = item.get("spk", "") + ins = item.get("ins", "") + sft_mode = bool(spk or ins) # is sft mode + + # Training: always require txt and wav (both pretrain and SFT) + # Inference: only SFT may optionally omit txt/wav (no prompt) + if not self.inference: + if "txt" not in item or "wav" not in item: + return None + else: + # Inference SFT without prompt: no txt or wav + if sft_mode and ("txt" not in item or "wav" not in item): + include_prompt = False + else: + include_prompt = True + if "txt" not in item or "wav" not in item: + return None + + if not self.inference or include_prompt: + waveform, sample_rate = torchaudio.load(item["wav"]) + duration = waveform.size(1) / sample_rate + if not self.inference and ( duration < self.model_config.min_speech_duration - or duration > self.model_config.max_speech_duration): - return None - audio = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) - audio = audio[0] # get the first channel - mel = s3tokenizer.log_mel_spectrogram(audio) - mel = mel.transpose(0, 1) - # There is 100 frames mel per second, and 25 tokens per second - num_audio_token = math.ceil(mel.size(0) * 25 / 100.0 - 1e-9) + or duration > self.model_config.max_speech_duration + ): + return None + audio = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) + audio = audio[0] + mel = s3tokenizer.log_mel_spectrogram(audio) + mel = mel.transpose(0, 1) + num_audio_token = math.ceil(mel.size(0) * 25 / 100.0 - 1e-9) + else: + mel = None + num_audio_token = 0 + if not self.inference: - content = item['txt'] + '<|audio_bos|>' + if sft_mode: + content = ( + spk + + "<|spk_eos|>" + + ins + + "<|ins_eos|>" + + item["txt"] + + "<|audio_bos|>" + ) + else: + content = item["txt"] + "<|audio_bos|>" token_lengths = 0 else: - content = item['txt'] + item['syn_txt'] + '<|audio_bos|>' - token_lengths = len(self.tokenizer.encode(item['syn_txt'])) - ids_text = [self.tokenizer.bos_token_id - ] + self.tokenizer.encode(content) + if sft_mode: + if include_prompt: + content = ( + spk + + "<|spk_eos|>" + + ins + + "<|ins_eos|>" + + item["txt"] + + item["syn_txt"] + + "<|audio_bos|>" + ) + else: + content = ( + spk + + "<|spk_eos|>" + + ins + + "<|ins_eos|>" + + item["syn_txt"] + + "<|audio_bos|>" + ) + else: + content = item["txt"] + item["syn_txt"] + "<|audio_bos|>" + token_lengths = len(self.tokenizer.encode(item["syn_txt"])) + + if self.inference: + print("content:", content) + ids_text = [self.tokenizer.bos_token_id] + self.tokenizer.encode(content) tgt_text = [IGNORE_TOKEN_ID] * len(ids_text) ids_audio = [0] * num_audio_token + if not self.inference: ids = ids_text + ids_audio + [self.tokenizer.eos_token_id] tgt = tgt_text + ids_audio + [self.tokenizer.eos_token_id] else: ids = ids_text + ids_audio tgt = tgt_text + ids_audio + input_ids = torch.tensor(ids, dtype=torch.long) tgt_ids = torch.tensor(tgt, dtype=torch.long) - return { - 'input_ids': input_ids, - 'labels': tgt_ids, - 'audio_features': mel, - 'audio_offsets': len(ids_text), - 'text_lengths': token_lengths + result = { + "input_ids": input_ids, + "labels": tgt_ids, + "audio_offsets": len(ids_text), + "text_lengths": token_lengths, } + if mel is not None: + result["audio_features"] = mel + return result diff --git a/west/models/touch_tts/modeling_touch_tts.py b/west/models/touch_tts/modeling_touch_tts.py index e1d5313..6466d18 100644 --- a/west/models/touch_tts/modeling_touch_tts.py +++ b/west/models/touch_tts/modeling_touch_tts.py @@ -47,25 +47,30 @@ def reorg_ids( inputs_embeds: Optional[torch.LongTensor] = None, ): """ Extract speech codes by speech tokenizer, and reorg that in - `input_ids`, `labels` + `input_ids`, `labels`. + When audio_features is None (e.g. SFT without prompt wav), + skip speech tokenizer and return text embeddings directly. """ - speech_codes, speech_codes_lens = self.speech_tokenizer.quantize( - audio_features.transpose(1, 2), audio_features_lengths) - for i in range(audio_features.size(0)): - b = batch_idx[i] - s, e = audio_offsets[i], audio_offsets[i] + speech_codes_lens[i] - ids = speech_codes[ - i, :speech_codes_lens[i]] + self.speech_code_start_idx - input_ids[b, s:e] = ids - labels[b, s:e] = ids + if audio_features is not None: + speech_codes, speech_codes_lens = self.speech_tokenizer.quantize( + audio_features.transpose(1, 2), audio_features_lengths) + for i in range(audio_features.size(0)): + b = batch_idx[i] + s, e = audio_offsets[i], audio_offsets[i] + speech_codes_lens[i] + ids = speech_codes[ + i, :speech_codes_lens[i]] + self.speech_code_start_idx + input_ids[b, s:e] = ids + labels[b, s:e] = ids + text_embs = self.llm.get_input_embeddings()(input_ids) if inputs_embeds is None: return text_embs, labels else: # replace speech token emb - for i in range(audio_features.size(0)): - b = batch_idx[i] - s, e = audio_offsets[i], audio_offsets[i] + speech_codes_lens[i] - inputs_embeds[b, s:e] = text_embs[b, s:e] + if audio_features is not None: + for i in range(audio_features.size(0)): + b = batch_idx[i] + s, e = audio_offsets[i], audio_offsets[i] + speech_codes_lens[i] + inputs_embeds[b, s:e] = text_embs[b, s:e] return inputs_embeds, labels @torch.autocast(device_type="cuda", dtype=torch.bfloat16) From 326cb17f0b5fb022542884a984a03b78f03faf0d Mon Sep 17 00:00:00 2001 From: "hao.yin" Date: Mon, 30 Mar 2026 09:28:02 +0000 Subject: [PATCH 2/3] [touch_tts] add spk_prompt_wav_map.example.json --- examples/libritts/tts/conf/spk_prompt_wav_map.example.json | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 examples/libritts/tts/conf/spk_prompt_wav_map.example.json diff --git a/examples/libritts/tts/conf/spk_prompt_wav_map.example.json b/examples/libritts/tts/conf/spk_prompt_wav_map.example.json new file mode 100644 index 0000000..d0f0a93 --- /dev/null +++ b/examples/libritts/tts/conf/spk_prompt_wav_map.example.json @@ -0,0 +1,4 @@ +{ + "speaker_id_1": "/abs/path/to/speaker_id_1_prompt.wav", + "speaker_id_2": "/abs/path/to/speaker_id_2_prompt.wav" +} From 48ccb947c6f21010d3866dad8457a6879468394f Mon Sep 17 00:00:00 2001 From: "hao.yin" Date: Tue, 31 Mar 2026 06:42:51 +0000 Subject: [PATCH 3/3] [touch_tts] tts sft mode just support raw jsonl --- west/dataset/dataset.py | 67 +++-------------------------------------- 1 file changed, 4 insertions(+), 63 deletions(-) diff --git a/west/dataset/dataset.py b/west/dataset/dataset.py index 4cf317c..bbff337 100644 --- a/west/dataset/dataset.py +++ b/west/dataset/dataset.py @@ -18,24 +18,6 @@ from west.dataset.extractor import Extractor -def _tar_sample_get_field(sample: dict, name: str): - """Resolve webdataset/tar field by exact key or suffix ``.``. - - Examples: ``txt`` / ``wav.txt`` -> ``txt``; ``wav`` / ``clip.wav`` -> ``wav``. - Prefers exact key ``name``; otherwise first key (sorted) ending with ``.``. - Skips dunder keys like ``__key__``. - """ - if name in sample: - return sample[name] - suf = '.' + name - for k in sorted(sample.keys()): - if not isinstance(k, str) or k.startswith('__'): - continue - if k.endswith(suf): - return sample[k] - return None - - @dataclass class DataArguments: data_path: str = field(default=None, @@ -44,7 +26,7 @@ class DataArguments: default=None, metadata={ "help": - "TouchFlow SFT: JSON mapping spk_id -> prompt wav for mel_speaker. " + "TouchFlow SFT: JSON mapping spk_id -> prompt wav for mel_speaker." }) batch_size: int = field(default=1, metadata={"help": "batch size"}) pack_size: int = field( @@ -126,50 +108,9 @@ def _read_one(self): data = wds.tarfile_samples(src) for x in data: try: - txt_val = _tar_sample_get_field(x, 'txt') - wav_val = _tar_sample_get_field(x, 'wav') - if txt_val is None or wav_val is None: - logging.warning( - 'Dataset tar sample missing required txt ' - 'or wav, skip. url=%s keys=%s', - line, list(x.keys())) - continue - out = {} - if isinstance(txt_val, bytes): - out['txt'] = txt_val.decode('utf8') - elif isinstance(txt_val, str): - out['txt'] = txt_val - else: - logging.warning( - 'Dataset tar txt must be bytes or str, ' - 'got %s, url=%s', - type(txt_val), line) - continue - if not isinstance(wav_val, bytes): - logging.warning( - 'Dataset tar wav must be bytes, ' - 'got %s, url=%s', - type(wav_val), line) - continue - out['wav'] = io.BytesIO(wav_val) - # for sft mode, spk & instruction are optional - for opt_key in ('spk', 'ins'): - v = _tar_sample_get_field(x, opt_key) - if v is None: - continue - if isinstance(v, bytes): - out[opt_key] = v.decode('utf8') - elif isinstance(v, str): - out[opt_key] = v - else: - logging.warning( - 'Dataset tar %s must be bytes or str, ' - 'got %s, url=%s', - opt_key, type(v), line) - for meta in ('__key__', '__url__'): - if meta in x: - out[meta] = x[meta] - yield out + x['txt'] = x['txt'].decode('utf8') + x['wav'] = io.BytesIO(x['wav']) + yield x except Exception as e: logging.info(f'Dataset decode error, {line}, {e}') continue