Skip to content

Add padded batched matmul MoE expert backend#10

Open
samuelwheeler wants to merge 2 commits into
saforem2:ezpzfrom
samuelwheeler:feature/moe-batched-mm-padded
Open

Add padded batched matmul MoE expert backend#10
samuelwheeler wants to merge 2 commits into
saforem2:ezpzfrom
samuelwheeler:feature/moe-batched-mm-padded

Conversation

@samuelwheeler

@samuelwheeler samuelwheeler commented May 11, 2026

Copy link
Copy Markdown

Summary

Adds a batched_mm_padded expert compute backend for MoE GroupedExperts.

The backend pads routed expert inputs to [num_experts, max_tokens_per_expert, dim], uses torch.bmm for the SwiGLU expert projections, then gathers back to routed token order.

Default behavior is unchanged: if compute_backend is unset, the existing use_grouped_mm flag still selects grouped-mm vs for-loop. The new path is opt-in via compute_backend="batched_mm_padded".

Also threads the backend selector through make_experts_config, adds an ezpz 10B_2B_sdpa_batched_mm_padded flavor, and adds unit coverage comparing outputs and gradients against the for-loop implementation.

Validation

  • git diff --cached --check
  • /lus/flare/projects/AuroraGPT/sww/new_tt_aurora/torchtitan/.venv/bin/python -m py_compile torchtitan/models/common/moe.py torchtitan/models/common/config_utils.py torchtitan/experiments/ezpz/moe/__init__.py torchtitan/experiments/ezpz/moe/config_registry.py tests/unit_tests/test_moe_expert_backends.py
  • PYTHONPATH=. timeout 300s /lus/flare/projects/AuroraGPT/sww/new_tt_aurora/torchtitan/.venv/bin/python -m unittest -v tests.unit_tests.test_moe_expert_backends

Summary by Sourcery

Add a new padded batched matmul expert compute backend for MoE GroupedExperts and wire it through configs and tests.

New Features:

  • Introduce a batched_mm_padded expert compute backend for MoE GroupedExperts using padded batched matrix multiplications.
  • Add a 10B_2B_sdpa_batched_mm_padded MoE configuration variant that uses the new backend.

Enhancements:

  • Thread a generic expert compute_backend selector through GroupedExperts and make_experts_config while preserving existing default behavior.

Tests:

  • Add unit tests to validate that the batched_mm_padded backend matches the for_loop implementation in outputs and gradients, and that backend selection in GroupedExperts behaves correctly.

@sourcery-ai

sourcery-ai Bot commented May 11, 2026

Copy link
Copy Markdown

Reviewer's Guide

Introduces a new padded batched matmul expert compute backend for MoE GroupedExperts, wires it into configuration and ezpz experiment presets, and adds unit tests to verify numerical and gradient equivalence with the existing for-loop backend while preserving existing defaults.

Flow diagram for configuring batched_mm_padded MoE expert backend

flowchart LR
    A[m oe_10b_2b_sdpa_batched_mm_padded] --> B[_10b_2b_sdpa_batched_mm_padded]
    B --> C[set layer_cfg.moe.experts.compute_backend = batched_mm_padded]
    C --> D[make_experts_config<br>compute_backend = batched_mm_padded]
    D --> E[GroupedExperts.Config<br>compute_backend = batched_mm_padded]
    E --> F[GroupedExperts<br>compute_backend == batched_mm_padded]
Loading

File-Level Changes

Change Details Files
Add a padded batched matmul MoE expert backend and integrate it with the expert backend selector.
  • Define ExpertComputeBackend Literal type and thread it into GroupedExperts configuration.
  • Implement _run_experts_batched_mm_padded that pads expert inputs per expert, runs SwiGLU projections via torch.bmm, and gathers results back to routed token order.
  • Refactor GroupedExperts to derive compute_backend from compute_backend/use_grouped_mm, route _experts_forward based on the backend string, and raise on unknown backends.
torchtitan/models/common/moe.py
Expose the new backend through config utilities and ezpz MoE experiment presets.
  • Extend make_experts_config to accept and pass through a compute_backend parameter alongside use_grouped_mm.
  • Add a 10B_2B_sdpa_batched_mm_padded MoE experiment config that selects the new backend and disables grouped_mm.
  • Register the new experiment in the ezpz MoE config registry and JSON loader helpers.
torchtitan/models/common/config_utils.py
torchtitan/experiments/ezpz/moe/__init__.py
torchtitan/experiments/ezpz/moe/config_registry.py
Add unit tests to validate the new backend and backend selection logic against the for-loop implementation.
  • Introduce tests comparing outputs and gradients between _run_experts_batched_mm_padded and _run_experts_for_loop for a fixed routing pattern.
  • Add tests ensuring GroupedExperts with compute_backend='batched_mm_padded' matches the for-loop backend in both outputs and gradients, including parameter grads and input grads.
tests/unit_tests/test_moe_expert_backends.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 4 issues, and left some high level feedback:

  • The interaction between use_grouped_mm and the new compute_backend is a bit confusing: the constructor silently derives compute_backend from use_grouped_mm, but downstream configs (e.g. _10b_2b_sdpa_batched_mm_padded) now set both; consider either deprecating use_grouped_mm in favor of compute_backend or asserting they don't conflict to avoid surprising behavior.
  • In _run_experts_batched_mm_padded, the expert and token index computation is non-trivial; adding a small helper or inline comment about the expected layout of x (e.g., that tokens are grouped by expert and counts is in that order) would make future maintenance and refactors less error-prone.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- The interaction between `use_grouped_mm` and the new `compute_backend` is a bit confusing: the constructor silently derives `compute_backend` from `use_grouped_mm`, but downstream configs (e.g. `_10b_2b_sdpa_batched_mm_padded`) now set both; consider either deprecating `use_grouped_mm` in favor of `compute_backend` or asserting they don't conflict to avoid surprising behavior.
- In `_run_experts_batched_mm_padded`, the expert and token index computation is non-trivial; adding a small helper or inline comment about the expected layout of `x` (e.g., that tokens are grouped by expert and `counts` is in that order) would make future maintenance and refactors less error-prone.

## Individual Comments

### Comment 1
<location path="torchtitan/experiments/ezpz/moe/__init__.py" line_range="947-948" />
<code_context>
+    for layer_cfg in cfg.layers:
+        if layer_cfg.moe is None:
+            continue
+        layer_cfg.moe.experts.compute_backend = "batched_mm_padded"
+        layer_cfg.moe.experts.use_grouped_mm = False
+    return cfg
+
</code_context>
<issue_to_address>
**suggestion:** Setting both `compute_backend` and `use_grouped_mm` is redundant now that backend drives the flag.

Since `use_grouped_mm` is already derived from `compute_backend` at initialization, explicitly setting it here is redundant and may drift from the initialization logic over time. Please rely on `compute_backend = "batched_mm_padded"` alone and remove the explicit `use_grouped_mm` assignment to keep this configuration consistent and minimal.
</issue_to_address>

### Comment 2
<location path="tests/unit_tests/test_moe_expert_backends.py" line_range="33-42" />
<code_context>
+    def test_batched_mm_padded_matches_for_loop_kernel(self) -> None:
</code_context>
<issue_to_address>
**suggestion (testing):** Add explicit tests for empty and all-zero `num_tokens_per_expert` to cover early-return branches in `_run_experts_batched_mm_padded`.

The new backend has two early-return branches (`counts.numel() == 0` and `max_tokens == 0`), but the current test with `counts = [0, 3, 5, 1, 0, 7]` never hits them. Please add small tests that cover:

1. No experts (empty `num_tokens_per_expert`), and
2. All experts with zero tokens (e.g., `counts = [0, 0, 0]`).

For each, call `_run_experts_for_loop` and `_run_experts_batched_mm_padded` with the same inputs and assert that outputs (and ideally gradients) match, so these early-return paths are covered and protected against regressions.

Suggested implementation:

```python
class MoEExpertBackendsTest(unittest.TestCase):
    def test_batched_mm_padded_matches_for_loop_kernel(self) -> None:
        torch.manual_seed(1234)
        dim = 16
        hidden_dim = 24

        # Baseline case (existing coverage)
        num_experts = 6
        counts = torch.tensor([0, 3, 5, 1, 0, 7], dtype=torch.int64)
        total_tokens = int(counts.sum().item())

        w1_base = torch.randn(num_experts, hidden_dim, dim) * 0.02
        w2_base = torch.randn(num_experts, dim, hidden_dim) * 0.02
        w3_base = torch.randn(num_experts, hidden_dim, dim) * 0.02

        w1_for_loop = w1_base.clone().detach().requires_grad_(True)
        w2_for_loop = w2_base.clone().detach().requires_grad_(True)
        w3_for_loop = w3_base.clone().detach().requires_grad_(True)

        w1_batched = w1_base.clone().detach().requires_grad_(True)
        w2_batched = w2_base.clone().detach().requires_grad_(True)
        w3_batched = w3_base.clone().detach().requires_grad_(True)

        x_for_loop = torch.randn(total_tokens, dim, requires_grad=True)
        x_batched = x_for_loop.clone().detach().requires_grad_(True)

        out_for_loop = _run_experts_for_loop(
            x_for_loop, w1_for_loop, w2_for_loop, w3_for_loop, counts
        )
        out_batched = _run_experts_batched_mm_padded(
            x_batched, w1_batched, w2_batched, w3_batched, counts
        )

        self.assertTrue(torch.allclose(out_for_loop, out_batched, atol=1e-6, rtol=1e-6))

        out_for_loop.sum().backward()
        out_batched.sum().backward()

        self.assertTrue(torch.allclose(x_for_loop.grad, x_batched.grad, atol=1e-6, rtol=1e-6))
        self.assertTrue(torch.allclose(w1_for_loop.grad, w1_batched.grad, atol=1e-6, rtol=1e-6))
        self.assertTrue(torch.allclose(w2_for_loop.grad, w2_batched.grad, atol=1e-6, rtol=1e-6))
        self.assertTrue(torch.allclose(w3_for_loop.grad, w3_batched.grad, atol=1e-6, rtol=1e-6))

        # Early-return case 1: no experts (empty counts; counts.numel() == 0)
        counts_empty = torch.empty(0, dtype=torch.int64)
        num_experts_empty = 0
        total_tokens_empty = int(counts_empty.sum().item())

        w1_base_empty = torch.randn(num_experts_empty, hidden_dim, dim) * 0.02
        w2_base_empty = torch.randn(num_experts_empty, dim, hidden_dim) * 0.02
        w3_base_empty = torch.randn(num_experts_empty, hidden_dim, dim) * 0.02

        w1_for_loop_empty = w1_base_empty.clone().detach().requires_grad_(True)
        w2_for_loop_empty = w2_base_empty.clone().detach().requires_grad_(True)
        w3_for_loop_empty = w3_base_empty.clone().detach().requires_grad_(True)

        w1_batched_empty = w1_base_empty.clone().detach().requires_grad_(True)
        w2_batched_empty = w2_base_empty.clone().detach().requires_grad_(True)
        w3_batched_empty = w3_base_empty.clone().detach().requires_grad_(True)

        x_for_loop_empty = torch.randn(total_tokens_empty, dim, requires_grad=True)
        x_batched_empty = x_for_loop_empty.clone().detach().requires_grad_(True)

        out_for_loop_empty = _run_experts_for_loop(
            x_for_loop_empty,
            w1_for_loop_empty,
            w2_for_loop_empty,
            w3_for_loop_empty,
            counts_empty,
        )
        out_batched_empty = _run_experts_batched_mm_padded(
            x_batched_empty,
            w1_batched_empty,
            w2_batched_empty,
            w3_batched_empty,
            counts_empty,
        )

        self.assertTrue(
            torch.allclose(out_for_loop_empty, out_batched_empty, atol=1e-6, rtol=1e-6)
        )

        out_for_loop_empty.sum().backward()
        out_batched_empty.sum().backward()

        self.assertTrue(
            torch.allclose(
                x_for_loop_empty.grad, x_batched_empty.grad, atol=1e-6, rtol=1e-6
            )
        )
        self.assertTrue(
            torch.allclose(
                w1_for_loop_empty.grad, w1_batched_empty.grad, atol=1e-6, rtol=1e-6
            )
        )
        self.assertTrue(
            torch.allclose(
                w2_for_loop_empty.grad, w2_batched_empty.grad, atol=1e-6, rtol=1e-6
            )
        )
        self.assertTrue(
            torch.allclose(
                w3_for_loop_empty.grad, w3_batched_empty.grad, atol=1e-6, rtol=1e-6
            )
        )

        # Early-return case 2: all experts with zero tokens (max_tokens == 0)
        counts_all_zero = torch.tensor([0, 0, 0], dtype=torch.int64)
        num_experts_all_zero = counts_all_zero.numel()
        total_tokens_all_zero = int(counts_all_zero.sum().item())

        w1_base_all_zero = torch.randn(num_experts_all_zero, hidden_dim, dim) * 0.02
        w2_base_all_zero = torch.randn(num_experts_all_zero, dim, hidden_dim) * 0.02
        w3_base_all_zero = torch.randn(num_experts_all_zero, hidden_dim, dim) * 0.02

        w1_for_loop_all_zero = w1_base_all_zero.clone().detach().requires_grad_(True)
        w2_for_loop_all_zero = w2_base_all_zero.clone().detach().requires_grad_(True)
        w3_for_loop_all_zero = w3_base_all_zero.clone().detach().requires_grad_(True)

        w1_batched_all_zero = w1_base_all_zero.clone().detach().requires_grad_(True)
        w2_batched_all_zero = w2_base_all_zero.clone().detach().requires_grad_(True)
        w3_batched_all_zero = w3_base_all_zero.clone().detach().requires_grad_(True)

        x_for_loop_all_zero = torch.randn(
            total_tokens_all_zero, dim, requires_grad=True
        )
        x_batched_all_zero = x_for_loop_all_zero.clone().detach().requires_grad_(True)

        out_for_loop_all_zero = _run_experts_for_loop(
            x_for_loop_all_zero,
            w1_for_loop_all_zero,
            w2_for_loop_all_zero,
            w3_for_loop_all_zero,
            counts_all_zero,
        )
        out_batched_all_zero = _run_experts_batched_mm_padded(
            x_batched_all_zero,
            w1_batched_all_zero,
            w2_batched_all_zero,
            w3_batched_all_zero,
            counts_all_zero,
        )

        self.assertTrue(
            torch.allclose(
                out_for_loop_all_zero, out_batched_all_zero, atol=1e-6, rtol=1e-6
            )
        )

        out_for_loop_all_zero.sum().backward()
        out_batched_all_zero.sum().backward()

        self.assertTrue(
            torch.allclose(
                x_for_loop_all_zero.grad,
                x_batched_all_zero.grad,
                atol=1e-6,
                rtol=1e-6,
            )
        )
        self.assertTrue(
            torch.allclose(
                w1_for_loop_all_zero.grad,
                w1_batched_all_zero.grad,
                atol=1e-6,
                rtol=1e-6,
            )
        )
        self.assertTrue(
            torch.allclose(
                w2_for_loop_all_zero.grad,
                w2_batched_all_zero.grad,
                atol=1e-6,
                rtol=1e-6,
            )
        )
        self.assertTrue(
            torch.allclose(
                w3_for_loop_all_zero.grad,
                w3_batched_all_zero.grad,
                atol=1e-6,
                rtol=1e-6,
            )
        )

```

This edit assumes that `_run_experts_for_loop` and `_run_experts_batched_mm_padded` both accept `(x, w1, w2, w3, num_tokens_per_expert)` in that order, and that they correctly handle the cases where `num_tokens_per_expert` is empty or sums to zero. If their signatures differ, adjust the call sites accordingly.

If the original implementation of `test_batched_mm_padded_matches_for_loop_kernel` contained more logic than is visible in the provided snippet, you should merge that existing logic into the new test body rather than replacing it wholesale. The key requirement is that the two new early-return scenarios (empty counts and all-zero counts) are exercised, and that both outputs and gradients from the for-loop and batched backends are compared.
</issue_to_address>

### Comment 3
<location path="tests/unit_tests/test_moe_expert_backends.py" line_range="32-41" />
<code_context>
+class MoEExpertBackendsTest(unittest.TestCase):
</code_context>
<issue_to_address>
**suggestion (testing):** Consider parameterizing tests over device and dtype to ensure the backend behaves consistently beyond default CPU/float32.

Both tests currently exercise only CPU with `float32`, so any device- or dtype-specific issues in `_run_experts_batched_mm_padded` (e.g., CUDA vs CPU, `float16`/`bfloat16` vs `float32`) may be missed. If your CI supports GPUs, please consider parametrizing over `device` (CPU and CUDA when available) and/or adding a variant using a lower-precision dtype to validate numerical and autograd behavior against the reference loop.

Suggested implementation:

```python
    dim: int,
    hidden_dim: int,
    *,
    device: torch.device | str = "cpu",
    dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    w1 = torch.randn(num_experts, hidden_dim, dim, device=device, dtype=dtype) * 0.02
    w2 = torch.randn(num_experts, dim, hidden_dim, device=device, dtype=dtype) * 0.02
    w3 = torch.randn(num_experts, hidden_dim, dim, device=device, dtype=dtype) * 0.02
    return w1, w2, w3

```

```python
class MoEExpertBackendsTest(unittest.TestCase):
    def test_batched_mm_padded_matches_for_loop_kernel(self) -> None:
        devices: list[str] = ["cpu"]
        if torch.cuda.is_available():
            devices.append("cuda")

        # Always include float32; add a lower-precision dtype to exercise
        # numerical and backend behavior beyond the default.
        dtypes: list[torch.dtype] = [torch.float32]
        if hasattr(torch, "float16"):
            dtypes.append(torch.float16)

        for device in devices:
            for dtype in dtypes:
                with self.subTest(device=device, dtype=dtype):
                    torch.manual_seed(1234)
                    num_experts = 6
                    dim = 16
                    hidden_dim = 24
                    counts = torch.tensor(
                        [0, 3, 5, 1, 0, 7],
                        dtype=torch.int64,
                        device=device,
                    )
                    total_tokens = int(counts.sum().item())

                    w1, w2, w3 = _init_weights(
                        num_experts,
                        dim,
                        hidden_dim,
                        device=device,
                        dtype=dtype,
                    )
                    x = torch.randn(total_tokens, dim, device=device, dtype=dtype)

```

If the remainder of `test_batched_mm_padded_matches_for_loop_kernel` (not shown in the snippet) creates additional tensors (e.g., reference outputs, indices, or masks), they should also be constructed with `device=device` and `dtype=dtype` (or at least `device=device`, and `dtype` chosen consistently with the values being compared) to avoid device/dtype mismatches and to ensure the parameterization fully exercises the backend on the selected configuration.
</issue_to_address>

### Comment 4
<location path="torchtitan/models/common/moe.py" line_range="139" />
<code_context>
             torch.empty(config.num_experts, config.hidden_dim, config.dim)
         )
-        self.use_grouped_mm = config.use_grouped_mm
+        self.compute_backend: ExpertComputeBackend = (
+            config.compute_backend
+            if config.compute_backend is not None
</code_context>
<issue_to_address>
**issue (complexity):** Consider simplifying the expert compute configuration and dispatch logic and refactoring the padded backend to isolate indexing math for better readability and maintainability.

You can simplify the configuration and dispatch layer and slightly decompose the padded backend without changing behavior.

### 1. Remove redundant `use_grouped_mm` instance state

Right now `self.compute_backend` is the real source of truth, and `self.use_grouped_mm` is derived from it but still kept as state. You can normalize legacy configs once and then drop the extra attribute:

```python
@dataclass(kw_only=True, slots=True)
class Config(Module.Config):
    dim: int
    hidden_dim: int
    num_experts: int
    # keep for backward-compat config, but treat as deprecated
    use_grouped_mm: bool = True
    compute_backend: ExpertComputeBackend | None = None
    token_dispatcher: LocalTokenDispatcher.Config

def __init__(self, config: Config):
    super().__init__()
    ...
    if config.compute_backend is not None:
        backend: ExpertComputeBackend = config.compute_backend
    else:
        backend = "grouped_mm" if config.use_grouped_mm else "for_loop"

    self.compute_backend: ExpertComputeBackend = backend
    # drop: self.use_grouped_mm = ...
    self.token_dispatcher = config.token_dispatcher.build()
```

This removes redundant state and ensures all downstream logic only needs to look at `self.compute_backend`.

### 2. Centralize backend dispatch

The current sequence of independent `if` statements plus a final `ValueError` is a bit brittle. A small dispatch map keeps the logic flat and makes adding/removing backends less error-prone:

```python
def _experts_forward(
    self,
    x: torch.Tensor,
    num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
    ...
    backends: dict[ExpertComputeBackend, Callable[..., torch.Tensor]] = {
        "grouped_mm": _run_experts_grouped_mm,
        "batched_mm_padded": _run_experts_batched_mm_padded,
        "for_loop": _run_experts_for_loop,
    }

    try:
        fn = backends[self.compute_backend]
    except KeyError:
        raise ValueError(f"Unknown expert compute backend: {self.compute_backend}")

    return fn(w1, w2, w3, x, num_tokens_per_expert)
```

This keeps behavior identical while making the branching logic clearer.

### 3. Extract the index layout logic in `_run_experts_batched_mm_padded`

The index math is the hardest part to reason about. Pulling it into a small helper gives `_run_experts_batched_mm_padded` a more “algorithmic” shape:

```python
def _compute_expert_layout(
    counts: torch.Tensor, total_tokens: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    # counts: [num_experts]
    offsets = counts.cumsum(0) - counts
    expert_indices = torch.repeat_interleave(
        torch.arange(counts.numel(), device=device, dtype=torch.int64),
        counts,
    )
    token_indices_within_expert = torch.arange(
        total_tokens, device=device, dtype=torch.int64
    ) - torch.repeat_interleave(offsets, counts)
    return offsets, expert_indices, token_indices_within_expert
```

Then `_run_experts_batched_mm_padded` becomes:

```python
def _run_experts_batched_mm_padded(
    w1: torch.Tensor,
    w2: torch.Tensor,
    w3: torch.Tensor,
    x: torch.Tensor,
    num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
    counts = num_tokens_per_expert.to(device=x.device, dtype=torch.int64)
    if counts.numel() == 0 or counts.max().item() == 0:
        return x.new_empty((0, w2.shape[1]))

    total_tokens = x.shape[0]
    device = x.device
    max_tokens = int(counts.max().item())

    _, expert_indices, token_indices_within_expert = _compute_expert_layout(
        counts, total_tokens, device
    )

    padded_x = x.new_zeros((counts.numel(), max_tokens, x.shape[-1]))
    padded_x[expert_indices, token_indices_within_expert] = x

    h = F.silu(torch.bmm(padded_x, w1.transpose(-2, -1)))
    h = h * torch.bmm(padded_x, w3.transpose(-2, -1))
    out_padded = torch.bmm(h, w2.transpose(-2, -1))

    return out_padded[expert_indices, token_indices_within_expert]
```

This keeps all behavior and shapes intact but isolates the tricky layout logic and simplifies the main computation path.
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment thread torchtitan/experiments/ezpz/moe/__init__.py Outdated
Comment on lines +33 to +42
def test_batched_mm_padded_matches_for_loop_kernel(self) -> None:
torch.manual_seed(1234)
num_experts = 6
dim = 16
hidden_dim = 24
counts = torch.tensor([0, 3, 5, 1, 0, 7], dtype=torch.int64)
total_tokens = int(counts.sum().item())

w1, w2, w3 = _init_weights(num_experts, dim, hidden_dim)
x = torch.randn(total_tokens, dim)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add explicit tests for empty and all-zero num_tokens_per_expert to cover early-return branches in _run_experts_batched_mm_padded.

The new backend has two early-return branches (counts.numel() == 0 and max_tokens == 0), but the current test with counts = [0, 3, 5, 1, 0, 7] never hits them. Please add small tests that cover:

  1. No experts (empty num_tokens_per_expert), and
  2. All experts with zero tokens (e.g., counts = [0, 0, 0]).

For each, call _run_experts_for_loop and _run_experts_batched_mm_padded with the same inputs and assert that outputs (and ideally gradients) match, so these early-return paths are covered and protected against regressions.

Suggested implementation:

class MoEExpertBackendsTest(unittest.TestCase):
    def test_batched_mm_padded_matches_for_loop_kernel(self) -> None:
        torch.manual_seed(1234)
        dim = 16
        hidden_dim = 24

        # Baseline case (existing coverage)
        num_experts = 6
        counts = torch.tensor([0, 3, 5, 1, 0, 7], dtype=torch.int64)
        total_tokens = int(counts.sum().item())

        w1_base = torch.randn(num_experts, hidden_dim, dim) * 0.02
        w2_base = torch.randn(num_experts, dim, hidden_dim) * 0.02
        w3_base = torch.randn(num_experts, hidden_dim, dim) * 0.02

        w1_for_loop = w1_base.clone().detach().requires_grad_(True)
        w2_for_loop = w2_base.clone().detach().requires_grad_(True)
        w3_for_loop = w3_base.clone().detach().requires_grad_(True)

        w1_batched = w1_base.clone().detach().requires_grad_(True)
        w2_batched = w2_base.clone().detach().requires_grad_(True)
        w3_batched = w3_base.clone().detach().requires_grad_(True)

        x_for_loop = torch.randn(total_tokens, dim, requires_grad=True)
        x_batched = x_for_loop.clone().detach().requires_grad_(True)

        out_for_loop = _run_experts_for_loop(
            x_for_loop, w1_for_loop, w2_for_loop, w3_for_loop, counts
        )
        out_batched = _run_experts_batched_mm_padded(
            x_batched, w1_batched, w2_batched, w3_batched, counts
        )

        self.assertTrue(torch.allclose(out_for_loop, out_batched, atol=1e-6, rtol=1e-6))

        out_for_loop.sum().backward()
        out_batched.sum().backward()

        self.assertTrue(torch.allclose(x_for_loop.grad, x_batched.grad, atol=1e-6, rtol=1e-6))
        self.assertTrue(torch.allclose(w1_for_loop.grad, w1_batched.grad, atol=1e-6, rtol=1e-6))
        self.assertTrue(torch.allclose(w2_for_loop.grad, w2_batched.grad, atol=1e-6, rtol=1e-6))
        self.assertTrue(torch.allclose(w3_for_loop.grad, w3_batched.grad, atol=1e-6, rtol=1e-6))

        # Early-return case 1: no experts (empty counts; counts.numel() == 0)
        counts_empty = torch.empty(0, dtype=torch.int64)
        num_experts_empty = 0
        total_tokens_empty = int(counts_empty.sum().item())

        w1_base_empty = torch.randn(num_experts_empty, hidden_dim, dim) * 0.02
        w2_base_empty = torch.randn(num_experts_empty, dim, hidden_dim) * 0.02
        w3_base_empty = torch.randn(num_experts_empty, hidden_dim, dim) * 0.02

        w1_for_loop_empty = w1_base_empty.clone().detach().requires_grad_(True)
        w2_for_loop_empty = w2_base_empty.clone().detach().requires_grad_(True)
        w3_for_loop_empty = w3_base_empty.clone().detach().requires_grad_(True)

        w1_batched_empty = w1_base_empty.clone().detach().requires_grad_(True)
        w2_batched_empty = w2_base_empty.clone().detach().requires_grad_(True)
        w3_batched_empty = w3_base_empty.clone().detach().requires_grad_(True)

        x_for_loop_empty = torch.randn(total_tokens_empty, dim, requires_grad=True)
        x_batched_empty = x_for_loop_empty.clone().detach().requires_grad_(True)

        out_for_loop_empty = _run_experts_for_loop(
            x_for_loop_empty,
            w1_for_loop_empty,
            w2_for_loop_empty,
            w3_for_loop_empty,
            counts_empty,
        )
        out_batched_empty = _run_experts_batched_mm_padded(
            x_batched_empty,
            w1_batched_empty,
            w2_batched_empty,
            w3_batched_empty,
            counts_empty,
        )

        self.assertTrue(
            torch.allclose(out_for_loop_empty, out_batched_empty, atol=1e-6, rtol=1e-6)
        )

        out_for_loop_empty.sum().backward()
        out_batched_empty.sum().backward()

        self.assertTrue(
            torch.allclose(
                x_for_loop_empty.grad, x_batched_empty.grad, atol=1e-6, rtol=1e-6
            )
        )
        self.assertTrue(
            torch.allclose(
                w1_for_loop_empty.grad, w1_batched_empty.grad, atol=1e-6, rtol=1e-6
            )
        )
        self.assertTrue(
            torch.allclose(
                w2_for_loop_empty.grad, w2_batched_empty.grad, atol=1e-6, rtol=1e-6
            )
        )
        self.assertTrue(
            torch.allclose(
                w3_for_loop_empty.grad, w3_batched_empty.grad, atol=1e-6, rtol=1e-6
            )
        )

        # Early-return case 2: all experts with zero tokens (max_tokens == 0)
        counts_all_zero = torch.tensor([0, 0, 0], dtype=torch.int64)
        num_experts_all_zero = counts_all_zero.numel()
        total_tokens_all_zero = int(counts_all_zero.sum().item())

        w1_base_all_zero = torch.randn(num_experts_all_zero, hidden_dim, dim) * 0.02
        w2_base_all_zero = torch.randn(num_experts_all_zero, dim, hidden_dim) * 0.02
        w3_base_all_zero = torch.randn(num_experts_all_zero, hidden_dim, dim) * 0.02

        w1_for_loop_all_zero = w1_base_all_zero.clone().detach().requires_grad_(True)
        w2_for_loop_all_zero = w2_base_all_zero.clone().detach().requires_grad_(True)
        w3_for_loop_all_zero = w3_base_all_zero.clone().detach().requires_grad_(True)

        w1_batched_all_zero = w1_base_all_zero.clone().detach().requires_grad_(True)
        w2_batched_all_zero = w2_base_all_zero.clone().detach().requires_grad_(True)
        w3_batched_all_zero = w3_base_all_zero.clone().detach().requires_grad_(True)

        x_for_loop_all_zero = torch.randn(
            total_tokens_all_zero, dim, requires_grad=True
        )
        x_batched_all_zero = x_for_loop_all_zero.clone().detach().requires_grad_(True)

        out_for_loop_all_zero = _run_experts_for_loop(
            x_for_loop_all_zero,
            w1_for_loop_all_zero,
            w2_for_loop_all_zero,
            w3_for_loop_all_zero,
            counts_all_zero,
        )
        out_batched_all_zero = _run_experts_batched_mm_padded(
            x_batched_all_zero,
            w1_batched_all_zero,
            w2_batched_all_zero,
            w3_batched_all_zero,
            counts_all_zero,
        )

        self.assertTrue(
            torch.allclose(
                out_for_loop_all_zero, out_batched_all_zero, atol=1e-6, rtol=1e-6
            )
        )

        out_for_loop_all_zero.sum().backward()
        out_batched_all_zero.sum().backward()

        self.assertTrue(
            torch.allclose(
                x_for_loop_all_zero.grad,
                x_batched_all_zero.grad,
                atol=1e-6,
                rtol=1e-6,
            )
        )
        self.assertTrue(
            torch.allclose(
                w1_for_loop_all_zero.grad,
                w1_batched_all_zero.grad,
                atol=1e-6,
                rtol=1e-6,
            )
        )
        self.assertTrue(
            torch.allclose(
                w2_for_loop_all_zero.grad,
                w2_batched_all_zero.grad,
                atol=1e-6,
                rtol=1e-6,
            )
        )
        self.assertTrue(
            torch.allclose(
                w3_for_loop_all_zero.grad,
                w3_batched_all_zero.grad,
                atol=1e-6,
                rtol=1e-6,
            )
        )

This edit assumes that _run_experts_for_loop and _run_experts_batched_mm_padded both accept (x, w1, w2, w3, num_tokens_per_expert) in that order, and that they correctly handle the cases where num_tokens_per_expert is empty or sums to zero. If their signatures differ, adjust the call sites accordingly.

If the original implementation of test_batched_mm_padded_matches_for_loop_kernel contained more logic than is visible in the provided snippet, you should merge that existing logic into the new test body rather than replacing it wholesale. The key requirement is that the two new early-return scenarios (empty counts and all-zero counts) are exercised, and that both outputs and gradients from the for-loop and batched backends are compared.

Comment on lines +32 to +41
class MoEExpertBackendsTest(unittest.TestCase):
def test_batched_mm_padded_matches_for_loop_kernel(self) -> None:
torch.manual_seed(1234)
num_experts = 6
dim = 16
hidden_dim = 24
counts = torch.tensor([0, 3, 5, 1, 0, 7], dtype=torch.int64)
total_tokens = int(counts.sum().item())

w1, w2, w3 = _init_weights(num_experts, dim, hidden_dim)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider parameterizing tests over device and dtype to ensure the backend behaves consistently beyond default CPU/float32.

Both tests currently exercise only CPU with float32, so any device- or dtype-specific issues in _run_experts_batched_mm_padded (e.g., CUDA vs CPU, float16/bfloat16 vs float32) may be missed. If your CI supports GPUs, please consider parametrizing over device (CPU and CUDA when available) and/or adding a variant using a lower-precision dtype to validate numerical and autograd behavior against the reference loop.

Suggested implementation:

    dim: int,
    hidden_dim: int,
    *,
    device: torch.device | str = "cpu",
    dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    w1 = torch.randn(num_experts, hidden_dim, dim, device=device, dtype=dtype) * 0.02
    w2 = torch.randn(num_experts, dim, hidden_dim, device=device, dtype=dtype) * 0.02
    w3 = torch.randn(num_experts, hidden_dim, dim, device=device, dtype=dtype) * 0.02
    return w1, w2, w3
class MoEExpertBackendsTest(unittest.TestCase):
    def test_batched_mm_padded_matches_for_loop_kernel(self) -> None:
        devices: list[str] = ["cpu"]
        if torch.cuda.is_available():
            devices.append("cuda")

        # Always include float32; add a lower-precision dtype to exercise
        # numerical and backend behavior beyond the default.
        dtypes: list[torch.dtype] = [torch.float32]
        if hasattr(torch, "float16"):
            dtypes.append(torch.float16)

        for device in devices:
            for dtype in dtypes:
                with self.subTest(device=device, dtype=dtype):
                    torch.manual_seed(1234)
                    num_experts = 6
                    dim = 16
                    hidden_dim = 24
                    counts = torch.tensor(
                        [0, 3, 5, 1, 0, 7],
                        dtype=torch.int64,
                        device=device,
                    )
                    total_tokens = int(counts.sum().item())

                    w1, w2, w3 = _init_weights(
                        num_experts,
                        dim,
                        hidden_dim,
                        device=device,
                        dtype=dtype,
                    )
                    x = torch.randn(total_tokens, dim, device=device, dtype=dtype)

If the remainder of test_batched_mm_padded_matches_for_loop_kernel (not shown in the snippet) creates additional tensors (e.g., reference outputs, indices, or masks), they should also be constructed with device=device and dtype=dtype (or at least device=device, and dtype chosen consistently with the values being compared) to avoid device/dtype mismatches and to ensure the parameterization fully exercises the backend on the selected configuration.

Comment thread torchtitan/models/common/moe.py
@saforem2

Copy link
Copy Markdown
Owner

@samuelwheeler this is excellent, my only concern would be the changes to things outside of the torchtitan/experiments/ezpz/ directory

I think it should be possible to integrate these changes into there directly; claude had this to say:

Restructure (everything inside experiments/ezpz/):

torchtitan/experiments/ezpz/moe/
├── experts.py          # NEW: EzpzGroupedExperts(GroupedExperts) + batched_mm_padded
├── __init__.py         # add make_experts_config_ezpz wrapper
├── config_registry.py  # 10B_2B_sdpa_batched_mm_padded uses the wrapper
└── ...

torchtitan/experiments/ezpz/tests/
└── test_moe_expert_backends.py   # MOVED from tests/unit_tests/

torchtitan/experiments/ezpz/moe/experts.py:

from dataclasses import dataclass
from typing import Literal

from torch.distributed.tensor import DTensor
from torchtitan.models.common.moe import GroupedExperts

ExpertComputeBackend = Literal["for_loop", "grouped_mm", "batched_mm_padded"]


def _run_experts_batched_mm_padded(w1, w2, w3, x, num_tokens_per_expert):
    # ... same body as in this PR
    ...


class EzpzGroupedExperts(GroupedExperts):
    @dataclass(kw_only=True, slots=True)
    class Config(GroupedExperts.Config):
        compute_backend: ExpertComputeBackend | None = None

    def __init__(self, config: Config):
        super().__init__(config)
        self.compute_backend = (
            config.compute_backend
            if config.compute_backend is not None
            else ("grouped_mm" if config.use_grouped_mm else "for_loop")
        )

    def _experts_forward(self, x, num_tokens_per_expert):
        if self.compute_backend != "batched_mm_padded":
            return super()._experts_forward(x, num_tokens_per_expert)

        if isinstance(self.w1, DTensor):
            w1, w2, w3 = self.w1.to_local(), self.w2.to_local(), self.w3.to_local()
        else:
            w1, w2, w3 = self.w1, self.w2, self.w3
        return _run_experts_batched_mm_padded(w1, w2, w3, x, num_tokens_per_expert)

Why this works:

  • Zero changes to torchtitan/models/common/
  • Every other MoE model (deepseek_v3, qwen3, llama4, …) keeps using upstream's GroupedExperts unchanged
  • EzpzGroupedExperts inherits Config + __init__ + forward; only overrides _experts_forward to dispatch to the new backend, falling back to super() for grouped_mm / for_loop
  • No upstream-merge conflicts on moe.py going forward
  • The numel() == 0 guard added to _run_experts_for_loop is unrelated to the new backend — if the empty-expert case is real, it's worth its own small PR upstream so every model benefits

@saforem2

Copy link
Copy Markdown
Owner

Heads up — upstream just landed pytorch/torchtitan#3308 (b301dfa0, "Remove MoE expert for-loop fallback") which removes use_grouped_mm from GroupedExperts.Config entirely and inlines _grouped_mm as the only path. So this PR will no longer cleanly merge after the next upstream sync — both the use_grouped_mm attribute and the _run_experts_for_loop reference are gone from models/common/moe.py.

This actually makes the subclass approach more attractive, not less:

  1. The use_grouped_mm legacy fallback in this PR's __init__ is now dead code upstreamEzpzGroupedExperts.Config would just have compute_backend: ExpertComputeBackend = "grouped_mm" (no None / no legacy bridge needed).

  2. We actually need a subclass anyway for XPU. Our experiments/ezpz/moe/model.py:226-235 was using use_grouped_mm = False as the XPU (pre-SM90) fallback escape hatch. Upstream's claim is that torch._grouped_mm has its own CUDA fallback — but XPU isn't CUDA, and we have no validation either way. So an EzpzGroupedExperts with a for-loop backend opt-in solves both problems in one place: your new batched_mm_padded lives there, and we get our XPU for-loop fallback back.

  3. The _run_experts_for_loop body that was just deleted from upstream can be copied verbatim into experiments/ezpz/moe/experts.py — it'll be ezpz-owned forever, won't drift with upstream changes, and gives us a known-working path on XPU.

Updated sketch:

# torchtitan/experiments/ezpz/moe/experts.py
ExpertComputeBackend = Literal["for_loop", "grouped_mm", "batched_mm_padded"]

def _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert):
    # copy from the just-deleted upstream version
    ...

def _run_experts_batched_mm_padded(w1, w2, w3, x, num_tokens_per_expert):
    # body from this PR
    ...

class EzpzGroupedExperts(GroupedExperts):
    @dataclass(kw_only=True, slots=True)
    class Config(GroupedExperts.Config):
        compute_backend: ExpertComputeBackend = "grouped_mm"

    def __init__(self, config: Config):
        super().__init__(config)
        self.compute_backend = config.compute_backend

    def _experts_forward(self, x, num_tokens_per_expert):
        if self.compute_backend == "grouped_mm":
            return super()._experts_forward(x, num_tokens_per_expert)
        # local DTensor unwrap + dispatch to for_loop / batched_mm_padded
        ...

I held off pulling this latest upstream batch on ezpz until we have a plan, since the XPU fallback removal would break our MoE runs on Aurora.

@saforem2

Copy link
Copy Markdown
Owner

Folded into #12. The new backend lives at torchtitan/experiments/ezpz/moe/experts.py as EzpzGroupedExperts(GroupedExperts) with a compute_backend selector (grouped_mm / for_loop / batched_mm_padded); the new flavor is registered as 10B_2B_sdpa_batched_mm_padded. Zero core edits — pytorch#3308 actually made this much cleaner since the legacy use_grouped_mm bridge is no longer needed. Tests moved to experiments/ezpz/tests/. Going to keep this open until #12 lands.

@saforem2

Copy link
Copy Markdown
Owner

FYI — #13 just landed an EzpzGroupedExperts(GroupedExperts) subclass in experiments/ezpz/moe/experts.py to handle pytorch#3308 (use_grouped_mm field deletion). It currently exposes compute_backend: Literal["for_loop", "grouped_mm"].

When you rebase onto the merged ezpz, the cleanest landing path for this PR is probably:

  1. Add "batched_mm_padded" as a third literal in EzpzGroupedExperts.ExpertComputeBackend
  2. Add a third branch in EzpzGroupedExperts._experts_forward that calls your _run_experts_batched_mm_padded helper (vendored into experts.py next to the existing _run_experts_for_loop)
  3. Move the new 10B_2B_sdpa_batched_mm_padded flavor to use make_ezpz_experts_config(..., compute_backend="batched_mm_padded")
  4. Move the unit tests from tests/unit_tests/ to torchtitan/experiments/ezpz/tests/

That keeps everything in experiments/ezpz/ and avoids the models/common/moe.py and config_utils.py edits this PR currently needs. Withdrawing the prior #12 which had attempted this; this can stay open as the canonical home.

nscottnichols pushed a commit to nscottnichols/torchtitan that referenced this pull request Jun 8, 2026
…ch#3308

Replays the pytorch#3308 ("Remove MoE expert for-loop
fallback") deletion onto experiments/ezpz/moe/. That upstream PR
removed the `use_grouped_mm` config field from `GroupedExperts.Config`
and inlined `torch._grouped_mm` as the only expert path. Upstream's
argument is that `_grouped_mm` already provides a CUDA fallback on
pre-SM90 hardware. **XPU has no `_grouped_mm` kernel at all**, so the
unconditional path breaks ezpz on Aurora / Sunspot.

- `experiments/ezpz/moe/experts.py` (new): `EzpzGroupedExperts` subclass
  with `compute_backend: Literal["for_loop", "grouped_mm"]`. Default
  defers to upstream; `"for_loop"` re-vendors the per-expert matmul
  loop that pytorch#3308 deleted, restoring the XPU / pre-SM90 path.
- `experiments/ezpz/moe/__init__.py`: `make_ezpz_experts_config(...)`
  wrapper that calls upstream's `make_experts_config(...)` and
  re-wraps as `EzpzGroupedExperts.Config`. `_build_moe_layers` threads
  a `compute_backend` kwarg (default `"grouped_mm"`).
- `experiments/ezpz/moe/model.py`: `update_from_config` previously
  mutated `experts.use_grouped_mm = False` on pre-SM90 devices; that
  field no longer exists. Replaced with
  `experts_cfg.compute_backend = "for_loop"` guarded by a
  `getattr(..., "compute_backend", "grouped_mm")` so the block is
  robust to future config-shape changes.

Default behavior is unchanged for SM90+ CUDA: `compute_backend` defaults
to `"grouped_mm"`, which calls `super()._experts_forward(...)` and
hits the upstream path. The for_loop branch is only taken on devices
without `_grouped_mm`.

Pending PRs saforem2#9 / saforem2#10 / saforem2#11 will need to rebase onto this and adapt to
the `EzpzGroupedExperts` subclass — see docs/upstream-sync.md for
notes on what each will need to do.
saforem2 added a commit to nscottnichols/torchtitan that referenced this pull request Jun 11, 2026
690-line file with zero non-self importers. ~560 lines are
byte-similar to torchtitan/components/metrics.py (DeviceMemoryMonitor,
BaseLogger, TensorBoardLogger, WandBLogger, LoggerContainer,
ensure_pp_loss_visible, _get_metrics_rank, MetricsProcessor — all
duplicated). Genuine new code is ~130 lines: TT_MOE_DEBUG_MEMORY
helpers + get_current_stats() + _MOE_FASTPATH_COUNTERS integration
in MetricsProcessor.log().

Verified zero importers:
    grep -rn "EzpzMetricsProcessor\|ezpz.metrics\|TT_MOE_DEBUG_MEMORY" torchtitan/

_MOE_FASTPATH_COUNTERS is only referenced from within
ezpz/moe/token_dispatcher.py via its own _record_moe_fastpath helper,
which keeps the counters in-process. With this metrics.py gone the
counters still increment; they just don't surface to W&B/TB unless
the trainer is updated to read them — same state as today, since
nothing wires this metrics.py in.

Re-add as a thin MetricsProcessor subclass overriding only log() /
should_log() when there's a consumer in train.py / trainer.py.
The TT_MOE_DEBUG_MEMORY knobs are model-agnostic and should
probably land upstream on MetricsProcessor.Config instead.

Addresses review finding saforem2#10 on saforem2#14.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants