Skip to content

Add support for flexible MTP layer architecture in MaxText.#3734

Draft
parambole wants to merge 1 commit intomainfrom
parambole/hybrid_mtp
Draft

Add support for flexible MTP layer architecture in MaxText.#3734
parambole wants to merge 1 commit intomainfrom
parambole/hybrid_mtp

Conversation

@parambole
Copy link
Copy Markdown
Collaborator

This change allows specifying a different decoder layer type for Multi-Token Prediction (MTP) blocks than the one used in the base model.

  • Modified get_decoder_layers in decoders.py and nnx_decoders.py to accept an optional decoder_block_type parameter.
  • Updated models.py to read mtp_decoder_type from config and use it to fetch the appropriate layer for MTP blocks, falling back to the base model's last layer if not specified.
  • Added mtp_decoder_type to base.yml and types.py.
  • Added a new unit test FlexibleMultiTokenPredictionBlockTest to multi_token_prediction_test.py to verify this behavior.

Verified on TPU VM with Llama base + Gemma2 MTP and Mixtral base + Llama MTP.

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

This change allows specifying a different decoder layer type for Multi-Token Prediction (MTP) blocks than the one used in the base model.

- Modified `get_decoder_layers` in `decoders.py` and `nnx_decoders.py` to accept an optional `decoder_block_type` parameter.
- Updated `models.py` to read `mtp_decoder_type` from config and use it to fetch the appropriate layer for MTP blocks, falling back to the base model's last layer if not specified.
- Added `mtp_decoder_type` to `base.yml` and `types.py`.
- Added a new unit test `FlexibleMultiTokenPredictionBlockTest` to `multi_token_prediction_test.py` to verify this behavior.

Verified on TPU VM with Llama base + Gemma2 MTP and Mixtral base + Llama MTP.
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 23, 2026

Codecov Report

❌ Patch coverage is 36.00000% with 16 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/models/models.py 17.64% 14 Missing ⚠️
src/maxtext/layers/nnx_decoders.py 60.00% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant