ezpz/moe: resync with upstream + adapt to PR #3308 (use_grouped_mm removal)#13
Conversation
…endency (pytorch#3242) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * pytorch#3236 * pytorch#3142 * __->__ pytorch#3242 vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this: - get rid of dependency on a HF format checkpoint folder when initializing. Don't implicitly depend on `config.json` as config source of truth Another changes in this PR: - remove the round-trip translation from torchtitan config -> vllm config -> torchtitan config. Using closure to bypass.
As titled
…pytorch#3315) AutoParallel's input_fn() only returned tokens, but Decoder.forward() also receives positions via extra_kwargs. This caused a mismatch between the number of graph placeholders (from tracing with tokens only) and the actual runtime args (which include positions), failing with "expected N arguments for placeholders but received N+1". Add positions to input_fn() and a matching input constraint so AutoParallel traces with both inputs. Authored by Claude.
…h#3311) ## Summary - Add a simple `log_timer` context manager to `common_utils.py` that measures wall-clock elapsed time and logs it to console (e.g. `trace_train_step took 0.043s`). - Apply `log_timer` to the `trace_train_step` call in `GraphTrainer._make_fx_forward_backward_step` to measure tracing time. ## Test plan - [x] Verify `log_timer` output appears in training logs during aot_fx_trace runs - [ ] Existing unit tests pass: `pytest torchtitan/experiments/graph_trainer/tests/ -x`
…rch#3270) Summary: The remat pass previously rebuilt the graph wholesale (fx.Graph() + node_copy of every node) and relied on whole-graph DCE to remove dead must_recompute originals. Refactor to mutate gm.graph in place: dups are inserted in front of their first backward consumer, backward args are redirected to the dups, and only originals whose users became empty are erased. Original node identities and names are preserved, the topological-order assumption is explicit (input graph order drives insertion, validated by gm.graph.lint() at the end), and the underlying function takes the standard (gm, example_inputs) graph pass signature. CPU-offload reload chains are handled by hoisting the chain in front of the earliest dup that needs it - the in-place equivalent of upstream's "eagerly copy reload chain into the new graph" trick. Authored with Claude. Test Plan: Unit tests: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -x -> 68 passed, 1 skipped (3 new in TestSelectiveActivationRematPass) End-to-end on 8xH100 / Llama3 8B / FSDP=4 + TP=2 / aot_fx_trace / no cudagraph, with --debug.seed=42 --debug.deterministic: <img width="1728" height="625" alt="Screenshot 2026-05-07 at 5 16 08 PM" src="https://github.com/user-attachments/assets/b28e3acf-626e-4014-b5e5-4ffa6a686f08" /> CPU offload using upstream remat pass <img width="1728" height="612" alt="Screenshot 2026-05-07 at 5 16 23 PM" src="https://github.com/user-attachments/assets/6c391927-7022-496d-bb9b-0e38e4808df8" /> CPU offload using our refactor Before: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpfmqvbv/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 After: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp9ImSg9/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
torch._grouped_mm already provides a CUDA fallback path when the fused grouped GEMM kernel is unavailable, including on pre-SM90 hardware. Keeping a separate Python for-loop expert implementation duplicates that fallback, carries an extra configuration branch, and makes MoE behavior diverge across models. Use the grouped-mm path unconditionally and rely on PyTorch to choose either the fused kernel or its built-in loopy fallback.
Code Move Only! ## Summary - Split the monolithic `passes.py` (~1037 lines) into focused modules, keeping `passes.py` as the orchestration layer (~400 lines): - **`memory_policy.py`** — SAC tagging, default/eager/offload policies, and `tag_with_memory_policy_pass` - **`inductor_passes.py`** — `regional_inductor_pass`, `full_inductor_compilation_pass`, and `annotate_flex_attention_for_regional_inductor_pass` - **`cudagraph.py`** — `cudagraph_pass` and `insert_kernel_annotations_pass` (appended to existing file) - **`registry.py`** — `MEMORY_POLICY_REGISTRY`, `PASS_PIPELINE_REGISTRY`, `POST_INIT_HOOKS`, `PRE_TRAIN_STEP_HOOKS` and their decorators (breaks circular dep between `passes.py` and `memory_policy.py`) - Code-move only — all function bodies are identical; the only diffs are removing local imports that became unnecessary when code moved to the same file - Updated `test_passes.py` imports to use new module paths ## Test plan - [ ] `ruff check --select F401` passes clean (no unused imports) - [ ] `pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -x` - [ ] `pytest torchtitan/experiments/graph_trainer/tests/test_precompile.py -x` - [ ] `pre-commit run --all-files`
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * pytorch#3250 * pytorch#3248 * pytorch#3247 * pytorch#3246 * __->__ pytorch#3249 Add a ``support_autograd_grad`` opt-in flag to ``ChunkedCELoss.Config`` that exposes the lm_head parameter gradients as explicit autograd outputs of the returned loss tensor. Designed for FX-tracing flows (e.g. graph_trainer's aot_fx_trace) where ``param.grad`` side-effect writes from ``chunk_loss.backward()`` inside the chunk loop don't survive into the captured graph and replay therefore produces all-zero param grads. Mechanism, when the flag is True: - The chunk loop runs unchanged: per-chunk ``chunk_loss.backward()`` populates ``lm_head.weight.grad`` with the correctly sharded value (FSDP last-chunk reduce-scatter still handles the actual reduction). - After the loop, the sharded ``param.grad`` is captured via ``p.grad.detach()`` and ``p.grad`` is cleared. - The captured grads (plus the existing accumulated hidden_states grad) are plumbed through a new ``_ChunkedLossWithParamGrads`` autograd Function as saved tensors. Its backward returns them as grads for the corresponding lm_head parameter inputs, so outer ``torch.autograd.grad(loss, lm_head.parameters())`` resolves to real gradients instead of None / zeros. - Under FSDP, ``set_requires_gradient_sync(False)`` is set on lm_head and restored after the outer backward via a callback queued on the autograd engine. Without this, outer ``loss.backward()`` would re-fire the post-accumulate-grad hook on already-sharded grads and either error or produce wrong values. Both autograd Functions (the existing ``_DecoderOutputGradientBackProp`` and the new ``_ChunkedLossWithParamGrads``) return saved grads as-is without chain-ruling through ``grad_output``. The contract is that the loss returned by these Functions is the autograd endpoint — callers must pass any scaling factor (e.g. ``global_valid_tokens``) into ``loss_fn`` rather than dividing the returned loss externally. See graph_trainer's ``compute_loss`` for the canonical pattern. This matches the pre-existing behavior of ``_DecoderOutputGradientBackProp`` and avoids a structural FSDP+TP problem: ``grad_output`` is a ``Replicate()`` scalar DTensor on the loss's mesh (typically ``(tp,)``) while saved param grads live on the params' mesh (e.g. ``(fsdp, tp)``); DTensor refuses cross-mesh ``aten.mul.Tensor``, so any chain-rule multiply would crash at runtime. Tests: - tests/unit_tests/test_loss.py: parametrized ``support_autograd_grad in {False, True}`` equivalence check against the unchunked CE standard path; bitwise (rtol=0, atol=0) check that True and False paths produce identical grads; side-effect contract check that the True path doesn't populate ``param.grad``. - graph_trainer/tests/test_trace_module.py: end-to-end test that traces a small ``lm_head + ChunkedCELoss(support_autograd_grad= True)`` train_step via ``trace_train_step`` and verifies ``torch.equal`` between eager and replayed loss + h grad + lm_head grad on a CUDA model. The flag defaults to False so existing callers (the eager torchtitan trainer) are unaffected; consumers that want explicit param grads opt in. graph_trainer wires this in a separate commit.
Wrap each layer's inner attention forward via a new
`apply_cp_to_attention` helper in common_utils, called from
parallelize_{llama,deepseekv3,qwen3} when `cp_enabled`. Adds CP and
TP+CP integration tests for all three models.
<!-- ps-id: 089851fc-7329-4da3-8ed7-103f786148a0 -->
Co-authored-by: Aditya Venkataraman <avenkataraman@fb.com>
Hey there 👋 As explained in [pytorch#3158](pytorch#3158 (comment)), this PR refactors the Lychee link-checking logic to eliminate flaky CI failures while significantly improving execution speed. <br> ## 1. Multi-Tier Check Strategy The "Commit-time" experience is now separated from "Infrastructure monitoring" to ensure the development flow is not blocked by external outages. * **Pre-commit (Local & PR CI):** Configured to **fail only on 404**. Codes like `502` (e.g., "GitHub Unicorns") are accepted to prevent transient failures from stalling the workflow. * **Nightly CI:** Performs a strict check (accepting only `200`, `403`, `429`, `503`) with high persistence. It retries for up to **30 minutes**, outlasting most service outages. This run populates the cache with verified statuses that later can be reused in other workflow runs. <br> ## 2. The Cache Lifecycle: Why `key` and `restore-keys` are Necessary To understand the necessity of a dynamic `key` and the `restore-keys` fallback, one must first recognize that **GitHub Actions caches are immutable**. ### The Problem with Static Keys If a static key like `key: lychee-cache` is used without any dynamic parts, the workflow encounters a "Cache Hit" lock: 1. **The First Run:** GitHub creates the `.lycheecache` file and saves it as `lychee-cache`. 2. **Subsequent Runs:** GitHub finds an exact match for `lychee-cache` and downloads it. Because GitHub cannot update an existing cache, any new links or status updates discovered during the run are **discarded**. 3. **The "Stale Cache" Effect:** Over time, the cache becomes frozen in the state of the first run. Fixed links remain marked as "broken," and new links are re-checked every single time, slowing down CI. ### The Solution: Dynamic Keys & Branch Isolation GitHub restricts cache access based on branches: the **Default Branch** (`main`) is accessible by all, while **Feature Branches** can only access their own caches or the default branch's cache. By using a dynamic key (e.g., `cache-lychee-${{ github.sha }}`) combined with `restore-keys`, the process moves through two distinct phases: - **Restore** (start of job) - **Save** (end of job). #### Step-by-Step Lifecycle (3-Commit Example) | Phase | Commit 1: The "Inheritance" | Commit 2: The "Update" | Commit 3: The "Iteration" | | :--- | :--- | :--- | :--- | | **Restore** | Misses `cache-lychee-SHA1`. Falls back to `restore-keys` to pull the latest `main` cache that matches `cache-lychee-` pattern. | Misses `SHA2`. Pulls the most recent match from the branch scope (**SHA1**). | Misses `SHA3`. Pulls the latest available version from the branch (**SHA2**). | | **Action** | Lychee runs, checks only new links, and uses the inherited cache for the rest. | Lychee updates results with any new findings. | Lychee uses the `SHA2` baseline, ensuring zero redundant checks. | | **Save** | GitHub saves a **new** cache entry: `lychee-cache-SHA1`. | GitHub saves a new immutable entry: `lychee-cache-SHA2`. | GitHub saves the final state: `lychee-cache-SHA3`. | > [!NOTE] > This "chaining" effect ensures every commit builds upon the previous one, while keeping PR runs fast and the nightly "source of truth" accessible. <br> ## 3. Optimization: Parallelism & Output Previously, the configuration relied on `require_serial: true` and `verbose: true`. * **The Problem:** `verbose: true` was required to display the "Lychee not found" warning (since the script uses `exit 0`). However, because `pre-commit` spawns a process for every file, this caused the warning to print for every single file checked. `require_serial: true` was used to stop the log spam, but caused a **2x-3x slowdown**. * **The Fix:** An **Atomic Sentinel** is now used via `mkdir /tmp/lychee_lock`. Because `mkdir` is an atomic operation, only the first process successfully creates the directory and prints the warning. All other parallel processes fail to create the directory and remain silent. * **The Result:** `require_serial` and `verbose` are now **false**. The check runs in parallel (fast), and the warning prints exactly once (clean) by redirecting directly to the user's terminal screen via `> /dev/tty`. <br> ## 4. Fix Lychee version to install In **Lychee v0.24.1**, a change in the release archive structure broke dynamic installation scripts that fetched the "Latest" release. * **Change:** The installation now uses a **fixed, verified version** to prevent upstream changes from breaking the CI pipeline. <br> ## 5. GITHUB_TOKEN The `GITHUB_TOKEN` is explicitly passed to the Lychee action and pre-commit steps. This increases the rate limit for GitHub API requests, reducing `403 Forbidden` and `429 Too Many Requests` errors when checking internal repository links. <br> ## Summary Overview | Feature | Local / PR CI | Nightly CI | | :--- | :--- | :--- | | **Failure Condition** | Only `404` | Most non-200 codes | | **Duration/Retries** | Fast (5 retries / 15 secs) | Patient (30 retries / 30 mins) | | **Execution** | Parallel (via Sentinel) | Standard Action | | **Cache Goal** | Consume & Increment | Refresh & Validate |
Add Codex-facing AGENTS.md symlinks that point at the existing Claude instruction files. This lets Codex reuse the same repo and graph_trainer guidance without duplicating instruction content or creating a second source of truth. The root AGENTS.md points to the repo-level .claude/CLAUDE.md file. The graph_trainer AGENTS.md points to the graph_trainer-local .claude/CLAUDE.md file so directory-local instructions continue to apply when Codex is working in that subtree. Test Plan: - git ls-tree -l HEAD AGENTS.md torchtitan/experiments/graph_trainer/AGENTS.md - git diff --name-status origin/main..HEAD
…torch#3321) ## Summary - Rename `apply_sac_pass` to `tag_sac_policy` to clarify it only tags nodes (not the remat transform) - Skip `torch.ops.device_mesh._get_submesh` in SAC tagging (metadata-only op, no tensor output) - Propagate source node metadata (stacktrace, module_fqn, seq_nr) to CPU offload chain nodes for better tlparse/graph dump readability - Remove unused `cpu_offload_reload_node` metadata (remat pass discovers offload chain via graph structure) - Fix `defer_offload_waits` to defer from production site instead of latest storage-chain consumer — avoids pushing waits too far (e.g. layers.1's wait landing at layers.3) - Fix deferred waits landing before offload ops by including `ao.offload` in region anchors ## Test plan - [ ] `pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -x` - [ ] `NGPU=8` run with `--compile.debug_graph_passes` and verify tlparse graph dumps show correct metadata and ordering - [ ] Verify bitwise deterministic test still passes **Following are desirable, but didn't' land.** Before the fix, we have both waits from layer0 and 1 at the end of layer 2 <img width="1727" height="447" alt="image" src="https://github.com/user-attachments/assets/423d0ed1-8886-4c51-ba6a-ba40892ee7d2" /> after the fix <img width="1725" height="298" alt="image" src="https://github.com/user-attachments/assets/7d5df701-031a-46bf-ab93-68f578471e66" /> Before the fix https://www.internalfb.com/intern/diffing/?before_paste_number=2321471471&after_paste_number=2321471510®ex_remove_pattern=&enable_regex_remove=0&strip_empty_lines=0&line_wrap=0&selected_tab=plain_diff After the fix (removing these few line) https://www.internalfb.com/intern/diffing/?before_paste_number=2320391950&after_paste_number=2320391970®ex_remove_pattern=&enable_regex_remove=0&strip_empty_lines=0&line_wrap=0&selected_tab=plain_diff
…ch#3308 Replays the pytorch#3308 ("Remove MoE expert for-loop fallback") deletion onto experiments/ezpz/moe/. That upstream PR removed the `use_grouped_mm` config field from `GroupedExperts.Config` and inlined `torch._grouped_mm` as the only expert path. Upstream's argument is that `_grouped_mm` already provides a CUDA fallback on pre-SM90 hardware. **XPU has no `_grouped_mm` kernel at all**, so the unconditional path breaks ezpz on Aurora / Sunspot. - `experiments/ezpz/moe/experts.py` (new): `EzpzGroupedExperts` subclass with `compute_backend: Literal["for_loop", "grouped_mm"]`. Default defers to upstream; `"for_loop"` re-vendors the per-expert matmul loop that pytorch#3308 deleted, restoring the XPU / pre-SM90 path. - `experiments/ezpz/moe/__init__.py`: `make_ezpz_experts_config(...)` wrapper that calls upstream's `make_experts_config(...)` and re-wraps as `EzpzGroupedExperts.Config`. `_build_moe_layers` threads a `compute_backend` kwarg (default `"grouped_mm"`). - `experiments/ezpz/moe/model.py`: `update_from_config` previously mutated `experts.use_grouped_mm = False` on pre-SM90 devices; that field no longer exists. Replaced with `experts_cfg.compute_backend = "for_loop"` guarded by a `getattr(..., "compute_backend", "grouped_mm")` so the block is robust to future config-shape changes. Default behavior is unchanged for SM90+ CUDA: `compute_backend` defaults to `"grouped_mm"`, which calls `super()._experts_forward(...)` and hits the upstream path. The for_loop branch is only taken on devices without `_grouped_mm`. Pending PRs #9 / #10 / #11 will need to rebase onto this and adapt to the `EzpzGroupedExperts` subclass — see docs/upstream-sync.md for notes on what each will need to do.
There was a problem hiding this comment.
Sorry @saforem2, your pull request is larger than the review limit of 150000 diff characters
There was a problem hiding this comment.
Pull request overview
This PR resyncs the ezpz/moe fork with 12 upstream commits (including upstream removal of the MoE use_grouped_mm flag/for-loop fallback), and reintroduces an ezpz-owned fallback expert compute path for devices without a usable torch._grouped_mm kernel (notably XPU). It also pulls in substantial upstream refactors across graph_trainer and the RL vLLM integration, plus link-checker workflow changes.
Changes:
- Remove
use_grouped_mmplumbing from core MoE and model configs, making grouped-mm the sole upstream expert path. - Add
EzpzGroupedExpertswith acompute_backendselector (grouped_mmvsfor_loop) and thread it through ezpz MoE configs + runtime config updates for pre-SM90/XPU fallback. - Upstream resync: graph_trainer pass refactors + CP support, RL vLLM registry/parser updates, ChunkedCE autograd.grad support, and Lychee link-checker caching/workflow updates.
Reviewed changes
Copilot reviewed 52 out of 53 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| torchtitan/models/llama4/model.py | Removes deprecated use_grouped_mm SM90 gating logic in config update. |
| torchtitan/models/gpt_oss/moe.py | Drops for-loop/grouped-mm toggle and inlines grouped-mm expert compute. |
| torchtitan/models/gpt_oss/model.py | Removes deprecated use_grouped_mm SM90 gating logic in config update. |
| torchtitan/models/deepseek_v3/model.py | Removes deprecated use_grouped_mm SM90 gating logic in config update. |
| torchtitan/models/common/moe.py | Removes use_grouped_mm and for-loop fallback; always uses torch._grouped_mm. |
| torchtitan/models/common/config_utils.py | Updates make_experts_config signature to remove use_grouped_mm. |
| torchtitan/experiments/rl/tests/test_bitwise_parity.py | Updates vLLM registration callsite to new registry API. |
| torchtitan/experiments/rl/models/vllm_wrapper.py | Renames wrapper and switches to torchtitan-parallelism–driven ParallelDims. |
| torchtitan/experiments/rl/models/vllm_registry.py | Reworks vLLM registration to include config parser + HF-shaped config dict generation. |
| torchtitan/experiments/rl/models/parallelize.py | Updates config imports to new consolidated module exports. |
| torchtitan/experiments/rl/grpo.py | Updates config imports to new consolidated module exports. |
| torchtitan/experiments/rl/generate.py | Updates vLLM registration callsite to new registry API and threads parallelism. |
| torchtitan/experiments/rl/config_registry.py | Updates config imports and enforces generator parallelism invariants in configs. |
| torchtitan/experiments/rl/actors/trainer.py | Updates config imports to new consolidated module exports. |
| torchtitan/experiments/rl/actors/generator.py | Adds generator-side validation (no SP / loss parallel) and uses new vLLM parser registration path. |
| torchtitan/experiments/rl/init.py | Exposes renamed wrapper + new registry entry point. |
| torchtitan/experiments/graph_trainer/trainer.py | Adds timing around trace step and adjusts imports for refactored registries. |
| torchtitan/experiments/graph_trainer/tests/test_trace_module.py | Adds tests for ChunkedCE autograd.grad tracing and CP tracing expectations. |
| torchtitan/experiments/graph_trainer/tests/test_precompile.py | Updates precompile test fixtures with new trace metadata fields. |
| torchtitan/experiments/graph_trainer/tests/test_passes.py | Updates pass imports and adds extensive remat/offload behavioral tests. |
| torchtitan/experiments/graph_trainer/tests/test_chunked_loss.py | Adds unit tests for ChunkedCELossWithParamGrads parity/grad side effects. |
| torchtitan/experiments/graph_trainer/tests/integration_tests.py | Updates integration test overrides for CP and memory policy rename. |
| torchtitan/experiments/graph_trainer/selective_activation_remat.py | Refactors remat to in-place duplication + targeted erase and offload-chain hoisting. |
| torchtitan/experiments/graph_trainer/registry.py | New centralized registries to avoid circular imports across pass modules. |
| torchtitan/experiments/graph_trainer/qwen3/parallelize.py | Adds CP application to attention modules when CP is enabled. |
| torchtitan/experiments/graph_trainer/precompile.py | Moves distributed metadata filtering import to refactored inductor passes module. |
| torchtitan/experiments/graph_trainer/passes.py | Refactors into orchestrator module; delegates passes to focused modules and registry. |
| torchtitan/experiments/graph_trainer/memory_policy.py | New memory policy tagging module (SAC + optional offload) and dispatcher. |
| torchtitan/experiments/graph_trainer/llama3/parallelize.py | Adds CP application to attention modules when CP is enabled. |
| torchtitan/experiments/graph_trainer/llama3/parallelize_autoparallel.py | Updates autoparallel input_fn to include positions tensor and constraints. |
| torchtitan/experiments/graph_trainer/inductor_passes.py | New module for regional/full Inductor compilation passes and FlexAttention annotation. |
| torchtitan/experiments/graph_trainer/graph_utils.py | Removes obsolete grouped-mm comment; updates memory policy error text. |
| torchtitan/experiments/graph_trainer/fsdp_passes.py | Preserves MUST_CPU_OFFLOAD tags and updates backward-node classification. |
| torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py | Switches CP checks to shared helper; applies CP when enabled. |
| torchtitan/experiments/graph_trainer/cudagraph.py | Moves kernel annotation pass + cudagraph wrapper into cudagraph module. |
| torchtitan/experiments/graph_trainer/cpu_offload.py | Propagates node metadata into offload chain and adjusts wait deferral behavior. |
| torchtitan/experiments/graph_trainer/configs.py | Renames memory policy option to sac_and_offload. |
| torchtitan/experiments/graph_trainer/common_utils.py | Adds log_timer and CP-to-attention helper; removes recomputed-node workaround. |
| torchtitan/experiments/graph_trainer/chunked_loss.py | New ChunkedCE variant to plumb lm_head param grads via explicit autograd outputs. |
| torchtitan/experiments/graph_trainer/autoparallel_api.py | Removes redundant forward-arg validation call. |
| torchtitan/experiments/graph_trainer/.claude/CLAUDE.md | Updates documentation to match in-place remat behavior. |
| torchtitan/experiments/ezpz/moe/model.py | Switches ezpz MoE fallback from removed use_grouped_mm to compute_backend="for_loop". |
| torchtitan/experiments/ezpz/moe/experts.py | New EzpzGroupedExperts with compute_backend selector and for-loop fallback. |
| torchtitan/experiments/ezpz/moe/init.py | Adds make_ezpz_experts_config wrapper and threads compute_backend into layer build. |
| torchtitan/experiments/ezpz/docs/upstream-sync.md | Documents the upstream sync and the ezpz MoE replay strategy. |
| torchtitan/components/loss.py | Refactors chunked CE loss to allow subclasses to override autograd bridging. |
| tests/unit_tests/test_compile_moe.py | Updates MoE compile test to use GroupedExperts._experts_forward directly. |
| .pre-commit-config.yaml | Makes lychee less noisy, enables caching, and adds a “single warning” sentinel. |
| .gitignore | Ignores .lycheecache produced by lychee --cache. |
| .github/workflows/lint.yaml | Pins lychee version, adds cache restore, and runs lychee via pre-commit across all files. |
| .github/workflows/link_check.yaml | Removes standalone scheduled link-check workflow (now covered in lint). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Three lint hits surfaced by CI on the new files in #13: - **ufmt**: applied the formatter's reformatting to the three new `experiments/ezpz/moe/` files (import grouping, comprehension collapse, line-wrap nits in `model.py`). - **flake8 N801**: pre-existing class names `moeTransformerBlock` / `moeModel` violate CapWords, but renaming would touch every callsite outside of this PR's scope. Suppressed in-place with `# noqa: N801` on the two class declarations. - **codespell**: a quoted upstream PR title in `experiments/ezpz/docs/upstream-sync.md` contains "Reenable" which codespell flags as "Re-enable". Added a `<!-- codespell:ignore-line -->` marker so the historical quote stays verbatim. Strictly inside `experiments/ezpz/` per Golden Rule #1 — no edits to core `pyproject.toml`.
- `docs/summaries/2026-04-12_to_2026-04-27.md:230` linked to `journal.md` (resolves to `summaries/journal.md`, doesn't exist). Fix: `../journal.md` (the file lives at `docs/journal.md`). - `README.md:95` linked to `torchtitan/experiments/ezpz/run_train.sh` which doubled the path (resolved to `experiments/ezpz/torchtitan/experiments/ezpz/run_train.sh`). Fix: `run_train.sh` (sibling of the README).
Two pre-commit failures landed by the previous link-fix commit (which brought README.md into the lint set): - **trailing-whitespace** on line 4: stray `` `` (markdown line-break) after `i.e.:`. Dropped — the line breaks naturally on the `[saforem2/torchtitan@ezpz](...)` link that follows. - **codespell** on line 18: `preemptable` is the actual ALCF Polaris queue name, but codespell wants `preemptible`. Stash it as `preempt"able"` (shell concatenation, resolves to the same literal at submission time) so codespell sees two separate tokens. Header parenthetical clarifies what the queue is for the reader.
50-step `moe_500m` smoke on Sunspot 8N (job 12466707) validating PR #13's `EzpzGroupedExperts.compute_backend = "for_loop"` fallback. - 11 layer-wise warnings confirm the fallback fires (1 per MoE layer). - Loss descends cleanly 12.90 → 6.66 (-6.24 nats) over 50 steps. - Steady-state ~8,694 TPS / GPU, ~13% MFU — ~20% per-GPU uplift vs the 2026-04-13 2N baseline (torch.compile inductor improvements; not regression in the for_loop kernel). - Memory stable at 54.6% across the run; wall 541s end-to-end. W&B: fluent-glitter-2042 (https://wandb.ai/aurora_gpt/torchtitan.ezpz.train/runs/lt77xx0o)
Validates the 35th upstream sync (PR pytorch#3159 Module.parallelize signature change replay) on Sunspot 2N (job 12467124) across 5 configs: agpt_debugmodel LBS=2 loss 10.81 → 6.77 ~37,500 TPS / ~2.9% MFU agpt_2b LBS=1 loss 12.93 → 6.01 ~6,100 TPS / ~22.8% MFU agpt_2b LBS=2 loss 12.91 → 6.12 ~7,200 TPS / ~27.0% MFU moe_debugmodel LBS=2 loss 12.96 → 7.01 ~13,000 TPS / ~9.0% MFU moe_2b LBS=1 loss 12.91 → 6.16 ~2,900 TPS / ~8.4% MFU agpt_2b LBS=2 matches the 2026-04-25 n=2 baseline (7,142 TPS / 27.6% MFU) within noise; resync introduces no throughput regression. The for_loop MoE expert backend (PR #13) still fires correctly on XPU under the new parallelize signature. Default moe_2b() LBS=16 OOMs on Max 1550 (single 62.53 GiB allocation, likely the fused activation for all experts × full-batch-tokens); LBS=1 documented as the safe XPU baseline. Adds two new reports under docs/experiments/{agpt,moe}/sunspot/ and indexes them in the parent READMEs.
Summary
Pulls in 12 commits from
upstream/mainand replays the only ezpz-affecting change ontoexperiments/ezpz/moe/. Two-commit history; no expert-side perf work or HSDP fixes here — those stay with the open PRs (#9, #10, #11) which will need to rebase onto this.Commits
62b395b5— Mergeupstream/mainintoezpz-moe-resync. Clean merge (0 conflicts pre-checked). Notable upstream changes:b301dfa0— [MoE] Remove expert for-loop fallback (Remove MoE expert for-loop fallback pytorch/torchtitan#3308) — deletes_run_experts_for_loopand theuse_grouped_mmconfig field frommodels/common/moe.py. Inlinestorch._grouped_mmas the only expert path. This breaksexperiments/ezpz/moe/model.pywhich useduse_grouped_mm = Falseas the XPU fallback (XPU has no_grouped_mmkernel at all).5ca23a5d— [GraphTrainer] Add Context Parallel supportd57df092— Make ChunkedCELoss supporttorch.autograd.grad1a0fe3e3— [graph_trainer] Refactor passes.py into focused modules12aba852—feat(ezpz/moe): introduce EzpzGroupedExperts to handle upstream #3308New
experiments/ezpz/moe/experts.py:EzpzGroupedExperts(GroupedExperts)subclass withcompute_backend: Literal[\"for_loop\", \"grouped_mm\"]. Default\"grouped_mm\"defers to upstream viasuper()._experts_forward(...).\"for_loop\"re-vendors the per-expert matmul loop that Remove MoE expert for-loop fallback pytorch/torchtitan#3308 deleted, restoring the XPU / pre-SM90 path.experiments/ezpz/moe/__init__.py: newmake_ezpz_experts_config(...)wrapper that calls upstream'smake_experts_config(...)and re-wraps asEzpzGroupedExperts.Config._build_moe_layersnow threads acompute_backendkwarg (default\"grouped_mm\") through to it. All existing flavors (debugmodel,500M,2B,10B_2B,10B_2B_sdpa, …) build unchanged.experiments/ezpz/moe/model.py:update_from_configpreviously mutatedexperts.use_grouped_mm = Falseon pre-SM90 devices. That field no longer exists. Replaced withexperts_cfg.compute_backend = \"for_loop\", guarded by agetattr(..., \"compute_backend\", \"grouped_mm\")so the block is robust to future config-shape changes.experiments/ezpz/docs/upstream-sync.md: added the 33rd-sync entry per the upstream-sync protocol, documenting Remove MoE expert for-loop fallback pytorch/torchtitan#3308 and the replay.Default behavior is unchanged
For SM90+ CUDA,
compute_backenddefaults to\"grouped_mm\", which callssuper()._experts_forward(...)and hits the upstream path. The newfor_loopbranch only fires on devices without_grouped_mm(currently XPU and pre-SM90 CUDA).End-to-end config build verified locally for
debugmodel,500M,10B_2B_sdpa— all instantiateEzpzGroupedExpertswithcompute_backend=\"grouped_mm\".Notes for downstream PRs
PRs #9, #10, #11 all assume the pre-pytorch#3308
GroupedExpertsshape:moe.py, but lives on top of pre-mergeparallelize.py. Should rebase cleanly.batched_mm_paddedbackend): extendsGroupedExperts.Configwith acompute_backendfield that's basically the same shape as the one this PR adds via subclass. Cleanest landing path is to add\"batched_mm_padded\"as a third literal inEzpzGroupedExperts.ExpertComputeBackendand a third branch inEzpzGroupedExperts._experts_forward.w13/w2_t+ equal-counts no-grad bmm fast path) layer onto thefor_loopmethod here. Token-dispatcher rewrite is unaffected by this PR.Will leave pointer comments on #9 / #10 / #11.
Out of scope
This PR deliberately does NOT bundle the changes from open PRs #9, #10, #11. Those are owned by Sam and Nathan respectively and should land independently after rebasing onto this. Supersedes #12 (the prior 4-commit integration attempt) which will be closed.
Validation
upstream/main(0 conflicts, pre-checked withgit merge-tree)EzpzGroupedExperts.Configwrapperuse_grouped_mminexperiments/ezpz/(verified via grep)for_looppath) — to be filed as a follow-up smoke-test report underexperiments/ezpz/docs/experiments/moe/