Skip to content

Random MLM replacement can sample special tokens #28

@sbauer0

Description

@sbauer0

Random MLM replacement can sample special tokens

Hi, I noticed a small issue in the MLM masking logic.

In methylbert/src/data/vocab.py, the vocabulary is constructed with the special tokens prepended before the actual k-mer tokens:

special_tokens = ["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"]

self.pad_index = 0
self.unk_index = 1
self.eos_index = 2
self.sos_index = 3
self.mask_index = 4

self.itos = list(special_tokens) + vocabs
self.stoi = {t: i for i, t in enumerate(self.itos

So the first normal k-mer token starts at index 5.

In methylbert/src/data/dataset.py, lines 192–195, the 10% random replacement step samples from the full vocabulary:

# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.vocab), labels.shape, dtype=torch.int16)
inputs[indices_random] = random_words[indices_random]

Because torch.randint(len(self.vocab), ...) samples from all token IDs, it can also sample the special-token IDs 0–4, meaning a masked k-mer may be randomly replaced with <pad>, <unk>,<eos>, <sos>, or <mask>.

I think the random replacement step should probably sample only from the actual k-mer vocabulary and exclude special tokens. For example:

num_special_tokens = 5

random_words = torch.randint(
    low=num_special_tokens,
    high=len(self.vocab),
    size=labels.shape,
    dtype=torch.int16,
)
inputs[indices_random] = random_words[indices_random]

Happy to open a PR if this looks right.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions