[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907
[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907jing-4369 wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR widens
Confidence Score: 3/5The uint32_t sort change correctly repositions -1 sentinel entries, but the kernel still writes to a negative row_id_map index when it encounters them, corrupting adjacent device memory for any model with sentinel routing and topK > 1. The PR's test suite only exercises topK=1 for the -1 sentinel repro. For topK > 1 (which DeepSeek-V3 uses), (-1) % topK evaluates to -1 in C++, and the write row_id_map[-1 * num_rows + 0] lands num_rows elements before the buffer — silent device-memory corruption that won't surface until the next kernel reads that region. A single if (source_row < 0) return guard closes this, but without it the -1 sentinel path is still broken for the primary target workload. transformer_engine/common/permutation/permutation.cu — specifically the drop branch of moe_permute_row_map after the uint32_t sort change Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[indices tensor
routing map with possible -1 sentinels] --> B[nvte_device_radix_sort_pairs
uint32_t sort: -1 goes to tail]
B --> C[sorted_row_id
valid IDs first, -1 sentinels last]
C --> D[moe_permute_row_map
idx in 0..num_rows x topK-1]
D --> E{idx >= num_out_tokens?}
E -- No --> F[Write idx to row_id_map OK]
E -- Yes --> G{source_row < 0?}
G -- No valid dropped token --> H[Write -1 to row_id_map OK]
G -- Yes -1 sentinel source_topK_id=-1 --> I[OOB write row_id_map at negative index]
F --> J[moe_permute_kernel / moe_unpermute_kernel]
H --> J
Reviews (7): Last reviewed commit: "Widen num_rows * topK products in moe_pe..." | Re-trigger Greptile |
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int); |
There was a problem hiding this comment.
Negative
num_minus_ones becomes enormous size_t offset
num_minus_ones is computed as int. If a caller passes num_out_tokens > num_tokens * topK (which the function does not validate), num_minus_ones is negative. The pointer advance expression:
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int);involves int * size_t, which promotes num_minus_ones to size_t (unsigned). A value like -4 becomes SIZE_MAX - 3, advancing the pointer far out of the allocation and causing a silent OOB read. A simple clamp or assert before this line would prevent this:
TORCH_CHECK(num_out_tokens <= num_tokens * topK,
"num_out_tokens (", num_out_tokens, ") cannot exceed num_tokens*topK (",
num_tokens * topK, ")");…permute Two independent bugs in transformer_engine/common/permutation/permutation.cu and the PyTorch extension caller reproduce on main (264da2b) and v2.13: 1. int32 overflow in moe_unpermute_kernel and moe_permute_kernel. `source_token * num_cols` and `source_row * num_cols` are computed with int, so for long-sequence MoE workloads where num_out_tokens * num_cols reaches 2**31 (e.g. 2**18 tokens x 2**13 hidden), the pointer offset wraps and the kernel either reads garbage or raises `an illegal memory access was encountered`. Widening source_token, source_row and dest_row to int64_t inside the kernels keeps the index arithmetic in 64 bits without changing any public types. 2. Incorrect handling of -1 sentinels in the routing indices. Libraries such as DeepEP (and any expert-parallel mask that sets non-local (token, slot) pairs to -1) feed a routing_map that contains -1 entries. `cub::DeviceRadixSort::SortPairs` is signed ascending, so those sentinels land at the HEAD of sorted_row_id, not the tail. moe_permute_row_map currently writes -1 only for idx >= num_out_tokens and reads the sentinel prefix as if it were a valid sorted id, producing bogus row_id_map writes (for instance `source_row / topK == 0, source_row % topK == -1`). The caller now advances sorted_row_id_ptr past the num_minus_ones prefix and pre-fills row_id_map with -1 via torch::full, so the kernel only processes the valid suffix and never dereferences a sentinel. The launcher's grid switches from num_rows*topK blocks to num_out_tokens blocks to match the new valid range. No behaviour change on happy-path routing_map (no -1, no overflow). Reproducers: - 8-token, topK=2 routing_map with -1 masking: max |TE - ref| = 4.5e0 on bf16 with current main; 0.0 with this patch. - num_tokens=2**18+1, num_cols=2**13, topK=1: current main raises CUDA illegal memory access at permutation.cu:252; with this patch it succeeds. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
a909254 to
ce34774
Compare
for more information, see https://pre-commit.ci
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | ||
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
num_tokens * topK still computed as int * int
num_tokens and topK are both int, so num_tokens * topK on line 61 is evaluated in 32-bit arithmetic before the result feeds the int64_t subtraction. The same expression appears twice in the NVTE_CHECK on lines 59–60. If num_tokens * topK wraps to a negative int (possible when, e.g., num_tokens ≥ 2^31 / topK), the NVTE_CHECK would either spuriously reject a valid num_out_tokens, or the error-message value would be wrong. Casting to int64_t before the multiplication closes this gap:
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | |
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | |
| static_cast<size_t>(num_minus_ones) * sizeof(int); | |
| const int64_t total_tokens = static_cast<int64_t>(num_tokens) * topK; | |
| NVTE_CHECK(num_out_tokens <= total_tokens, "num_out_tokens (", num_out_tokens, | |
| ") must not exceed num_tokens*topK (", total_tokens, ")"); | |
| const int num_minus_ones = static_cast<int>(total_tokens - num_out_tokens); | |
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | |
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
This (the greptile command) looks correct. Can you please help cast num_tpkens to int64 before multiplication and - num_out_tokens?
There was a problem hiding this comment.
This is not a huge deal because even with topK=128, youwould need > 16M tokens per rank for the int product to overflow. But better to be consistent, and also, this casting of 1 value on the CPU side probably would not slow dow much
ba39fba to
cd64aa5
Compare
Add an NVTE_CHECK that num_out_tokens <= num_tokens * topK and cast num_minus_ones to size_t before the pointer advance, so a negative num_minus_ones (from an invalid num_out_tokens) cannot silently wrap into a huge pointer offset. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
cd64aa5 to
b73a1f9
Compare
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | ||
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
This (the greptile command) looks correct. Can you please help cast num_tpkens to int64 before multiplication and - num_out_tokens?
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | ||
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
This is not a huge deal because even with topK=128, youwould need > 16M tokens per rank for the int product to overflow. But better to be consistent, and also, this casting of 1 value on the CPU side probably would not slow dow much
|
|
||
| int threads = 64; | ||
| int blocks = (num_rows * topK + threads - 1) / threads; | ||
| int blocks = (num_out_tokens + threads - 1) / threads; |
There was a problem hiding this comment.
this is correct here but has an implied prerequisite that host prefills the buffer with -1 and shift the ptr by num_minus_ones (what you did in the other file). Better make it more explicit with a comment so no regression will happen by someone accidentally changing this behavior and mess up the number of blocks here. Something like:
"// row_id_map MUST be pre-initialized to -1; sorted_row_id MUST point past the sentinel prefix"
| num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; | ||
| NVTE_CHECK(num_out_tokens <= num_tokens * topK, "num_out_tokens (", num_out_tokens, | ||
| ") must not exceed num_tokens*topK (", num_tokens * topK, ")"); | ||
| const int num_minus_ones = num_tokens * topK - num_out_tokens; |
There was a problem hiding this comment.
This is probably going to introduce a regression for the capacity-drop path. This shift assumes the dropped routes are -1 sentinels at the head of sorted_row_id (cub's signed radix sort), which is true for the EP-mask case this PR targets. But the pre-existing capacity-drop path encodes drops as a large positive expert id that sorts to the tail. For that case, the head is valid low-expert-id rows, and shifting past them drops the wrong tokens.(just fyi, capacity-dropping case means no -1 in indices, num_out_tokens < num_tokens * topK because some expert exceeded capacity))
See in this file tests/pytorch/test_permutation.py, in pytorch_permute_index_map, we have:
sorted_indices[:num_out_tokens] (keeps the head),
so I'd expect test_permutation_index_map[..., num_out_tokens=2039, ...] to fail. We can run the te_ci to confirm it.
There was a problem hiding this comment.
I think another solution to this without doing num_tokens * topk - num_out_tokens (or counting the number of -1 on host side) is to sort the keys as uint32_t instead of int32_t. So, -1 becomes UINT_MAX and sorts to the tail, unifying both capacity-dropping and dropless under the original idx >= num_out_tokens --> drop logic. That removes the need for the prefix shift you did, and the row_id_map pre-fill. This just needs expert_id to be <= UINT_MAX, which I do not think we are reaching there anytime soon
There was a problem hiding this comment.
Thanks for the careful review. Acknowledging the capacity-drop regression concern and the unsigned-sort suggestion below — both make sense. Waiting on the te_ci result you triggered before I push any code change, so we have a concrete signal on what needs to move.
There was a problem hiding this comment.
Here is the CI pipeline: https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/50478896
It failed in the expected tests
|
/te_ci pytorch |
|
/te-ci pytorch |
|
/te-ci pytorch L0 |
The MoE permute path was correct for the existing capacity-drop convention (drops encoded as a large positive expert id, sorted to the tail by the signed cub::DeviceRadixSort), but it broke for callers that mark dropped (token, slot) pairs with -1 (expert-parallel rank masking, e.g. DeepEP). With signed sort the -1 sentinels land at the HEAD of sorted_row_id, while moe_permute_row_map's `idx >= num_out_tokens` branch assumes drops are at the tail. Reinterpret the keys as uint32_t inside nvte_device_radix_sort_pairs so -1 (= UINT_MAX) sorts to the tail, unifying the EP-mask case with the existing capacity-drop convention. The kernel and caller sides are unchanged - this is a one-place fix that makes both drop conventions land in the existing drop branch. Also widen the loop-carried indices in moe_unpermute_kernel and moe_permute_kernel to int64_t (`source_token`, `source_row`, `dest_row`) to keep `row * num_cols` in 64 bits. We hit this on DeepSeek-V3 long- context training (hidden = 7168, topK = 8): once `num_out_tokens * num_cols` reaches 2**31 the int product wraps and the kernel either silently corrupts rows or raises CUDA `illegal memory access`. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
Hi, I can trigger it now. However, for future reference, you could also trigger it by commenting "/te-ci " on the PR. is either pytorch or jax, and can be L0, L1, L2. Alternatively, you can test it out locally also,, by running the related tests in tests/pytorch or tests/jax, depending on the change, or using the scripts in qa/L0_pytorch_unittest/test.sh |
|
/te-ci pytorch |
There was a problem hiding this comment.
+1 to greptile's P1 on moe_permute_row_map / its launcher: now that the kernel is back to walking all num_rows * topK entries, the int * int products are the same overflow class this PR is fixing in the other two kernels. Could you widen those to int64_t for consistency?
There was a problem hiding this comment.
source_topK_id * num_rows =. int * int. same as comment on line 242
There was a problem hiding this comment.
same as comment on line 242
|
I can see that the CI finished and the failures are not related to your change: https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/50584460 the only test failing was fused_qkv not permutation |
Per reviewer feedback in NVIDIA#2907, promote the int * int multiplications in moe_permute_row_map and its launcher to int64_t. These are not the overflow path this PR was originally fixing (DeepSeek-V3 long-context hits row * num_cols, where num_cols is the hidden dim ~ 7-8k), and num_rows * topK only crosses 2**31 at unrealistic per-rank token counts (>= 268M at topK=8). The change is purely defensive but keeps the index arithmetic in this kernel consistent with the int64_t source_token / source_row / dest_row widening already applied to moe_unpermute_kernel and moe_permute_kernel. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
| int source_row = sorted_row_id[idx]; | ||
| int source_token_id = source_row / topK; | ||
| int source_topK_id = source_row % topK; |
There was a problem hiding this comment.
After the uint32_t-sort change,
-1 sentinel entries land at the tail of sorted_row_id (positions ≥ num_out_tokens). The drop branch then computes source_token_id = (-1) / topK and source_topK_id = (-1) % topK. For topK > 1, C++ truncates toward zero, giving source_token_id = 0 and source_topK_id = -1, so the write becomes row_id_map[-1 * num_rows + 0] — num_rows words before the buffer start. For topK = 1 the write lands at row_id_map[-1]. Both cases silently corrupt adjacent device memory. A simple early-exit on source_row < 0 closes this gap without touching the caller.
| int source_row = sorted_row_id[idx]; | |
| int source_token_id = source_row / topK; | |
| int source_topK_id = source_row % topK; | |
| int source_row = sorted_row_id[idx]; | |
| if (source_row < 0) return; // skip -1 sentinel entries | |
| int source_token_id = source_row / topK; | |
| int source_topK_id = source_row % topK; |
|
/te-ci pytorch |
Fixes #2908 — full description, repros, and DeepSeek-V3 context there.
Changes
permutation.cu— widensource_token,source_row,dest_rowtoint64_tinsidemoe_unpermute_kernelandmoe_permute_kernelsorow * num_colsstays 64-bit. Simplifymoe_permute_row_mapto only process the valid[0, num_out_tokens)range; launcher grid becomesnum_out_tokensblocks.permutation.cpp— advancesorted_row_id_ptrpast thenum_minus_onessentinel prefix left bycub::DeviceRadixSort(signed ascending), and pre-fillrow_id_mapwith-1viatorch::fullso dropped slots are marked without the kernel ever dereferencing a sentinel.No public API / dtype changes.
+17 / -18lines across the two files.Test plan
routing_map(no-1, offsets within int32) — unchanged.-1-sentinel repro from [Bug] moe_permute CUDA kernel: int32 overflow and incorrect -1 sentinel handling #2908 →max |TE - ref| = 0.0on bf16 (was4.56e0).int32-boundary repro from [Bug] moe_permute CUDA kernel: int32 overflow and incorrect -1 sentinel handling #2908 → no longer raisesillegal memory access; matches reference.tests/pytorch/test_permutation.pyvia CI.