Random MLM replacement can sample special tokens
Hi, I noticed a small issue in the MLM masking logic.
In methylbert/src/data/vocab.py, the vocabulary is constructed with the special tokens prepended before the actual k-mer tokens:
special_tokens = ["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"]
self.pad_index = 0
self.unk_index = 1
self.eos_index = 2
self.sos_index = 3
self.mask_index = 4
self.itos = list(special_tokens) + vocabs
self.stoi = {t: i for i, t in enumerate(self.itos
So the first normal k-mer token starts at index 5.
In methylbert/src/data/dataset.py, lines 192–195, the 10% random replacement step samples from the full vocabulary:
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.vocab), labels.shape, dtype=torch.int16)
inputs[indices_random] = random_words[indices_random]
Because torch.randint(len(self.vocab), ...) samples from all token IDs, it can also sample the special-token IDs 0–4, meaning a masked k-mer may be randomly replaced with <pad>, <unk>,<eos>, <sos>, or <mask>.
I think the random replacement step should probably sample only from the actual k-mer vocabulary and exclude special tokens. For example:
num_special_tokens = 5
random_words = torch.randint(
low=num_special_tokens,
high=len(self.vocab),
size=labels.shape,
dtype=torch.int16,
)
inputs[indices_random] = random_words[indices_random]
Happy to open a PR if this looks right.
Random MLM replacement can sample special tokens
Hi, I noticed a small issue in the MLM masking logic.
In
methylbert/src/data/vocab.py, the vocabulary is constructed with the special tokens prepended before the actual k-mer tokens:So the first normal k-mer token starts at index 5.
In methylbert/src/data/dataset.py, lines 192–195, the 10% random replacement step samples from the full vocabulary:
Because torch.randint(len(self.vocab), ...) samples from all token IDs, it can also sample the special-token IDs 0–4, meaning a masked k-mer may be randomly replaced with <pad>, <unk>,<eos>, <sos>, or <mask>.
I think the random replacement step should probably sample only from the actual k-mer vocabulary and exclude special tokens. For example:
Happy to open a PR if this looks right.