Skip to content

[WIP][Common, PyTorch] Improve mHC to match DeepSeek's implementation#2953

Draft
kainzhong wants to merge 2 commits intoNVIDIA:mainfrom
kainzhong:feat/mhc_enhancement
Draft

[WIP][Common, PyTorch] Improve mHC to match DeepSeek's implementation#2953
kainzhong wants to merge 2 commits intoNVIDIA:mainfrom
kainzhong:feat/mhc_enhancement

Conversation

@kainzhong
Copy link
Copy Markdown
Collaborator

@kainzhong kainzhong commented May 1, 2026

Description

Some enhancement for mHC to better align with DeepSeek's tilelang implementation: https://github.com/deepseek-ai/TileKernels/tree/main/tile_kernels/mhc

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Allow mhc_fused_projection to accept arguments with mixed dtype: x.dtype=bf16, phi.dtype=fp32, which matches DeepSeek's implementation
  • mhc_fused_projection now outputs fp32 regardless of the input dtype, matching DeepSeek's implementation
  • Add fuse_grad_x_acc optimization, which will reuse the same grad_x buffer to accumulate the initial mHC input x's gradient for mhc_fused_expand_combine, mhc_fused_aggregate and mhc_fused_projection
  • Support norm_weight for mhc_fused_projection, which would be equivalent to apply RMSNorm in the unfused manner with elementwise_affine=True, which would be the learnable per-element affine parameters for RMSNorm
  • Add main_grad optimization for Megatron-LM integration, which will accumulate the gradient of phi, alpha, beta (they are all supposed to be torch.nn.Parameter) to main_grad if such attribute exists.
  • [TODO]: add checkpoint recomputing fused kernel to match DeepSeek's implementation
  • [TODO]: add a fused projection + aggregate only kernel (no expand & combine path) for the last mHC layer, which seems to be also used for MTP. See function learned_output_contract in [dev] [DeepSeek-v4] Part 3: MTP support with mHC and new mHC contract Megatron-LM#4518

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

kainzhong and others added 2 commits May 1, 2026 20:26
@kainzhong kainzhong changed the title [Common, PyTorch] Enhancement for mHC to match DeepSeek's Tilelang im… [Common, PyTorch] Improve mHC to match DeepSeek's implementation May 1, 2026
@kainzhong kainzhong changed the title [Common, PyTorch] Improve mHC to match DeepSeek's implementation [WIP][Common, PyTorch] Improve mHC to match DeepSeek's implementation May 1, 2026

@staticmethod
def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True):
def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True, fuse_grad_x_acc=True):
Copy link
Copy Markdown
Collaborator Author

@kainzhong kainzhong May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it default to False, and also add this to other two functions to make users aware

return grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None
if ctx.fuse_grad_x_acc:
# When fused x gradient accumulation is enabled, use fp32 for the accumulation buffer
x.untyped_storage().grad_x_acc = grad_x.to(torch.float32)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if this is overwriting anything that is already there

M = s * b

grad_x = torch.empty_like(x)
fuse_grad_x_acc = hasattr(x.untyped_storage(), "grad_x_acc")
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add fuse_grad_x_acc param in other functions as well


ctx.dtype = H.dtype
H = H.to(torch.float32)
ctx.alpha_main_grad = getattr(alpha, "main_grad", None)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With FSDP this will break
Check linear.py for reference
Talk with Varun

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant