Skip to content

Fix MoE expert FSDP mesh info for HSDP#9

Open
samuelwheeler wants to merge 2 commits into
saforem2:ezpzfrom
samuelwheeler:fix/moe-hsdp-mesh-info
Open

Fix MoE expert FSDP mesh info for HSDP#9
samuelwheeler wants to merge 2 commits into
saforem2:ezpzfrom
samuelwheeler:fix/moe-hsdp-mesh-info

Conversation

@samuelwheeler

@samuelwheeler samuelwheeler commented May 11, 2026

Copy link
Copy Markdown

Summary

Fixes the custom MoE expert FSDP sharding path for HSDP meshes.

The previous code constructed FSDPMeshInfo for both 1D and 2D meshes. For 2D HSDP meshes, this can produce DTensor specs whose mesh has both replicate and shard dimensions but whose placements only describe one axis.

This changes the custom placement path to use HSDPMeshInfo for 2D meshes and preserves FSDPMeshInfo for 1D meshes.

Validation

  • git diff --cached --check
  • .venv/bin/python -m py_compile torchtitan/experiments/ezpz/moe/parallelize.py
  • Validated in Aurora EP/HSDP MoE runs while debugging EP setup.

Summary by Sourcery

Bug Fixes:

  • Ensure 2D HSDP meshes use HSDPMeshInfo so DTensor placements include both shard and replicate axes instead of an incomplete single-axis spec.

@sourcery-ai

sourcery-ai Bot commented May 11, 2026

Copy link
Copy Markdown

Reviewer's Guide

Adjusts MoE expert FSDP sharding for HSDP by using HSDPMeshInfo for 2D meshes, adds mesh-dimension name handling, and aligns expert and data-parallel mesh info with fully_shard()’s expected 1D/2D behavior.

File-Level Changes

Change Details Files
Introduce helper utilities to derive FSDP/HSDP mesh info from DeviceMesh with dimension-name validation.
  • Add import of HSDPMeshInfo alongside FSDPMeshInfo and ShardPlacementResult.
  • Implement _mesh_dim() helper that resolves a named mesh dimension, enforces dim names for 2D meshes, and raises clear errors when names are missing or mismatched.
  • Implement _dp_mesh_info() helper that returns FSDPMeshInfo for 1D meshes and HSDPMeshInfo (with both shard and replicate dims) for 2D meshes, rejecting meshes with ndim other than 1 or 2.
torchtitan/experiments/ezpz/moe/parallelize.py
Align expert and data-parallel mesh info construction with fully_shard() behavior for 1D and 2D meshes.
  • Replace direct construction of FSDPMeshInfo with shard_mesh_dim=0 for edp_mesh and dp_mesh with calls to the new _dp_mesh_info() helper.
  • Configure edp_mesh_info to shard on dim name 'efsdp' and dp_mesh_info to shard on dim name 'fsdp', and use 'dp_replicate' as the replicate dimension for 2D meshes.
  • Document in comments that 2D meshes must use HSDPMeshInfo so DTensor placements include both replicate and shard axes, avoiding inconsistent 2D DeviceMesh placements.
torchtitan/experiments/ezpz/moe/parallelize.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've left some high level feedback:

  • The return type annotation of _dp_mesh_info is FSDPMeshInfo but it can return HSDPMeshInfo for 2D meshes; consider broadening or adjusting the type hint to match the actual return types.
  • Given _dp_mesh_info is used for both edp_mesh_info and dp_mesh_info, consider renaming it (and its docstring/comment, if any) to something more general than dp_ to better reflect its purpose.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- The return type annotation of `_dp_mesh_info` is `FSDPMeshInfo` but it can return `HSDPMeshInfo` for 2D meshes; consider broadening or adjusting the type hint to match the actual return types.
- Given `_dp_mesh_info` is used for both `edp_mesh_info` and `dp_mesh_info`, consider renaming it (and its docstring/comment, if any) to something more general than `dp_` to better reflect its purpose.

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

@saforem2

Copy link
Copy Markdown
Owner

Folded into #12 alongside #10 and the expert-side pieces of #11. Your HSDPMeshInfo helper is in commit 09a7c697 (with the 1D/2D dispatch + axis-name lookups intact), wrapped in #11's import-fallback so XPU wheels lacking torch.distributed.fsdp._fully_shard._fsdp_common get a two-phase fully_shard instead. Going to keep this open until #12 lands so we don't lose the diff.

@saforem2

Copy link
Copy Markdown
Owner

FYI — pulled in upstream/main onto ezpz via #13, which now incorporates pytorch#3308 (use_grouped_mm removal). Your branch should still rebase cleanly since #9 only touches experiments/ezpz/moe/parallelize.py, but worth a quick git rebase ezpz once #13 lands. Withdrawing the prior #12 (which had bundled this in) — keeping #9 as the canonical home for the HSDP fix.

nscottnichols pushed a commit to nscottnichols/torchtitan that referenced this pull request Jun 8, 2026
…ch#3308

Replays the pytorch#3308 ("Remove MoE expert for-loop
fallback") deletion onto experiments/ezpz/moe/. That upstream PR
removed the `use_grouped_mm` config field from `GroupedExperts.Config`
and inlined `torch._grouped_mm` as the only expert path. Upstream's
argument is that `_grouped_mm` already provides a CUDA fallback on
pre-SM90 hardware. **XPU has no `_grouped_mm` kernel at all**, so the
unconditional path breaks ezpz on Aurora / Sunspot.

- `experiments/ezpz/moe/experts.py` (new): `EzpzGroupedExperts` subclass
  with `compute_backend: Literal["for_loop", "grouped_mm"]`. Default
  defers to upstream; `"for_loop"` re-vendors the per-expert matmul
  loop that pytorch#3308 deleted, restoring the XPU / pre-SM90 path.
- `experiments/ezpz/moe/__init__.py`: `make_ezpz_experts_config(...)`
  wrapper that calls upstream's `make_experts_config(...)` and
  re-wraps as `EzpzGroupedExperts.Config`. `_build_moe_layers` threads
  a `compute_backend` kwarg (default `"grouped_mm"`).
- `experiments/ezpz/moe/model.py`: `update_from_config` previously
  mutated `experts.use_grouped_mm = False` on pre-SM90 devices; that
  field no longer exists. Replaced with
  `experts_cfg.compute_backend = "for_loop"` guarded by a
  `getattr(..., "compute_backend", "grouped_mm")` so the block is
  robust to future config-shape changes.

Default behavior is unchanged for SM90+ CUDA: `compute_backend` defaults
to `"grouped_mm"`, which calls `super()._experts_forward(...)` and
hits the upstream path. The for_loop branch is only taken on devices
without `_grouped_mm`.

Pending PRs saforem2#9 / saforem2#10 / saforem2#11 will need to rebase onto this and adapt to
the `EzpzGroupedExperts` subclass — see docs/upstream-sync.md for
notes on what each will need to do.
saforem2 added a commit to nscottnichols/torchtitan that referenced this pull request Jun 11, 2026
Zero importers in the codebase. Verified with
    grep -rn "EzpzChunkedCELoss\|ezpz.loss" torchtitan/
which returned only the definition itself.

The class adds three Config knobs over upstream ChunkedCELoss
(empty_cache_between_chunks, keep_lm_head_unsharded_between_chunks,
sync_replicated_lm_head_grad) plus a ~150-line __call__ override.
Since no model spec / train.py / trainer.py instantiates it, the
file is 201 lines of inert code that adds sync burden against
upstream ChunkedCELoss changes (e.g. the FSDP set_reshard_after_forward
plumbing) without any consumer benefit.

Re-add as a thin subclass when there's a real consumer in
train.py / a model registry. The three knobs likely belong on
upstream ChunkedCELoss.Config directly anyway — they're not
MoE/SFT-specific.

Addresses review finding saforem2#9 on saforem2#14.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants