Add padded batched matmul MoE expert backend#10
Conversation
Reviewer's GuideIntroduces a new padded batched matmul expert compute backend for MoE GroupedExperts, wires it into configuration and ezpz experiment presets, and adds unit tests to verify numerical and gradient equivalence with the existing for-loop backend while preserving existing defaults. Flow diagram for configuring batched_mm_padded MoE expert backendflowchart LR
A[m oe_10b_2b_sdpa_batched_mm_padded] --> B[_10b_2b_sdpa_batched_mm_padded]
B --> C[set layer_cfg.moe.experts.compute_backend = batched_mm_padded]
C --> D[make_experts_config<br>compute_backend = batched_mm_padded]
D --> E[GroupedExperts.Config<br>compute_backend = batched_mm_padded]
E --> F[GroupedExperts<br>compute_backend == batched_mm_padded]
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 4 issues, and left some high level feedback:
- The interaction between
use_grouped_mmand the newcompute_backendis a bit confusing: the constructor silently derivescompute_backendfromuse_grouped_mm, but downstream configs (e.g._10b_2b_sdpa_batched_mm_padded) now set both; consider either deprecatinguse_grouped_mmin favor ofcompute_backendor asserting they don't conflict to avoid surprising behavior. - In
_run_experts_batched_mm_padded, the expert and token index computation is non-trivial; adding a small helper or inline comment about the expected layout ofx(e.g., that tokens are grouped by expert andcountsis in that order) would make future maintenance and refactors less error-prone.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The interaction between `use_grouped_mm` and the new `compute_backend` is a bit confusing: the constructor silently derives `compute_backend` from `use_grouped_mm`, but downstream configs (e.g. `_10b_2b_sdpa_batched_mm_padded`) now set both; consider either deprecating `use_grouped_mm` in favor of `compute_backend` or asserting they don't conflict to avoid surprising behavior.
- In `_run_experts_batched_mm_padded`, the expert and token index computation is non-trivial; adding a small helper or inline comment about the expected layout of `x` (e.g., that tokens are grouped by expert and `counts` is in that order) would make future maintenance and refactors less error-prone.
## Individual Comments
### Comment 1
<location path="torchtitan/experiments/ezpz/moe/__init__.py" line_range="947-948" />
<code_context>
+ for layer_cfg in cfg.layers:
+ if layer_cfg.moe is None:
+ continue
+ layer_cfg.moe.experts.compute_backend = "batched_mm_padded"
+ layer_cfg.moe.experts.use_grouped_mm = False
+ return cfg
+
</code_context>
<issue_to_address>
**suggestion:** Setting both `compute_backend` and `use_grouped_mm` is redundant now that backend drives the flag.
Since `use_grouped_mm` is already derived from `compute_backend` at initialization, explicitly setting it here is redundant and may drift from the initialization logic over time. Please rely on `compute_backend = "batched_mm_padded"` alone and remove the explicit `use_grouped_mm` assignment to keep this configuration consistent and minimal.
</issue_to_address>
### Comment 2
<location path="tests/unit_tests/test_moe_expert_backends.py" line_range="33-42" />
<code_context>
+ def test_batched_mm_padded_matches_for_loop_kernel(self) -> None:
</code_context>
<issue_to_address>
**suggestion (testing):** Add explicit tests for empty and all-zero `num_tokens_per_expert` to cover early-return branches in `_run_experts_batched_mm_padded`.
The new backend has two early-return branches (`counts.numel() == 0` and `max_tokens == 0`), but the current test with `counts = [0, 3, 5, 1, 0, 7]` never hits them. Please add small tests that cover:
1. No experts (empty `num_tokens_per_expert`), and
2. All experts with zero tokens (e.g., `counts = [0, 0, 0]`).
For each, call `_run_experts_for_loop` and `_run_experts_batched_mm_padded` with the same inputs and assert that outputs (and ideally gradients) match, so these early-return paths are covered and protected against regressions.
Suggested implementation:
```python
class MoEExpertBackendsTest(unittest.TestCase):
def test_batched_mm_padded_matches_for_loop_kernel(self) -> None:
torch.manual_seed(1234)
dim = 16
hidden_dim = 24
# Baseline case (existing coverage)
num_experts = 6
counts = torch.tensor([0, 3, 5, 1, 0, 7], dtype=torch.int64)
total_tokens = int(counts.sum().item())
w1_base = torch.randn(num_experts, hidden_dim, dim) * 0.02
w2_base = torch.randn(num_experts, dim, hidden_dim) * 0.02
w3_base = torch.randn(num_experts, hidden_dim, dim) * 0.02
w1_for_loop = w1_base.clone().detach().requires_grad_(True)
w2_for_loop = w2_base.clone().detach().requires_grad_(True)
w3_for_loop = w3_base.clone().detach().requires_grad_(True)
w1_batched = w1_base.clone().detach().requires_grad_(True)
w2_batched = w2_base.clone().detach().requires_grad_(True)
w3_batched = w3_base.clone().detach().requires_grad_(True)
x_for_loop = torch.randn(total_tokens, dim, requires_grad=True)
x_batched = x_for_loop.clone().detach().requires_grad_(True)
out_for_loop = _run_experts_for_loop(
x_for_loop, w1_for_loop, w2_for_loop, w3_for_loop, counts
)
out_batched = _run_experts_batched_mm_padded(
x_batched, w1_batched, w2_batched, w3_batched, counts
)
self.assertTrue(torch.allclose(out_for_loop, out_batched, atol=1e-6, rtol=1e-6))
out_for_loop.sum().backward()
out_batched.sum().backward()
self.assertTrue(torch.allclose(x_for_loop.grad, x_batched.grad, atol=1e-6, rtol=1e-6))
self.assertTrue(torch.allclose(w1_for_loop.grad, w1_batched.grad, atol=1e-6, rtol=1e-6))
self.assertTrue(torch.allclose(w2_for_loop.grad, w2_batched.grad, atol=1e-6, rtol=1e-6))
self.assertTrue(torch.allclose(w3_for_loop.grad, w3_batched.grad, atol=1e-6, rtol=1e-6))
# Early-return case 1: no experts (empty counts; counts.numel() == 0)
counts_empty = torch.empty(0, dtype=torch.int64)
num_experts_empty = 0
total_tokens_empty = int(counts_empty.sum().item())
w1_base_empty = torch.randn(num_experts_empty, hidden_dim, dim) * 0.02
w2_base_empty = torch.randn(num_experts_empty, dim, hidden_dim) * 0.02
w3_base_empty = torch.randn(num_experts_empty, hidden_dim, dim) * 0.02
w1_for_loop_empty = w1_base_empty.clone().detach().requires_grad_(True)
w2_for_loop_empty = w2_base_empty.clone().detach().requires_grad_(True)
w3_for_loop_empty = w3_base_empty.clone().detach().requires_grad_(True)
w1_batched_empty = w1_base_empty.clone().detach().requires_grad_(True)
w2_batched_empty = w2_base_empty.clone().detach().requires_grad_(True)
w3_batched_empty = w3_base_empty.clone().detach().requires_grad_(True)
x_for_loop_empty = torch.randn(total_tokens_empty, dim, requires_grad=True)
x_batched_empty = x_for_loop_empty.clone().detach().requires_grad_(True)
out_for_loop_empty = _run_experts_for_loop(
x_for_loop_empty,
w1_for_loop_empty,
w2_for_loop_empty,
w3_for_loop_empty,
counts_empty,
)
out_batched_empty = _run_experts_batched_mm_padded(
x_batched_empty,
w1_batched_empty,
w2_batched_empty,
w3_batched_empty,
counts_empty,
)
self.assertTrue(
torch.allclose(out_for_loop_empty, out_batched_empty, atol=1e-6, rtol=1e-6)
)
out_for_loop_empty.sum().backward()
out_batched_empty.sum().backward()
self.assertTrue(
torch.allclose(
x_for_loop_empty.grad, x_batched_empty.grad, atol=1e-6, rtol=1e-6
)
)
self.assertTrue(
torch.allclose(
w1_for_loop_empty.grad, w1_batched_empty.grad, atol=1e-6, rtol=1e-6
)
)
self.assertTrue(
torch.allclose(
w2_for_loop_empty.grad, w2_batched_empty.grad, atol=1e-6, rtol=1e-6
)
)
self.assertTrue(
torch.allclose(
w3_for_loop_empty.grad, w3_batched_empty.grad, atol=1e-6, rtol=1e-6
)
)
# Early-return case 2: all experts with zero tokens (max_tokens == 0)
counts_all_zero = torch.tensor([0, 0, 0], dtype=torch.int64)
num_experts_all_zero = counts_all_zero.numel()
total_tokens_all_zero = int(counts_all_zero.sum().item())
w1_base_all_zero = torch.randn(num_experts_all_zero, hidden_dim, dim) * 0.02
w2_base_all_zero = torch.randn(num_experts_all_zero, dim, hidden_dim) * 0.02
w3_base_all_zero = torch.randn(num_experts_all_zero, hidden_dim, dim) * 0.02
w1_for_loop_all_zero = w1_base_all_zero.clone().detach().requires_grad_(True)
w2_for_loop_all_zero = w2_base_all_zero.clone().detach().requires_grad_(True)
w3_for_loop_all_zero = w3_base_all_zero.clone().detach().requires_grad_(True)
w1_batched_all_zero = w1_base_all_zero.clone().detach().requires_grad_(True)
w2_batched_all_zero = w2_base_all_zero.clone().detach().requires_grad_(True)
w3_batched_all_zero = w3_base_all_zero.clone().detach().requires_grad_(True)
x_for_loop_all_zero = torch.randn(
total_tokens_all_zero, dim, requires_grad=True
)
x_batched_all_zero = x_for_loop_all_zero.clone().detach().requires_grad_(True)
out_for_loop_all_zero = _run_experts_for_loop(
x_for_loop_all_zero,
w1_for_loop_all_zero,
w2_for_loop_all_zero,
w3_for_loop_all_zero,
counts_all_zero,
)
out_batched_all_zero = _run_experts_batched_mm_padded(
x_batched_all_zero,
w1_batched_all_zero,
w2_batched_all_zero,
w3_batched_all_zero,
counts_all_zero,
)
self.assertTrue(
torch.allclose(
out_for_loop_all_zero, out_batched_all_zero, atol=1e-6, rtol=1e-6
)
)
out_for_loop_all_zero.sum().backward()
out_batched_all_zero.sum().backward()
self.assertTrue(
torch.allclose(
x_for_loop_all_zero.grad,
x_batched_all_zero.grad,
atol=1e-6,
rtol=1e-6,
)
)
self.assertTrue(
torch.allclose(
w1_for_loop_all_zero.grad,
w1_batched_all_zero.grad,
atol=1e-6,
rtol=1e-6,
)
)
self.assertTrue(
torch.allclose(
w2_for_loop_all_zero.grad,
w2_batched_all_zero.grad,
atol=1e-6,
rtol=1e-6,
)
)
self.assertTrue(
torch.allclose(
w3_for_loop_all_zero.grad,
w3_batched_all_zero.grad,
atol=1e-6,
rtol=1e-6,
)
)
```
This edit assumes that `_run_experts_for_loop` and `_run_experts_batched_mm_padded` both accept `(x, w1, w2, w3, num_tokens_per_expert)` in that order, and that they correctly handle the cases where `num_tokens_per_expert` is empty or sums to zero. If their signatures differ, adjust the call sites accordingly.
If the original implementation of `test_batched_mm_padded_matches_for_loop_kernel` contained more logic than is visible in the provided snippet, you should merge that existing logic into the new test body rather than replacing it wholesale. The key requirement is that the two new early-return scenarios (empty counts and all-zero counts) are exercised, and that both outputs and gradients from the for-loop and batched backends are compared.
</issue_to_address>
### Comment 3
<location path="tests/unit_tests/test_moe_expert_backends.py" line_range="32-41" />
<code_context>
+class MoEExpertBackendsTest(unittest.TestCase):
</code_context>
<issue_to_address>
**suggestion (testing):** Consider parameterizing tests over device and dtype to ensure the backend behaves consistently beyond default CPU/float32.
Both tests currently exercise only CPU with `float32`, so any device- or dtype-specific issues in `_run_experts_batched_mm_padded` (e.g., CUDA vs CPU, `float16`/`bfloat16` vs `float32`) may be missed. If your CI supports GPUs, please consider parametrizing over `device` (CPU and CUDA when available) and/or adding a variant using a lower-precision dtype to validate numerical and autograd behavior against the reference loop.
Suggested implementation:
```python
dim: int,
hidden_dim: int,
*,
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
w1 = torch.randn(num_experts, hidden_dim, dim, device=device, dtype=dtype) * 0.02
w2 = torch.randn(num_experts, dim, hidden_dim, device=device, dtype=dtype) * 0.02
w3 = torch.randn(num_experts, hidden_dim, dim, device=device, dtype=dtype) * 0.02
return w1, w2, w3
```
```python
class MoEExpertBackendsTest(unittest.TestCase):
def test_batched_mm_padded_matches_for_loop_kernel(self) -> None:
devices: list[str] = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
# Always include float32; add a lower-precision dtype to exercise
# numerical and backend behavior beyond the default.
dtypes: list[torch.dtype] = [torch.float32]
if hasattr(torch, "float16"):
dtypes.append(torch.float16)
for device in devices:
for dtype in dtypes:
with self.subTest(device=device, dtype=dtype):
torch.manual_seed(1234)
num_experts = 6
dim = 16
hidden_dim = 24
counts = torch.tensor(
[0, 3, 5, 1, 0, 7],
dtype=torch.int64,
device=device,
)
total_tokens = int(counts.sum().item())
w1, w2, w3 = _init_weights(
num_experts,
dim,
hidden_dim,
device=device,
dtype=dtype,
)
x = torch.randn(total_tokens, dim, device=device, dtype=dtype)
```
If the remainder of `test_batched_mm_padded_matches_for_loop_kernel` (not shown in the snippet) creates additional tensors (e.g., reference outputs, indices, or masks), they should also be constructed with `device=device` and `dtype=dtype` (or at least `device=device`, and `dtype` chosen consistently with the values being compared) to avoid device/dtype mismatches and to ensure the parameterization fully exercises the backend on the selected configuration.
</issue_to_address>
### Comment 4
<location path="torchtitan/models/common/moe.py" line_range="139" />
<code_context>
torch.empty(config.num_experts, config.hidden_dim, config.dim)
)
- self.use_grouped_mm = config.use_grouped_mm
+ self.compute_backend: ExpertComputeBackend = (
+ config.compute_backend
+ if config.compute_backend is not None
</code_context>
<issue_to_address>
**issue (complexity):** Consider simplifying the expert compute configuration and dispatch logic and refactoring the padded backend to isolate indexing math for better readability and maintainability.
You can simplify the configuration and dispatch layer and slightly decompose the padded backend without changing behavior.
### 1. Remove redundant `use_grouped_mm` instance state
Right now `self.compute_backend` is the real source of truth, and `self.use_grouped_mm` is derived from it but still kept as state. You can normalize legacy configs once and then drop the extra attribute:
```python
@dataclass(kw_only=True, slots=True)
class Config(Module.Config):
dim: int
hidden_dim: int
num_experts: int
# keep for backward-compat config, but treat as deprecated
use_grouped_mm: bool = True
compute_backend: ExpertComputeBackend | None = None
token_dispatcher: LocalTokenDispatcher.Config
def __init__(self, config: Config):
super().__init__()
...
if config.compute_backend is not None:
backend: ExpertComputeBackend = config.compute_backend
else:
backend = "grouped_mm" if config.use_grouped_mm else "for_loop"
self.compute_backend: ExpertComputeBackend = backend
# drop: self.use_grouped_mm = ...
self.token_dispatcher = config.token_dispatcher.build()
```
This removes redundant state and ensures all downstream logic only needs to look at `self.compute_backend`.
### 2. Centralize backend dispatch
The current sequence of independent `if` statements plus a final `ValueError` is a bit brittle. A small dispatch map keeps the logic flat and makes adding/removing backends less error-prone:
```python
def _experts_forward(
self,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
...
backends: dict[ExpertComputeBackend, Callable[..., torch.Tensor]] = {
"grouped_mm": _run_experts_grouped_mm,
"batched_mm_padded": _run_experts_batched_mm_padded,
"for_loop": _run_experts_for_loop,
}
try:
fn = backends[self.compute_backend]
except KeyError:
raise ValueError(f"Unknown expert compute backend: {self.compute_backend}")
return fn(w1, w2, w3, x, num_tokens_per_expert)
```
This keeps behavior identical while making the branching logic clearer.
### 3. Extract the index layout logic in `_run_experts_batched_mm_padded`
The index math is the hardest part to reason about. Pulling it into a small helper gives `_run_experts_batched_mm_padded` a more “algorithmic” shape:
```python
def _compute_expert_layout(
counts: torch.Tensor, total_tokens: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# counts: [num_experts]
offsets = counts.cumsum(0) - counts
expert_indices = torch.repeat_interleave(
torch.arange(counts.numel(), device=device, dtype=torch.int64),
counts,
)
token_indices_within_expert = torch.arange(
total_tokens, device=device, dtype=torch.int64
) - torch.repeat_interleave(offsets, counts)
return offsets, expert_indices, token_indices_within_expert
```
Then `_run_experts_batched_mm_padded` becomes:
```python
def _run_experts_batched_mm_padded(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
counts = num_tokens_per_expert.to(device=x.device, dtype=torch.int64)
if counts.numel() == 0 or counts.max().item() == 0:
return x.new_empty((0, w2.shape[1]))
total_tokens = x.shape[0]
device = x.device
max_tokens = int(counts.max().item())
_, expert_indices, token_indices_within_expert = _compute_expert_layout(
counts, total_tokens, device
)
padded_x = x.new_zeros((counts.numel(), max_tokens, x.shape[-1]))
padded_x[expert_indices, token_indices_within_expert] = x
h = F.silu(torch.bmm(padded_x, w1.transpose(-2, -1)))
h = h * torch.bmm(padded_x, w3.transpose(-2, -1))
out_padded = torch.bmm(h, w2.transpose(-2, -1))
return out_padded[expert_indices, token_indices_within_expert]
```
This keeps all behavior and shapes intact but isolates the tricky layout logic and simplifies the main computation path.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| def test_batched_mm_padded_matches_for_loop_kernel(self) -> None: | ||
| torch.manual_seed(1234) | ||
| num_experts = 6 | ||
| dim = 16 | ||
| hidden_dim = 24 | ||
| counts = torch.tensor([0, 3, 5, 1, 0, 7], dtype=torch.int64) | ||
| total_tokens = int(counts.sum().item()) | ||
|
|
||
| w1, w2, w3 = _init_weights(num_experts, dim, hidden_dim) | ||
| x = torch.randn(total_tokens, dim) |
There was a problem hiding this comment.
suggestion (testing): Add explicit tests for empty and all-zero num_tokens_per_expert to cover early-return branches in _run_experts_batched_mm_padded.
The new backend has two early-return branches (counts.numel() == 0 and max_tokens == 0), but the current test with counts = [0, 3, 5, 1, 0, 7] never hits them. Please add small tests that cover:
- No experts (empty
num_tokens_per_expert), and - All experts with zero tokens (e.g.,
counts = [0, 0, 0]).
For each, call _run_experts_for_loop and _run_experts_batched_mm_padded with the same inputs and assert that outputs (and ideally gradients) match, so these early-return paths are covered and protected against regressions.
Suggested implementation:
class MoEExpertBackendsTest(unittest.TestCase):
def test_batched_mm_padded_matches_for_loop_kernel(self) -> None:
torch.manual_seed(1234)
dim = 16
hidden_dim = 24
# Baseline case (existing coverage)
num_experts = 6
counts = torch.tensor([0, 3, 5, 1, 0, 7], dtype=torch.int64)
total_tokens = int(counts.sum().item())
w1_base = torch.randn(num_experts, hidden_dim, dim) * 0.02
w2_base = torch.randn(num_experts, dim, hidden_dim) * 0.02
w3_base = torch.randn(num_experts, hidden_dim, dim) * 0.02
w1_for_loop = w1_base.clone().detach().requires_grad_(True)
w2_for_loop = w2_base.clone().detach().requires_grad_(True)
w3_for_loop = w3_base.clone().detach().requires_grad_(True)
w1_batched = w1_base.clone().detach().requires_grad_(True)
w2_batched = w2_base.clone().detach().requires_grad_(True)
w3_batched = w3_base.clone().detach().requires_grad_(True)
x_for_loop = torch.randn(total_tokens, dim, requires_grad=True)
x_batched = x_for_loop.clone().detach().requires_grad_(True)
out_for_loop = _run_experts_for_loop(
x_for_loop, w1_for_loop, w2_for_loop, w3_for_loop, counts
)
out_batched = _run_experts_batched_mm_padded(
x_batched, w1_batched, w2_batched, w3_batched, counts
)
self.assertTrue(torch.allclose(out_for_loop, out_batched, atol=1e-6, rtol=1e-6))
out_for_loop.sum().backward()
out_batched.sum().backward()
self.assertTrue(torch.allclose(x_for_loop.grad, x_batched.grad, atol=1e-6, rtol=1e-6))
self.assertTrue(torch.allclose(w1_for_loop.grad, w1_batched.grad, atol=1e-6, rtol=1e-6))
self.assertTrue(torch.allclose(w2_for_loop.grad, w2_batched.grad, atol=1e-6, rtol=1e-6))
self.assertTrue(torch.allclose(w3_for_loop.grad, w3_batched.grad, atol=1e-6, rtol=1e-6))
# Early-return case 1: no experts (empty counts; counts.numel() == 0)
counts_empty = torch.empty(0, dtype=torch.int64)
num_experts_empty = 0
total_tokens_empty = int(counts_empty.sum().item())
w1_base_empty = torch.randn(num_experts_empty, hidden_dim, dim) * 0.02
w2_base_empty = torch.randn(num_experts_empty, dim, hidden_dim) * 0.02
w3_base_empty = torch.randn(num_experts_empty, hidden_dim, dim) * 0.02
w1_for_loop_empty = w1_base_empty.clone().detach().requires_grad_(True)
w2_for_loop_empty = w2_base_empty.clone().detach().requires_grad_(True)
w3_for_loop_empty = w3_base_empty.clone().detach().requires_grad_(True)
w1_batched_empty = w1_base_empty.clone().detach().requires_grad_(True)
w2_batched_empty = w2_base_empty.clone().detach().requires_grad_(True)
w3_batched_empty = w3_base_empty.clone().detach().requires_grad_(True)
x_for_loop_empty = torch.randn(total_tokens_empty, dim, requires_grad=True)
x_batched_empty = x_for_loop_empty.clone().detach().requires_grad_(True)
out_for_loop_empty = _run_experts_for_loop(
x_for_loop_empty,
w1_for_loop_empty,
w2_for_loop_empty,
w3_for_loop_empty,
counts_empty,
)
out_batched_empty = _run_experts_batched_mm_padded(
x_batched_empty,
w1_batched_empty,
w2_batched_empty,
w3_batched_empty,
counts_empty,
)
self.assertTrue(
torch.allclose(out_for_loop_empty, out_batched_empty, atol=1e-6, rtol=1e-6)
)
out_for_loop_empty.sum().backward()
out_batched_empty.sum().backward()
self.assertTrue(
torch.allclose(
x_for_loop_empty.grad, x_batched_empty.grad, atol=1e-6, rtol=1e-6
)
)
self.assertTrue(
torch.allclose(
w1_for_loop_empty.grad, w1_batched_empty.grad, atol=1e-6, rtol=1e-6
)
)
self.assertTrue(
torch.allclose(
w2_for_loop_empty.grad, w2_batched_empty.grad, atol=1e-6, rtol=1e-6
)
)
self.assertTrue(
torch.allclose(
w3_for_loop_empty.grad, w3_batched_empty.grad, atol=1e-6, rtol=1e-6
)
)
# Early-return case 2: all experts with zero tokens (max_tokens == 0)
counts_all_zero = torch.tensor([0, 0, 0], dtype=torch.int64)
num_experts_all_zero = counts_all_zero.numel()
total_tokens_all_zero = int(counts_all_zero.sum().item())
w1_base_all_zero = torch.randn(num_experts_all_zero, hidden_dim, dim) * 0.02
w2_base_all_zero = torch.randn(num_experts_all_zero, dim, hidden_dim) * 0.02
w3_base_all_zero = torch.randn(num_experts_all_zero, hidden_dim, dim) * 0.02
w1_for_loop_all_zero = w1_base_all_zero.clone().detach().requires_grad_(True)
w2_for_loop_all_zero = w2_base_all_zero.clone().detach().requires_grad_(True)
w3_for_loop_all_zero = w3_base_all_zero.clone().detach().requires_grad_(True)
w1_batched_all_zero = w1_base_all_zero.clone().detach().requires_grad_(True)
w2_batched_all_zero = w2_base_all_zero.clone().detach().requires_grad_(True)
w3_batched_all_zero = w3_base_all_zero.clone().detach().requires_grad_(True)
x_for_loop_all_zero = torch.randn(
total_tokens_all_zero, dim, requires_grad=True
)
x_batched_all_zero = x_for_loop_all_zero.clone().detach().requires_grad_(True)
out_for_loop_all_zero = _run_experts_for_loop(
x_for_loop_all_zero,
w1_for_loop_all_zero,
w2_for_loop_all_zero,
w3_for_loop_all_zero,
counts_all_zero,
)
out_batched_all_zero = _run_experts_batched_mm_padded(
x_batched_all_zero,
w1_batched_all_zero,
w2_batched_all_zero,
w3_batched_all_zero,
counts_all_zero,
)
self.assertTrue(
torch.allclose(
out_for_loop_all_zero, out_batched_all_zero, atol=1e-6, rtol=1e-6
)
)
out_for_loop_all_zero.sum().backward()
out_batched_all_zero.sum().backward()
self.assertTrue(
torch.allclose(
x_for_loop_all_zero.grad,
x_batched_all_zero.grad,
atol=1e-6,
rtol=1e-6,
)
)
self.assertTrue(
torch.allclose(
w1_for_loop_all_zero.grad,
w1_batched_all_zero.grad,
atol=1e-6,
rtol=1e-6,
)
)
self.assertTrue(
torch.allclose(
w2_for_loop_all_zero.grad,
w2_batched_all_zero.grad,
atol=1e-6,
rtol=1e-6,
)
)
self.assertTrue(
torch.allclose(
w3_for_loop_all_zero.grad,
w3_batched_all_zero.grad,
atol=1e-6,
rtol=1e-6,
)
)This edit assumes that _run_experts_for_loop and _run_experts_batched_mm_padded both accept (x, w1, w2, w3, num_tokens_per_expert) in that order, and that they correctly handle the cases where num_tokens_per_expert is empty or sums to zero. If their signatures differ, adjust the call sites accordingly.
If the original implementation of test_batched_mm_padded_matches_for_loop_kernel contained more logic than is visible in the provided snippet, you should merge that existing logic into the new test body rather than replacing it wholesale. The key requirement is that the two new early-return scenarios (empty counts and all-zero counts) are exercised, and that both outputs and gradients from the for-loop and batched backends are compared.
| class MoEExpertBackendsTest(unittest.TestCase): | ||
| def test_batched_mm_padded_matches_for_loop_kernel(self) -> None: | ||
| torch.manual_seed(1234) | ||
| num_experts = 6 | ||
| dim = 16 | ||
| hidden_dim = 24 | ||
| counts = torch.tensor([0, 3, 5, 1, 0, 7], dtype=torch.int64) | ||
| total_tokens = int(counts.sum().item()) | ||
|
|
||
| w1, w2, w3 = _init_weights(num_experts, dim, hidden_dim) |
There was a problem hiding this comment.
suggestion (testing): Consider parameterizing tests over device and dtype to ensure the backend behaves consistently beyond default CPU/float32.
Both tests currently exercise only CPU with float32, so any device- or dtype-specific issues in _run_experts_batched_mm_padded (e.g., CUDA vs CPU, float16/bfloat16 vs float32) may be missed. If your CI supports GPUs, please consider parametrizing over device (CPU and CUDA when available) and/or adding a variant using a lower-precision dtype to validate numerical and autograd behavior against the reference loop.
Suggested implementation:
dim: int,
hidden_dim: int,
*,
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
w1 = torch.randn(num_experts, hidden_dim, dim, device=device, dtype=dtype) * 0.02
w2 = torch.randn(num_experts, dim, hidden_dim, device=device, dtype=dtype) * 0.02
w3 = torch.randn(num_experts, hidden_dim, dim, device=device, dtype=dtype) * 0.02
return w1, w2, w3class MoEExpertBackendsTest(unittest.TestCase):
def test_batched_mm_padded_matches_for_loop_kernel(self) -> None:
devices: list[str] = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
# Always include float32; add a lower-precision dtype to exercise
# numerical and backend behavior beyond the default.
dtypes: list[torch.dtype] = [torch.float32]
if hasattr(torch, "float16"):
dtypes.append(torch.float16)
for device in devices:
for dtype in dtypes:
with self.subTest(device=device, dtype=dtype):
torch.manual_seed(1234)
num_experts = 6
dim = 16
hidden_dim = 24
counts = torch.tensor(
[0, 3, 5, 1, 0, 7],
dtype=torch.int64,
device=device,
)
total_tokens = int(counts.sum().item())
w1, w2, w3 = _init_weights(
num_experts,
dim,
hidden_dim,
device=device,
dtype=dtype,
)
x = torch.randn(total_tokens, dim, device=device, dtype=dtype)If the remainder of test_batched_mm_padded_matches_for_loop_kernel (not shown in the snippet) creates additional tensors (e.g., reference outputs, indices, or masks), they should also be constructed with device=device and dtype=dtype (or at least device=device, and dtype chosen consistently with the values being compared) to avoid device/dtype mismatches and to ensure the parameterization fully exercises the backend on the selected configuration.
|
@samuelwheeler this is excellent, my only concern would be the changes to things outside of the I think it should be possible to integrate these changes into there directly; claude had this to say: Restructure (everything inside
from dataclasses import dataclass
from typing import Literal
from torch.distributed.tensor import DTensor
from torchtitan.models.common.moe import GroupedExperts
ExpertComputeBackend = Literal["for_loop", "grouped_mm", "batched_mm_padded"]
def _run_experts_batched_mm_padded(w1, w2, w3, x, num_tokens_per_expert):
# ... same body as in this PR
...
class EzpzGroupedExperts(GroupedExperts):
@dataclass(kw_only=True, slots=True)
class Config(GroupedExperts.Config):
compute_backend: ExpertComputeBackend | None = None
def __init__(self, config: Config):
super().__init__(config)
self.compute_backend = (
config.compute_backend
if config.compute_backend is not None
else ("grouped_mm" if config.use_grouped_mm else "for_loop")
)
def _experts_forward(self, x, num_tokens_per_expert):
if self.compute_backend != "batched_mm_padded":
return super()._experts_forward(x, num_tokens_per_expert)
if isinstance(self.w1, DTensor):
w1, w2, w3 = self.w1.to_local(), self.w2.to_local(), self.w3.to_local()
else:
w1, w2, w3 = self.w1, self.w2, self.w3
return _run_experts_batched_mm_padded(w1, w2, w3, x, num_tokens_per_expert)Why this works:
|
|
Heads up — upstream just landed pytorch/torchtitan#3308 ( This actually makes the subclass approach more attractive, not less:
Updated sketch: # torchtitan/experiments/ezpz/moe/experts.py
ExpertComputeBackend = Literal["for_loop", "grouped_mm", "batched_mm_padded"]
def _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert):
# copy from the just-deleted upstream version
...
def _run_experts_batched_mm_padded(w1, w2, w3, x, num_tokens_per_expert):
# body from this PR
...
class EzpzGroupedExperts(GroupedExperts):
@dataclass(kw_only=True, slots=True)
class Config(GroupedExperts.Config):
compute_backend: ExpertComputeBackend = "grouped_mm"
def __init__(self, config: Config):
super().__init__(config)
self.compute_backend = config.compute_backend
def _experts_forward(self, x, num_tokens_per_expert):
if self.compute_backend == "grouped_mm":
return super()._experts_forward(x, num_tokens_per_expert)
# local DTensor unwrap + dispatch to for_loop / batched_mm_padded
...I held off pulling this latest upstream batch on |
|
Folded into #12. The new backend lives at |
|
FYI — #13 just landed an When you rebase onto the merged
That keeps everything in |
…ch#3308 Replays the pytorch#3308 ("Remove MoE expert for-loop fallback") deletion onto experiments/ezpz/moe/. That upstream PR removed the `use_grouped_mm` config field from `GroupedExperts.Config` and inlined `torch._grouped_mm` as the only expert path. Upstream's argument is that `_grouped_mm` already provides a CUDA fallback on pre-SM90 hardware. **XPU has no `_grouped_mm` kernel at all**, so the unconditional path breaks ezpz on Aurora / Sunspot. - `experiments/ezpz/moe/experts.py` (new): `EzpzGroupedExperts` subclass with `compute_backend: Literal["for_loop", "grouped_mm"]`. Default defers to upstream; `"for_loop"` re-vendors the per-expert matmul loop that pytorch#3308 deleted, restoring the XPU / pre-SM90 path. - `experiments/ezpz/moe/__init__.py`: `make_ezpz_experts_config(...)` wrapper that calls upstream's `make_experts_config(...)` and re-wraps as `EzpzGroupedExperts.Config`. `_build_moe_layers` threads a `compute_backend` kwarg (default `"grouped_mm"`). - `experiments/ezpz/moe/model.py`: `update_from_config` previously mutated `experts.use_grouped_mm = False` on pre-SM90 devices; that field no longer exists. Replaced with `experts_cfg.compute_backend = "for_loop"` guarded by a `getattr(..., "compute_backend", "grouped_mm")` so the block is robust to future config-shape changes. Default behavior is unchanged for SM90+ CUDA: `compute_backend` defaults to `"grouped_mm"`, which calls `super()._experts_forward(...)` and hits the upstream path. The for_loop branch is only taken on devices without `_grouped_mm`. Pending PRs saforem2#9 / saforem2#10 / saforem2#11 will need to rebase onto this and adapt to the `EzpzGroupedExperts` subclass — see docs/upstream-sync.md for notes on what each will need to do.
690-line file with zero non-self importers. ~560 lines are
byte-similar to torchtitan/components/metrics.py (DeviceMemoryMonitor,
BaseLogger, TensorBoardLogger, WandBLogger, LoggerContainer,
ensure_pp_loss_visible, _get_metrics_rank, MetricsProcessor — all
duplicated). Genuine new code is ~130 lines: TT_MOE_DEBUG_MEMORY
helpers + get_current_stats() + _MOE_FASTPATH_COUNTERS integration
in MetricsProcessor.log().
Verified zero importers:
grep -rn "EzpzMetricsProcessor\|ezpz.metrics\|TT_MOE_DEBUG_MEMORY" torchtitan/
_MOE_FASTPATH_COUNTERS is only referenced from within
ezpz/moe/token_dispatcher.py via its own _record_moe_fastpath helper,
which keeps the counters in-process. With this metrics.py gone the
counters still increment; they just don't surface to W&B/TB unless
the trainer is updated to read them — same state as today, since
nothing wires this metrics.py in.
Re-add as a thin MetricsProcessor subclass overriding only log() /
should_log() when there's a consumer in train.py / trainer.py.
The TT_MOE_DEBUG_MEMORY knobs are model-agnostic and should
probably land upstream on MetricsProcessor.Config instead.
Addresses review finding saforem2#10 on saforem2#14.
Summary
Adds a
batched_mm_paddedexpert compute backend for MoEGroupedExperts.The backend pads routed expert inputs to
[num_experts, max_tokens_per_expert, dim], usestorch.bmmfor the SwiGLU expert projections, then gathers back to routed token order.Default behavior is unchanged: if
compute_backendis unset, the existinguse_grouped_mmflag still selects grouped-mm vs for-loop. The new path is opt-in viacompute_backend="batched_mm_padded".Also threads the backend selector through
make_experts_config, adds an ezpz10B_2B_sdpa_batched_mm_paddedflavor, and adds unit coverage comparing outputs and gradients against the for-loop implementation.Validation
git diff --cached --check/lus/flare/projects/AuroraGPT/sww/new_tt_aurora/torchtitan/.venv/bin/python -m py_compile torchtitan/models/common/moe.py torchtitan/models/common/config_utils.py torchtitan/experiments/ezpz/moe/__init__.py torchtitan/experiments/ezpz/moe/config_registry.py tests/unit_tests/test_moe_expert_backends.pyPYTHONPATH=. timeout 300s /lus/flare/projects/AuroraGPT/sww/new_tt_aurora/torchtitan/.venv/bin/python -m unittest -v tests.unit_tests.test_moe_expert_backendsSummary by Sourcery
Add a new padded batched matmul expert compute backend for MoE GroupedExperts and wire it through configs and tests.
New Features:
Enhancements:
Tests: