Skip to content

FP8 Megablox for batch split#3770

Open
BirdsOfAFthr wants to merge 1 commit intomainfrom
amandaliang
Open

FP8 Megablox for batch split#3770
BirdsOfAFthr wants to merge 1 commit intomainfrom
amandaliang

Conversation

@BirdsOfAFthr
Copy link
Copy Markdown
Collaborator

@BirdsOfAFthr BirdsOfAFthr commented Apr 29, 2026

Description

(1) This update enables FP8 Megablox quantization support for DeepSeek batch split configurations.

When quantization is active, the following changes apply:

  • Kernel Quantization: gmm kernels allow FP8 recipes (defined via the MaxText command line) in both forward and backward passes.

  • gmm forward: weight is manually quantized to bypass the expcilt sharding error; activation is quantized using qwix

  • gmm backward: gradients are quantized using qwix.

(2) This change also enables merging gating gmm kernels.

In the previous SwiGLU/GLU implementation, the gate-projection and up-projection were processed using two sequential gmm_fn calls. By concatenating these weights and processing them together, we effectively double the contiguous hidden dimension of the kernel. This is especially critical for FP8 utilizing Expert Parallelism (EP) that shard along the contracting dimension. Because this sharding strategy inherently shrinks the local MLP hidden dimension on each device, the matrix multiplications can become small and bottlenecked by memory bandwidth. Merging $W_0$ and $W_1$ effectively gives us a 2X increase in that local dimension, restoring arithmetic intensity and hardware utilization.

Tests

  • Verification: Validated via end-to-end (e2e) perf and convergence benchmarks.

  • Coverage: Unit tests will be added in a subsequent update.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@gobbleturk
Copy link
Copy Markdown
Collaborator

Note there is more general support in this PR #3736

@BirdsOfAFthr BirdsOfAFthr changed the title Support merging gating gmm kernels FP8 Megablox for batch split Apr 29, 2026
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for GMM FP8 support with careful manual quantization, while bringing back merging gating from PR#3199! Had a comment wrt sharding, and other minor changes.

Comment thread src/maxtext/kernels/megablox/ops.py Outdated
Comment thread src/maxtext/kernels/megablox/ops.py Outdated
Comment thread src/maxtext/layers/quantizations.py Outdated
Comment thread src/maxtext/models/deepseek_batchsplit.py Outdated
Returns:
The result of the grouped matrix multiplication.
"""
if config.use_qwix_quantization:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this condition is clearer?

if config.quantization == "fp8_full":

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants