npu gemm patch#176
Merged
tastelikefeet merged 18 commits intomodelscope:mainfrom Apr 25, 2026
Merged
Conversation
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces NPU support for FSDP2 MoE models by implementing a monkey patch for the _grouped_mm function using torch_npu. It also adds a utility to detect NPU availability and a shell script for running the model on Ascend hardware. Key feedback points include addressing a missing dist import and an unused rank variable in fsdp2_moe.py, removing a redundant int64 conversion, and simplifying the logic for calculating group counts in the NPU patch.
Removed Chinese comments and unnecessary code comments for clarity.
tardis-key
reviewed
Apr 22, 2026
tardis-key
reviewed
Apr 22, 2026
tardis-key
reviewed
Apr 22, 2026
|
lgtm |
tastelikefeet
approved these changes
Apr 25, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
no_gemm_change.log
gemm_change.log
PR type
PR information
This PR adds an NPU-only monkey patch for the MoE grouped GEMM path used by
fsdp2_moe.py.When running on Ascend/NPU, the patch replaces
transformers.integrations.moe._grouped_mmwith an NPU implementation backed bytorch_npugrouped matmul. The goal is to improve MoE training performance on NPU while keeping numerical behavior aligned with the original implementation.What is changed
transformers.integrations.moe._grouped_mmwith_grouped_mm_npuon NPUfsdp2_moe.pywith the NPU patch enabledMotivation
For MoE training on Ascend/NPU, the default grouped GEMM path is not optimal for NPU execution. This PR introduces a minimal NPU-specific patch so that the MoE grouped matmul path can use the native NPU grouped GEMM kernel, improving training performance while keeping numerical behavior aligned.
Scope
This PR is intended for the NPU/Ascend path only.
Experiment results
Accuracy alignment
The patched run is numerically aligned with the original run in the compared logs.
Checked steps show identical values for both:
lossgrad_normExamples:
loss=11.7626,grad_norm=2.818920loss=11.8967,grad_norm=3.034608loss=11.2776,grad_norm=4.786496loss=10.4782,grad_norm=4.824494This indicates that the patched NPU grouped GEMM path is numerically aligned with the original implementation for this experiment.
Performance improvement
The patched run logs:
[PATCH] transformers.integrations.moe._grouped_mm -> _grouped_mm_npuAfter excluding the first warmup step:
This gives an approximate 3.77x speedup on the tested run.
A few concrete step time examples:
Baseline:
1.59992.27311.64191.6620Patched:
0.42500.99110.41470.4503The first step is slower in the patched run due to initialization and warmup overhead, but the steady-state step time is significantly lower.
Validation
Validation for this PR is based on:
lossandgrad_normon matched steps in the compared logsstep_timeNotes
This PR focuses on the NPU MoE grouped GEMM path only. It does not change the default behavior for non-NPU backends.