Skip to content

ezpz/moe: resync with upstream + adapt to PR #3308 (use_grouped_mm removal)#13

Merged
saforem2 merged 18 commits into
ezpzfrom
ezpz-moe-resync
May 13, 2026
Merged

ezpz/moe: resync with upstream + adapt to PR #3308 (use_grouped_mm removal)#13
saforem2 merged 18 commits into
ezpzfrom
ezpz-moe-resync

Conversation

@saforem2

Copy link
Copy Markdown
Owner

Summary

Pulls in 12 commits from upstream/main and replays the only ezpz-affecting change onto experiments/ezpz/moe/. Two-commit history; no expert-side perf work or HSDP fixes here — those stay with the open PRs (#9, #10, #11) which will need to rebase onto this.

Commits

  1. 62b395b5 — Merge upstream/main into ezpz-moe-resync. Clean merge (0 conflicts pre-checked). Notable upstream changes:

    • b301dfa0[MoE] Remove expert for-loop fallback (Remove MoE expert for-loop fallback pytorch/torchtitan#3308) — deletes _run_experts_for_loop and the use_grouped_mm config field from models/common/moe.py. Inlines torch._grouped_mm as the only expert path. This breaks experiments/ezpz/moe/model.py which used use_grouped_mm = False as the XPU fallback (XPU has no _grouped_mm kernel at all).
    • 5ca23a5d — [GraphTrainer] Add Context Parallel support
    • d57df092 — Make ChunkedCELoss support torch.autograd.grad
    • 1a0fe3e3 — [graph_trainer] Refactor passes.py into focused modules
    • …plus 8 more (graph_trainer / RL / lint / docs)
  2. 12aba852feat(ezpz/moe): introduce EzpzGroupedExperts to handle upstream #3308

    • New experiments/ezpz/moe/experts.py: EzpzGroupedExperts(GroupedExperts) subclass with compute_backend: Literal[\"for_loop\", \"grouped_mm\"]. Default \"grouped_mm\" defers to upstream via super()._experts_forward(...). \"for_loop\" re-vendors the per-expert matmul loop that Remove MoE expert for-loop fallback pytorch/torchtitan#3308 deleted, restoring the XPU / pre-SM90 path.

    • experiments/ezpz/moe/__init__.py: new make_ezpz_experts_config(...) wrapper that calls upstream's make_experts_config(...) and re-wraps as EzpzGroupedExperts.Config. _build_moe_layers now threads a compute_backend kwarg (default \"grouped_mm\") through to it. All existing flavors (debugmodel, 500M, 2B, 10B_2B, 10B_2B_sdpa, …) build unchanged.

    • experiments/ezpz/moe/model.py: update_from_config previously mutated experts.use_grouped_mm = False on pre-SM90 devices. That field no longer exists. Replaced with experts_cfg.compute_backend = \"for_loop\", guarded by a getattr(..., \"compute_backend\", \"grouped_mm\") so the block is robust to future config-shape changes.

    • experiments/ezpz/docs/upstream-sync.md: added the 33rd-sync entry per the upstream-sync protocol, documenting Remove MoE expert for-loop fallback pytorch/torchtitan#3308 and the replay.

Default behavior is unchanged

For SM90+ CUDA, compute_backend defaults to \"grouped_mm\", which calls super()._experts_forward(...) and hits the upstream path. The new for_loop branch only fires on devices without _grouped_mm (currently XPU and pre-SM90 CUDA).

End-to-end config build verified locally for debugmodel, 500M, 10B_2B_sdpa — all instantiate EzpzGroupedExperts with compute_backend=\"grouped_mm\".

Notes for downstream PRs

PRs #9, #10, #11 all assume the pre-pytorch#3308 GroupedExperts shape:

  • Fix MoE expert FSDP mesh info for HSDP #9 (Sam Wheeler — HSDP fix): doesn't touch moe.py, but lives on top of pre-merge parallelize.py. Should rebase cleanly.
  • Add padded batched matmul MoE expert backend #10 (Sam Wheeler — batched_mm_padded backend): extends GroupedExperts.Config with a compute_backend field that's basically the same shape as the one this PR adds via subclass. Cleanest landing path is to add \"batched_mm_padded\" as a third literal in EzpzGroupedExperts.ExpertComputeBackend and a third branch in EzpzGroupedExperts._experts_forward.
  • Optimize 10b 2b sdpa moe #11 (Nathan Nichols — MoE optimizations): expert-side optimizations (cached w13 / w2_t + equal-counts no-grad bmm fast path) layer onto the for_loop method here. Token-dispatcher rewrite is unaffected by this PR.

Will leave pointer comments on #9 / #10 / #11.

Out of scope

This PR deliberately does NOT bundle the changes from open PRs #9, #10, #11. Those are owned by Sam and Nathan respectively and should land independently after rebasing onto this. Supersedes #12 (the prior 4-commit integration attempt) which will be closed.

Validation

  • Clean merge of upstream/main (0 conflicts, pre-checked with git merge-tree)
  • All existing MoE flavors build with the new EzpzGroupedExperts.Config wrapper
  • No remaining references to use_grouped_mm in experiments/ezpz/ (verified via grep)
  • MoE smoke run on Aurora / Sunspot (XPU for_loop path) — to be filed as a follow-up smoke-test report under experiments/ezpz/docs/experiments/moe/

wwwjn and others added 14 commits May 11, 2026 11:49
…endency (pytorch#3242)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* pytorch#3236
* pytorch#3142
* __->__ pytorch#3242

vllm has this customized config parser registry support so we can plug
in TorchTitan's config parser. Why we need this:
- get rid of dependency on a HF format checkpoint folder when
initializing. Don't implicitly depend on `config.json` as config source
of truth

Another changes in this PR:
- remove the round-trip translation from torchtitan config -> vllm
config -> torchtitan config. Using closure to bypass.
…pytorch#3315)

AutoParallel's input_fn() only returned tokens, but Decoder.forward()
also receives positions via extra_kwargs. This caused a mismatch between
the number of graph placeholders (from tracing with tokens only) and the
actual runtime args (which include positions), failing with "expected N
arguments for placeholders but received N+1".

Add positions to input_fn() and a matching input constraint so
AutoParallel traces with both inputs.

Authored by Claude.
…h#3311)

## Summary
- Add a simple `log_timer` context manager to `common_utils.py` that
measures wall-clock elapsed time and logs it to console (e.g.
`trace_train_step took 0.043s`).
- Apply `log_timer` to the `trace_train_step` call in
`GraphTrainer._make_fx_forward_backward_step` to measure tracing time.

## Test plan
- [x] Verify `log_timer` output appears in training logs during
aot_fx_trace runs
- [ ] Existing unit tests pass: `pytest
torchtitan/experiments/graph_trainer/tests/ -x`
…rch#3270)

Summary:

The remat pass previously rebuilt the graph wholesale (fx.Graph() +
node_copy of every node) and relied on whole-graph DCE to remove dead
must_recompute originals. Refactor to mutate gm.graph in place: dups are
inserted in front of their first backward consumer, backward args are
redirected to the dups, and only originals whose users became empty are
erased. Original node identities and names are preserved, the
topological-order assumption is explicit (input graph order drives
insertion, validated by gm.graph.lint() at the end), and the underlying
function takes the standard (gm, example_inputs) graph pass signature.

CPU-offload reload chains are handled by hoisting the chain in front of
the earliest dup that needs it - the in-place equivalent of upstream's
"eagerly copy reload chain into the new graph" trick.


Authored with Claude.

Test Plan:
Unit tests:
  pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -x
  -> 68 passed, 1 skipped (3 new in TestSelectiveActivationRematPass)

End-to-end on 8xH100 / Llama3 8B / FSDP=4 + TP=2 / aot_fx_trace / no
cudagraph, with --debug.seed=42 --debug.deterministic:

<img width="1728" height="625" alt="Screenshot 2026-05-07 at 5 16 08 PM"
src="https://github.com/user-attachments/assets/b28e3acf-626e-4014-b5e5-4ffa6a686f08"
/>

CPU offload using upstream remat pass 

<img width="1728" height="612" alt="Screenshot 2026-05-07 at 5 16 23 PM"
src="https://github.com/user-attachments/assets/6c391927-7022-496d-bb9b-0e38e4808df8"
/>
CPU offload using our refactor 


Before:
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpfmqvbv/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
After:
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp9ImSg9/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
torch._grouped_mm already provides a CUDA fallback path when the fused grouped GEMM kernel is unavailable, including on pre-SM90 hardware. Keeping a separate Python for-loop expert implementation duplicates that fallback, carries an extra configuration branch, and makes MoE behavior diverge across models.

Use the grouped-mm path unconditionally and rely on PyTorch to choose either the fused kernel or its built-in loopy fallback.
Code Move Only! 

## Summary

- Split the monolithic `passes.py` (~1037 lines) into focused modules,
keeping `passes.py` as the orchestration layer (~400 lines):
- **`memory_policy.py`** — SAC tagging, default/eager/offload policies,
and `tag_with_memory_policy_pass`
- **`inductor_passes.py`** — `regional_inductor_pass`,
`full_inductor_compilation_pass`, and
`annotate_flex_attention_for_regional_inductor_pass`
- **`cudagraph.py`** — `cudagraph_pass` and
`insert_kernel_annotations_pass` (appended to existing file)
- **`registry.py`** — `MEMORY_POLICY_REGISTRY`,
`PASS_PIPELINE_REGISTRY`, `POST_INIT_HOOKS`, `PRE_TRAIN_STEP_HOOKS` and
their decorators (breaks circular dep between `passes.py` and
`memory_policy.py`)
- Code-move only — all function bodies are identical; the only diffs are
removing local imports that became unnecessary when code moved to the
same file
- Updated `test_passes.py` imports to use new module paths

## Test plan

- [ ] `ruff check --select F401` passes clean (no unused imports)
- [ ] `pytest torchtitan/experiments/graph_trainer/tests/test_passes.py
-x`
- [ ] `pytest
torchtitan/experiments/graph_trainer/tests/test_precompile.py -x`
- [ ] `pre-commit run --all-files`
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* pytorch#3250
* pytorch#3248
* pytorch#3247
* pytorch#3246
* __->__ pytorch#3249

Add a ``support_autograd_grad`` opt-in flag to ``ChunkedCELoss.Config``
that exposes the lm_head parameter gradients as explicit autograd
outputs of the returned loss tensor. Designed for FX-tracing flows
(e.g. graph_trainer's aot_fx_trace) where ``param.grad`` side-effect
writes from ``chunk_loss.backward()`` inside the chunk loop don't
survive into the captured graph and replay therefore produces
all-zero param grads.

Mechanism, when the flag is True:

  - The chunk loop runs unchanged: per-chunk ``chunk_loss.backward()``
    populates ``lm_head.weight.grad`` with the correctly sharded value
    (FSDP last-chunk reduce-scatter still handles the actual reduction).
  - After the loop, the sharded ``param.grad`` is captured via
    ``p.grad.detach()`` and ``p.grad`` is cleared.
  - The captured grads (plus the existing accumulated hidden_states
    grad) are plumbed through a new ``_ChunkedLossWithParamGrads``
    autograd Function as saved tensors. Its backward returns them as
    grads for the corresponding lm_head parameter inputs, so outer
    ``torch.autograd.grad(loss, lm_head.parameters())`` resolves to
    real gradients instead of None / zeros.
  - Under FSDP, ``set_requires_gradient_sync(False)`` is set on lm_head
    and restored after the outer backward via a callback queued on the
    autograd engine. Without this, outer ``loss.backward()`` would
    re-fire the post-accumulate-grad hook on already-sharded grads and
    either error or produce wrong values.

Both autograd Functions (the existing ``_DecoderOutputGradientBackProp``
and the new ``_ChunkedLossWithParamGrads``) return saved grads as-is
without chain-ruling through ``grad_output``. The contract is that the
loss returned by these Functions is the autograd endpoint — callers
must pass any scaling factor (e.g. ``global_valid_tokens``) into
``loss_fn`` rather than dividing the returned loss externally. See
graph_trainer's ``compute_loss`` for the canonical pattern. This
matches the pre-existing behavior of ``_DecoderOutputGradientBackProp``
and avoids a structural FSDP+TP problem: ``grad_output`` is a
``Replicate()`` scalar DTensor on the loss's mesh (typically ``(tp,)``)
while saved param grads live on the params' mesh (e.g. ``(fsdp, tp)``);
DTensor refuses cross-mesh ``aten.mul.Tensor``, so any chain-rule
multiply would crash at runtime.

Tests:

  - tests/unit_tests/test_loss.py: parametrized ``support_autograd_grad
    in {False, True}`` equivalence check against the unchunked CE
    standard path; bitwise (rtol=0, atol=0) check that True and False
    paths produce identical grads; side-effect contract check that the
    True path doesn't populate ``param.grad``.
  - graph_trainer/tests/test_trace_module.py: end-to-end test that
    traces a small ``lm_head + ChunkedCELoss(support_autograd_grad=
    True)`` train_step via ``trace_train_step`` and verifies
    ``torch.equal`` between eager and replayed loss + h grad + lm_head
    grad on a CUDA model.

The flag defaults to False so existing callers (the eager torchtitan
trainer) are unaffected; consumers that want explicit param grads
opt in. graph_trainer wires this in a separate commit.
Wrap each layer's inner attention forward via a new
`apply_cp_to_attention` helper in common_utils, called from
parallelize_{llama,deepseekv3,qwen3} when `cp_enabled`. Adds CP and
TP+CP integration tests for all three models.

<!-- ps-id: 089851fc-7329-4da3-8ed7-103f786148a0 -->

Co-authored-by: Aditya Venkataraman <avenkataraman@fb.com>
Hey there 👋 

As explained in
[pytorch#3158](pytorch#3158 (comment)),
this PR refactors the Lychee link-checking logic to eliminate flaky CI
failures while significantly improving execution speed.

<br>

## 1. Multi-Tier Check Strategy
The "Commit-time" experience is now separated from "Infrastructure
monitoring" to ensure the development flow is not blocked by external
outages.

* **Pre-commit (Local & PR CI):** Configured to **fail only on 404**.
Codes like `502` (e.g., "GitHub Unicorns") are accepted to prevent
transient failures from stalling the workflow.
* **Nightly CI:** Performs a strict check (accepting only `200`, `403`,
`429`, `503`) with high persistence. It retries for up to **30
minutes**, outlasting most service outages. This run populates the cache
with verified statuses that later can be reused in other workflow runs.

<br>

## 2. The Cache Lifecycle: Why `key` and `restore-keys` are Necessary
To understand the necessity of a dynamic `key` and the `restore-keys`
fallback, one must first recognize that **GitHub Actions caches are
immutable**.

### The Problem with Static Keys
If a static key like `key: lychee-cache` is used without any dynamic
parts, the workflow encounters a "Cache Hit" lock:
1. **The First Run:** GitHub creates the `.lycheecache` file and saves
it as `lychee-cache`.
2. **Subsequent Runs:** GitHub finds an exact match for `lychee-cache`
and downloads it. Because GitHub cannot update an existing cache, any
new links or status updates discovered during the run are **discarded**.
3. **The "Stale Cache" Effect:** Over time, the cache becomes frozen in
the state of the first run. Fixed links remain marked as "broken," and
new links are re-checked every single time, slowing down CI.

### The Solution: Dynamic Keys & Branch Isolation
GitHub restricts cache access based on branches: the **Default Branch**
(`main`) is accessible by all, while **Feature Branches** can only
access their own caches or the default branch's cache.

By using a dynamic key (e.g., `cache-lychee-${{ github.sha }}`) combined
with `restore-keys`, the process moves through two distinct phases:
- **Restore** (start of job)
- **Save** (end of job).

#### Step-by-Step Lifecycle (3-Commit Example)
| Phase | Commit 1: The "Inheritance" | Commit 2: The "Update" | Commit
3: The "Iteration" |
| :--- | :--- | :--- | :--- |
| **Restore** | Misses `cache-lychee-SHA1`. Falls back to `restore-keys`
to pull the latest `main` cache that matches `cache-lychee-` pattern. |
Misses `SHA2`. Pulls the most recent match from the branch scope
(**SHA1**). | Misses `SHA3`. Pulls the latest available version from the
branch (**SHA2**). |
| **Action** | Lychee runs, checks only new links, and uses the
inherited cache for the rest. | Lychee updates results with any new
findings. | Lychee uses the `SHA2` baseline, ensuring zero redundant
checks. |
| **Save** | GitHub saves a **new** cache entry: `lychee-cache-SHA1`. |
GitHub saves a new immutable entry: `lychee-cache-SHA2`. | GitHub saves
the final state: `lychee-cache-SHA3`. |

> [!NOTE]
> This "chaining" effect ensures every commit builds upon the previous
one, while keeping PR runs fast and the nightly "source of truth"
accessible.

<br>

## 3. Optimization: Parallelism & Output
Previously, the configuration relied on `require_serial: true` and
`verbose: true`.

* **The Problem:** `verbose: true` was required to display the "Lychee
not found" warning (since the script uses `exit 0`). However, because
`pre-commit` spawns a process for every file, this caused the warning to
print for every single file checked. `require_serial: true` was used to
stop the log spam, but caused a **2x-3x slowdown**.
* **The Fix:** An **Atomic Sentinel** is now used via `mkdir
/tmp/lychee_lock`. Because `mkdir` is an atomic operation, only the
first process successfully creates the directory and prints the warning.
All other parallel processes fail to create the directory and remain
silent.
* **The Result:** `require_serial` and `verbose` are now **false**. The
check runs in parallel (fast), and the warning prints exactly once
(clean) by redirecting directly to the user's terminal screen via `>
/dev/tty`.

<br>

## 4. Fix Lychee version to install
In **Lychee v0.24.1**, a change in the release archive structure broke
dynamic installation scripts that fetched the "Latest" release.
* **Change:** The installation now uses a **fixed, verified version** to
prevent upstream changes from breaking the CI pipeline.

<br>

## 5. GITHUB_TOKEN
The `GITHUB_TOKEN` is explicitly passed to the Lychee action and
pre-commit steps. This increases the rate limit for GitHub API requests,
reducing `403 Forbidden` and `429 Too Many Requests` errors when
checking internal repository links.

<br>

## Summary Overview
| Feature | Local / PR CI | Nightly CI |
| :--- | :--- | :--- |
| **Failure Condition** | Only `404` | Most non-200 codes |
| **Duration/Retries** | Fast (5 retries / 15 secs) | Patient (30
retries / 30 mins) |
| **Execution** | Parallel (via Sentinel) | Standard Action |
| **Cache Goal** | Consume & Increment | Refresh & Validate |
Add Codex-facing AGENTS.md symlinks that point at the existing Claude instruction files. This lets Codex reuse the same repo and graph_trainer guidance without duplicating instruction content or creating a second source of truth.

The root AGENTS.md points to the repo-level .claude/CLAUDE.md file. The graph_trainer AGENTS.md points to the graph_trainer-local .claude/CLAUDE.md file so directory-local instructions continue to apply when Codex is working in that subtree.

Test Plan:
- git ls-tree -l HEAD AGENTS.md torchtitan/experiments/graph_trainer/AGENTS.md
- git diff --name-status origin/main..HEAD
…torch#3321)

## Summary
- Rename `apply_sac_pass` to `tag_sac_policy` to clarify it only tags
nodes (not the remat transform)
- Skip `torch.ops.device_mesh._get_submesh` in SAC tagging
(metadata-only op, no tensor output)
- Propagate source node metadata (stacktrace, module_fqn, seq_nr) to CPU
offload chain nodes for better tlparse/graph dump readability
- Remove unused `cpu_offload_reload_node` metadata (remat pass discovers
offload chain via graph structure)
- Fix `defer_offload_waits` to defer from production site instead of
latest storage-chain consumer — avoids pushing waits too far (e.g.
layers.1's wait landing at layers.3)
- Fix deferred waits landing before offload ops by including
`ao.offload` in region anchors

## Test plan
- [ ] `pytest torchtitan/experiments/graph_trainer/tests/test_passes.py
-x`
- [ ] `NGPU=8` run with `--compile.debug_graph_passes` and verify
tlparse graph dumps show correct metadata and ordering
- [ ] Verify bitwise deterministic test still passes


**Following are desirable, but didn't' land.** 

Before the fix, we have both waits from layer0 and 1 at the end of layer
2
<img width="1727" height="447" alt="image"
src="https://github.com/user-attachments/assets/423d0ed1-8886-4c51-ba6a-ba40892ee7d2"
/>


after the fix
<img width="1725" height="298" alt="image"
src="https://github.com/user-attachments/assets/7d5df701-031a-46bf-ab93-68f578471e66"
/>


Before the fix

https://www.internalfb.com/intern/diffing/?before_paste_number=2321471471&after_paste_number=2321471510&regex_remove_pattern=&enable_regex_remove=0&strip_empty_lines=0&line_wrap=0&selected_tab=plain_diff

After the fix (removing these few line)

https://www.internalfb.com/intern/diffing/?before_paste_number=2320391950&after_paste_number=2320391970&regex_remove_pattern=&enable_regex_remove=0&strip_empty_lines=0&line_wrap=0&selected_tab=plain_diff
…ch#3308

Replays the pytorch#3308 ("Remove MoE expert for-loop
fallback") deletion onto experiments/ezpz/moe/. That upstream PR
removed the `use_grouped_mm` config field from `GroupedExperts.Config`
and inlined `torch._grouped_mm` as the only expert path. Upstream's
argument is that `_grouped_mm` already provides a CUDA fallback on
pre-SM90 hardware. **XPU has no `_grouped_mm` kernel at all**, so the
unconditional path breaks ezpz on Aurora / Sunspot.

- `experiments/ezpz/moe/experts.py` (new): `EzpzGroupedExperts` subclass
  with `compute_backend: Literal["for_loop", "grouped_mm"]`. Default
  defers to upstream; `"for_loop"` re-vendors the per-expert matmul
  loop that pytorch#3308 deleted, restoring the XPU / pre-SM90 path.
- `experiments/ezpz/moe/__init__.py`: `make_ezpz_experts_config(...)`
  wrapper that calls upstream's `make_experts_config(...)` and
  re-wraps as `EzpzGroupedExperts.Config`. `_build_moe_layers` threads
  a `compute_backend` kwarg (default `"grouped_mm"`).
- `experiments/ezpz/moe/model.py`: `update_from_config` previously
  mutated `experts.use_grouped_mm = False` on pre-SM90 devices; that
  field no longer exists. Replaced with
  `experts_cfg.compute_backend = "for_loop"` guarded by a
  `getattr(..., "compute_backend", "grouped_mm")` so the block is
  robust to future config-shape changes.

Default behavior is unchanged for SM90+ CUDA: `compute_backend` defaults
to `"grouped_mm"`, which calls `super()._experts_forward(...)` and
hits the upstream path. The for_loop branch is only taken on devices
without `_grouped_mm`.

Pending PRs #9 / #10 / #11 will need to rebase onto this and adapt to
the `EzpzGroupedExperts` subclass — see docs/upstream-sync.md for
notes on what each will need to do.
Copilot AI review requested due to automatic review settings May 12, 2026 21:01

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @saforem2, your pull request is larger than the review limit of 150000 diff characters

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR resyncs the ezpz/moe fork with 12 upstream commits (including upstream removal of the MoE use_grouped_mm flag/for-loop fallback), and reintroduces an ezpz-owned fallback expert compute path for devices without a usable torch._grouped_mm kernel (notably XPU). It also pulls in substantial upstream refactors across graph_trainer and the RL vLLM integration, plus link-checker workflow changes.

Changes:

  • Remove use_grouped_mm plumbing from core MoE and model configs, making grouped-mm the sole upstream expert path.
  • Add EzpzGroupedExperts with a compute_backend selector (grouped_mm vs for_loop) and thread it through ezpz MoE configs + runtime config updates for pre-SM90/XPU fallback.
  • Upstream resync: graph_trainer pass refactors + CP support, RL vLLM registry/parser updates, ChunkedCE autograd.grad support, and Lychee link-checker caching/workflow updates.

Reviewed changes

Copilot reviewed 52 out of 53 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
torchtitan/models/llama4/model.py Removes deprecated use_grouped_mm SM90 gating logic in config update.
torchtitan/models/gpt_oss/moe.py Drops for-loop/grouped-mm toggle and inlines grouped-mm expert compute.
torchtitan/models/gpt_oss/model.py Removes deprecated use_grouped_mm SM90 gating logic in config update.
torchtitan/models/deepseek_v3/model.py Removes deprecated use_grouped_mm SM90 gating logic in config update.
torchtitan/models/common/moe.py Removes use_grouped_mm and for-loop fallback; always uses torch._grouped_mm.
torchtitan/models/common/config_utils.py Updates make_experts_config signature to remove use_grouped_mm.
torchtitan/experiments/rl/tests/test_bitwise_parity.py Updates vLLM registration callsite to new registry API.
torchtitan/experiments/rl/models/vllm_wrapper.py Renames wrapper and switches to torchtitan-parallelism–driven ParallelDims.
torchtitan/experiments/rl/models/vllm_registry.py Reworks vLLM registration to include config parser + HF-shaped config dict generation.
torchtitan/experiments/rl/models/parallelize.py Updates config imports to new consolidated module exports.
torchtitan/experiments/rl/grpo.py Updates config imports to new consolidated module exports.
torchtitan/experiments/rl/generate.py Updates vLLM registration callsite to new registry API and threads parallelism.
torchtitan/experiments/rl/config_registry.py Updates config imports and enforces generator parallelism invariants in configs.
torchtitan/experiments/rl/actors/trainer.py Updates config imports to new consolidated module exports.
torchtitan/experiments/rl/actors/generator.py Adds generator-side validation (no SP / loss parallel) and uses new vLLM parser registration path.
torchtitan/experiments/rl/init.py Exposes renamed wrapper + new registry entry point.
torchtitan/experiments/graph_trainer/trainer.py Adds timing around trace step and adjusts imports for refactored registries.
torchtitan/experiments/graph_trainer/tests/test_trace_module.py Adds tests for ChunkedCE autograd.grad tracing and CP tracing expectations.
torchtitan/experiments/graph_trainer/tests/test_precompile.py Updates precompile test fixtures with new trace metadata fields.
torchtitan/experiments/graph_trainer/tests/test_passes.py Updates pass imports and adds extensive remat/offload behavioral tests.
torchtitan/experiments/graph_trainer/tests/test_chunked_loss.py Adds unit tests for ChunkedCELossWithParamGrads parity/grad side effects.
torchtitan/experiments/graph_trainer/tests/integration_tests.py Updates integration test overrides for CP and memory policy rename.
torchtitan/experiments/graph_trainer/selective_activation_remat.py Refactors remat to in-place duplication + targeted erase and offload-chain hoisting.
torchtitan/experiments/graph_trainer/registry.py New centralized registries to avoid circular imports across pass modules.
torchtitan/experiments/graph_trainer/qwen3/parallelize.py Adds CP application to attention modules when CP is enabled.
torchtitan/experiments/graph_trainer/precompile.py Moves distributed metadata filtering import to refactored inductor passes module.
torchtitan/experiments/graph_trainer/passes.py Refactors into orchestrator module; delegates passes to focused modules and registry.
torchtitan/experiments/graph_trainer/memory_policy.py New memory policy tagging module (SAC + optional offload) and dispatcher.
torchtitan/experiments/graph_trainer/llama3/parallelize.py Adds CP application to attention modules when CP is enabled.
torchtitan/experiments/graph_trainer/llama3/parallelize_autoparallel.py Updates autoparallel input_fn to include positions tensor and constraints.
torchtitan/experiments/graph_trainer/inductor_passes.py New module for regional/full Inductor compilation passes and FlexAttention annotation.
torchtitan/experiments/graph_trainer/graph_utils.py Removes obsolete grouped-mm comment; updates memory policy error text.
torchtitan/experiments/graph_trainer/fsdp_passes.py Preserves MUST_CPU_OFFLOAD tags and updates backward-node classification.
torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py Switches CP checks to shared helper; applies CP when enabled.
torchtitan/experiments/graph_trainer/cudagraph.py Moves kernel annotation pass + cudagraph wrapper into cudagraph module.
torchtitan/experiments/graph_trainer/cpu_offload.py Propagates node metadata into offload chain and adjusts wait deferral behavior.
torchtitan/experiments/graph_trainer/configs.py Renames memory policy option to sac_and_offload.
torchtitan/experiments/graph_trainer/common_utils.py Adds log_timer and CP-to-attention helper; removes recomputed-node workaround.
torchtitan/experiments/graph_trainer/chunked_loss.py New ChunkedCE variant to plumb lm_head param grads via explicit autograd outputs.
torchtitan/experiments/graph_trainer/autoparallel_api.py Removes redundant forward-arg validation call.
torchtitan/experiments/graph_trainer/.claude/CLAUDE.md Updates documentation to match in-place remat behavior.
torchtitan/experiments/ezpz/moe/model.py Switches ezpz MoE fallback from removed use_grouped_mm to compute_backend="for_loop".
torchtitan/experiments/ezpz/moe/experts.py New EzpzGroupedExperts with compute_backend selector and for-loop fallback.
torchtitan/experiments/ezpz/moe/init.py Adds make_ezpz_experts_config wrapper and threads compute_backend into layer build.
torchtitan/experiments/ezpz/docs/upstream-sync.md Documents the upstream sync and the ezpz MoE replay strategy.
torchtitan/components/loss.py Refactors chunked CE loss to allow subclasses to override autograd bridging.
tests/unit_tests/test_compile_moe.py Updates MoE compile test to use GroupedExperts._experts_forward directly.
.pre-commit-config.yaml Makes lychee less noisy, enables caching, and adds a “single warning” sentinel.
.gitignore Ignores .lycheecache produced by lychee --cache.
.github/workflows/lint.yaml Pins lychee version, adds cache restore, and runs lychee via pre-commit across all files.
.github/workflows/link_check.yaml Removes standalone scheduled link-check workflow (now covered in lint).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread torchtitan/components/loss.py
Comment thread torchtitan/experiments/graph_trainer/chunked_loss.py
saforem2 added 4 commits May 12, 2026 16:15
Three lint hits surfaced by CI on the new files in #13:

- **ufmt**: applied the formatter's reformatting to the three new
  `experiments/ezpz/moe/` files (import grouping, comprehension
  collapse, line-wrap nits in `model.py`).
- **flake8 N801**: pre-existing class names `moeTransformerBlock` /
  `moeModel` violate CapWords, but renaming would touch every
  callsite outside of this PR's scope. Suppressed in-place with
  `# noqa: N801` on the two class declarations.
- **codespell**: a quoted upstream PR title in
  `experiments/ezpz/docs/upstream-sync.md` contains "Reenable" which
  codespell flags as "Re-enable". Added a `<!-- codespell:ignore-line -->`
  marker so the historical quote stays verbatim. Strictly inside
  `experiments/ezpz/` per Golden Rule #1 — no edits to core
  `pyproject.toml`.
- `docs/summaries/2026-04-12_to_2026-04-27.md:230` linked to
  `journal.md` (resolves to `summaries/journal.md`, doesn't exist).
  Fix: `../journal.md` (the file lives at `docs/journal.md`).

- `README.md:95` linked to `torchtitan/experiments/ezpz/run_train.sh`
  which doubled the path (resolved to
  `experiments/ezpz/torchtitan/experiments/ezpz/run_train.sh`). Fix:
  `run_train.sh` (sibling of the README).
Two pre-commit failures landed by the previous link-fix commit (which
brought README.md into the lint set):

- **trailing-whitespace** on line 4: stray ``  `` (markdown line-break)
  after `i.e.:`. Dropped — the line breaks naturally on the
  `[saforem2/torchtitan@ezpz](...)` link that follows.
- **codespell** on line 18: `preemptable` is the actual ALCF Polaris
  queue name, but codespell wants `preemptible`. Stash it as
  `preempt"able"` (shell concatenation, resolves to the same literal
  at submission time) so codespell sees two separate tokens. Header
  parenthetical clarifies what the queue is for the reader.
50-step `moe_500m` smoke on Sunspot 8N (job 12466707) validating PR
#13's `EzpzGroupedExperts.compute_backend = "for_loop"` fallback.

- 11 layer-wise warnings confirm the fallback fires (1 per MoE layer).
- Loss descends cleanly 12.90 → 6.66 (-6.24 nats) over 50 steps.
- Steady-state ~8,694 TPS / GPU, ~13% MFU — ~20% per-GPU uplift vs the
  2026-04-13 2N baseline (torch.compile inductor improvements; not
  regression in the for_loop kernel).
- Memory stable at 54.6% across the run; wall 541s end-to-end.

W&B: fluent-glitter-2042 (https://wandb.ai/aurora_gpt/torchtitan.ezpz.train/runs/lt77xx0o)
@saforem2 saforem2 merged commit d958fc1 into ezpz May 13, 2026
3 of 4 checks passed
@saforem2 saforem2 deleted the ezpz-moe-resync branch May 13, 2026 03:49
saforem2 added a commit that referenced this pull request May 20, 2026
Validates the 35th upstream sync (PR pytorch#3159 Module.parallelize signature
change replay) on Sunspot 2N (job 12467124) across 5 configs:

  agpt_debugmodel       LBS=2  loss 10.81 → 6.77  ~37,500 TPS / ~2.9% MFU
  agpt_2b               LBS=1  loss 12.93 → 6.01  ~6,100  TPS / ~22.8% MFU
  agpt_2b               LBS=2  loss 12.91 → 6.12  ~7,200  TPS / ~27.0% MFU
  moe_debugmodel        LBS=2  loss 12.96 → 7.01  ~13,000 TPS / ~9.0% MFU
  moe_2b                LBS=1  loss 12.91 → 6.16  ~2,900  TPS / ~8.4% MFU

agpt_2b LBS=2 matches the 2026-04-25 n=2 baseline (7,142 TPS / 27.6%
MFU) within noise; resync introduces no throughput regression. The
for_loop MoE expert backend (PR #13) still fires correctly on XPU
under the new parallelize signature.

Default moe_2b() LBS=16 OOMs on Max 1550 (single 62.53 GiB allocation,
likely the fused activation for all experts × full-batch-tokens);
LBS=1 documented as the safe XPU baseline.

Adds two new reports under docs/experiments/{agpt,moe}/sunspot/ and
indexes them in the parent READMEs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants