Fix MoE expert FSDP mesh info for HSDP#9
Conversation
Reviewer's GuideAdjusts MoE expert FSDP sharding for HSDP by using HSDPMeshInfo for 2D meshes, adds mesh-dimension name handling, and aligns expert and data-parallel mesh info with fully_shard()’s expected 1D/2D behavior. File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've left some high level feedback:
- The return type annotation of
_dp_mesh_infoisFSDPMeshInfobut it can returnHSDPMeshInfofor 2D meshes; consider broadening or adjusting the type hint to match the actual return types. - Given
_dp_mesh_infois used for bothedp_mesh_infoanddp_mesh_info, consider renaming it (and its docstring/comment, if any) to something more general thandp_to better reflect its purpose.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The return type annotation of `_dp_mesh_info` is `FSDPMeshInfo` but it can return `HSDPMeshInfo` for 2D meshes; consider broadening or adjusting the type hint to match the actual return types.
- Given `_dp_mesh_info` is used for both `edp_mesh_info` and `dp_mesh_info`, consider renaming it (and its docstring/comment, if any) to something more general than `dp_` to better reflect its purpose.Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
|
Folded into #12 alongside #10 and the expert-side pieces of #11. Your HSDPMeshInfo helper is in commit |
|
FYI — pulled in upstream/main onto |
…ch#3308 Replays the pytorch#3308 ("Remove MoE expert for-loop fallback") deletion onto experiments/ezpz/moe/. That upstream PR removed the `use_grouped_mm` config field from `GroupedExperts.Config` and inlined `torch._grouped_mm` as the only expert path. Upstream's argument is that `_grouped_mm` already provides a CUDA fallback on pre-SM90 hardware. **XPU has no `_grouped_mm` kernel at all**, so the unconditional path breaks ezpz on Aurora / Sunspot. - `experiments/ezpz/moe/experts.py` (new): `EzpzGroupedExperts` subclass with `compute_backend: Literal["for_loop", "grouped_mm"]`. Default defers to upstream; `"for_loop"` re-vendors the per-expert matmul loop that pytorch#3308 deleted, restoring the XPU / pre-SM90 path. - `experiments/ezpz/moe/__init__.py`: `make_ezpz_experts_config(...)` wrapper that calls upstream's `make_experts_config(...)` and re-wraps as `EzpzGroupedExperts.Config`. `_build_moe_layers` threads a `compute_backend` kwarg (default `"grouped_mm"`). - `experiments/ezpz/moe/model.py`: `update_from_config` previously mutated `experts.use_grouped_mm = False` on pre-SM90 devices; that field no longer exists. Replaced with `experts_cfg.compute_backend = "for_loop"` guarded by a `getattr(..., "compute_backend", "grouped_mm")` so the block is robust to future config-shape changes. Default behavior is unchanged for SM90+ CUDA: `compute_backend` defaults to `"grouped_mm"`, which calls `super()._experts_forward(...)` and hits the upstream path. The for_loop branch is only taken on devices without `_grouped_mm`. Pending PRs saforem2#9 / saforem2#10 / saforem2#11 will need to rebase onto this and adapt to the `EzpzGroupedExperts` subclass — see docs/upstream-sync.md for notes on what each will need to do.
Zero importers in the codebase. Verified with
grep -rn "EzpzChunkedCELoss\|ezpz.loss" torchtitan/
which returned only the definition itself.
The class adds three Config knobs over upstream ChunkedCELoss
(empty_cache_between_chunks, keep_lm_head_unsharded_between_chunks,
sync_replicated_lm_head_grad) plus a ~150-line __call__ override.
Since no model spec / train.py / trainer.py instantiates it, the
file is 201 lines of inert code that adds sync burden against
upstream ChunkedCELoss changes (e.g. the FSDP set_reshard_after_forward
plumbing) without any consumer benefit.
Re-add as a thin subclass when there's a real consumer in
train.py / a model registry. The three knobs likely belong on
upstream ChunkedCELoss.Config directly anyway — they're not
MoE/SFT-specific.
Addresses review finding saforem2#9 on saforem2#14.
Summary
Fixes the custom MoE expert FSDP sharding path for HSDP meshes.
The previous code constructed
FSDPMeshInfofor both 1D and 2D meshes. For 2D HSDP meshes, this can produce DTensor specs whose mesh has both replicate and shard dimensions but whose placements only describe one axis.This changes the custom placement path to use
HSDPMeshInfofor 2D meshes and preservesFSDPMeshInfofor 1D meshes.Validation
git diff --cached --check.venv/bin/python -m py_compile torchtitan/experiments/ezpz/moe/parallelize.pySummary by Sourcery
Bug Fixes: