Skip to content

[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907

Open
jing-4369 wants to merge 5 commits intoNVIDIA:mainfrom
jing-4369:fix/moe-permute-int-overflow-and-minus-one
Open

[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907
jing-4369 wants to merge 5 commits intoNVIDIA:mainfrom
jing-4369:fix/moe-permute-int-overflow-and-minus-one

Conversation

@jing-4369
Copy link
Copy Markdown

@jing-4369 jing-4369 commented Apr 21, 2026

Fixes #2908 — full description, repros, and DeepSeek-V3 context there.

Changes

  • permutation.cu — widen source_token, source_row, dest_row to int64_t inside moe_unpermute_kernel and moe_permute_kernel so row * num_cols stays 64-bit. Simplify moe_permute_row_map to only process the valid [0, num_out_tokens) range; launcher grid becomes num_out_tokens blocks.
  • permutation.cpp — advance sorted_row_id_ptr past the num_minus_ones sentinel prefix left by cub::DeviceRadixSort (signed ascending), and pre-fill row_id_map with -1 via torch::full so dropped slots are marked without the kernel ever dereferencing a sentinel.

No public API / dtype changes. +17 / -18 lines across the two files.

Test plan

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 21, 2026

Greptile Summary

This PR widens source_token, source_row, and dest_row to int64_t in moe_unpermute_kernel and moe_permute_kernel to prevent row * num_cols from overflowing 32 bits, and switches nvte_device_radix_sort_pairs to sort keys as uint32_t so that -1 sentinel entries (bit-pattern 0xFFFFFFFF) land at the tail of the sorted output instead of the head.

  • int64_t widening — fixes the illegal-memory-access repro for large num_rows; the guard and index computations in moe_permute_row_map are similarly widened.
  • uint32_t sort — repositions -1 sentinel entries to the tail of sorted_row_id; however, moe_permute_row_map still processes all num_rows \u00d7 topK positions, and its drop branch computes source_token_id = (-1) / topK and source_topK_id = (-1) % topK for those tail entries, producing a write to row_id_map[-num_rows] (or row_id_map[-1]) that silently corrupts adjacent device memory for any workload with topK > 1 and sentinel routing values.

Confidence Score: 3/5

The uint32_t sort change correctly repositions -1 sentinel entries, but the kernel still writes to a negative row_id_map index when it encounters them, corrupting adjacent device memory for any model with sentinel routing and topK > 1.

The PR's test suite only exercises topK=1 for the -1 sentinel repro. For topK > 1 (which DeepSeek-V3 uses), (-1) % topK evaluates to -1 in C++, and the write row_id_map[-1 * num_rows + 0] lands num_rows elements before the buffer — silent device-memory corruption that won't surface until the next kernel reads that region. A single if (source_row < 0) return guard closes this, but without it the -1 sentinel path is still broken for the primary target workload.

transformer_engine/common/permutation/permutation.cu — specifically the drop branch of moe_permute_row_map after the uint32_t sort change

Important Files Changed

Filename Overview
transformer_engine/common/permutation/permutation.cu Widens source_token/dest_row to int64_t in forward/backward kernels and fixes the launcher grid computation; changes radix sort to uint32_t so -1 sentinels land at the tail. The moe_permute_row_map drop branch still processes those tail entries without guarding against source_row == -1, causing an OOB write for any topK > 1 workload with sentinel routing.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[indices tensor
routing map with possible -1 sentinels] --> B[nvte_device_radix_sort_pairs
uint32_t sort: -1 goes to tail]
    B --> C[sorted_row_id
valid IDs first, -1 sentinels last]
    C --> D[moe_permute_row_map
idx in 0..num_rows x topK-1]
    D --> E{idx >= num_out_tokens?}
    E -- No --> F[Write idx to row_id_map OK]
    E -- Yes --> G{source_row < 0?}
    G -- No valid dropped token --> H[Write -1 to row_id_map OK]
    G -- Yes -1 sentinel source_topK_id=-1 --> I[OOB write row_id_map at negative index]
    F --> J[moe_permute_kernel / moe_unpermute_kernel]
    H --> J
Loading

Reviews (7): Last reviewed commit: "Widen num_rows * topK products in moe_pe..." | Re-trigger Greptile

Comment on lines +59 to +60
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Negative num_minus_ones becomes enormous size_t offset

num_minus_ones is computed as int. If a caller passes num_out_tokens > num_tokens * topK (which the function does not validate), num_minus_ones is negative. The pointer advance expression:

sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int);

involves int * size_t, which promotes num_minus_ones to size_t (unsigned). A value like -4 becomes SIZE_MAX - 3, advancing the pointer far out of the allocation and causing a silent OOB read. A simple clamp or assert before this line would prevent this:

TORCH_CHECK(num_out_tokens <= num_tokens * topK,
            "num_out_tokens (", num_out_tokens, ") cannot exceed num_tokens*topK (",
            num_tokens * topK, ")");

…permute

Two independent bugs in transformer_engine/common/permutation/permutation.cu
and the PyTorch extension caller reproduce on main (264da2b) and v2.13:

1. int32 overflow in moe_unpermute_kernel and moe_permute_kernel.
   `source_token * num_cols` and `source_row * num_cols` are computed with
   int, so for long-sequence MoE workloads where num_out_tokens * num_cols
   reaches 2**31 (e.g. 2**18 tokens x 2**13 hidden), the pointer offset
   wraps and the kernel either reads garbage or raises
   `an illegal memory access was encountered`.
   Widening source_token, source_row and dest_row to int64_t inside the
   kernels keeps the index arithmetic in 64 bits without changing any
   public types.

2. Incorrect handling of -1 sentinels in the routing indices.
   Libraries such as DeepEP (and any expert-parallel mask that sets
   non-local (token, slot) pairs to -1) feed a routing_map that contains
   -1 entries. `cub::DeviceRadixSort::SortPairs` is signed ascending, so
   those sentinels land at the HEAD of sorted_row_id, not the tail.
   moe_permute_row_map currently writes -1 only for idx >= num_out_tokens
   and reads the sentinel prefix as if it were a valid sorted id,
   producing bogus row_id_map writes (for instance
   `source_row / topK == 0, source_row % topK == -1`).

   The caller now advances sorted_row_id_ptr past the num_minus_ones
   prefix and pre-fills row_id_map with -1 via torch::full, so the
   kernel only processes the valid suffix and never dereferences a
   sentinel.  The launcher's grid switches from num_rows*topK blocks
   to num_out_tokens blocks to match the new valid range.

No behaviour change on happy-path routing_map (no -1, no overflow).
Reproducers:

- 8-token, topK=2 routing_map with -1 masking: max |TE - ref| = 4.5e0
  on bf16 with current main; 0.0 with this patch.
- num_tokens=2**18+1, num_cols=2**13, topK=1: current main raises
  CUDA illegal memory access at permutation.cu:252; with this patch
  it succeeds.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
@jing-4369 jing-4369 force-pushed the fix/moe-permute-int-overflow-and-minus-one branch from a909254 to ce34774 Compare April 21, 2026 07:58
Comment on lines +61 to +63
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 num_tokens * topK still computed as int * int

num_tokens and topK are both int, so num_tokens * topK on line 61 is evaluated in 32-bit arithmetic before the result feeds the int64_t subtraction. The same expression appears twice in the NVTE_CHECK on lines 59–60. If num_tokens * topK wraps to a negative int (possible when, e.g., num_tokens ≥ 2^31 / topK), the NVTE_CHECK would either spuriously reject a valid num_out_tokens, or the error-message value would be wrong. Casting to int64_t before the multiplication closes this gap:

Suggested change
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);
const int64_t total_tokens = static_cast<int64_t>(num_tokens) * topK;
NVTE_CHECK(num_out_tokens <= total_tokens, "num_out_tokens (", num_out_tokens,
") must not exceed num_tokens*topK (", total_tokens, ")");
const int num_minus_ones = static_cast<int>(total_tokens - num_out_tokens);
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (the greptile command) looks correct. Can you please help cast num_tpkens to int64 before multiplication and - num_out_tokens?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a huge deal because even with topK=128, youwould need > 16M tokens per rank for the int product to overflow. But better to be consistent, and also, this casting of 1 value on the CPU side probably would not slow dow much

@jing-4369 jing-4369 force-pushed the fix/moe-permute-int-overflow-and-minus-one branch from ba39fba to cd64aa5 Compare April 21, 2026 08:14
Add an NVTE_CHECK that num_out_tokens <= num_tokens * topK and cast
num_minus_ones to size_t before the pointer advance, so a negative
num_minus_ones (from an invalid num_out_tokens) cannot silently wrap
into a huge pointer offset.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
@jing-4369 jing-4369 force-pushed the fix/moe-permute-int-overflow-and-minus-one branch from cd64aa5 to b73a1f9 Compare April 21, 2026 08:22
@ptrendx ptrendx added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Apr 21, 2026
Comment on lines +61 to +63
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (the greptile command) looks correct. Can you please help cast num_tpkens to int64 before multiplication and - num_out_tokens?

Comment on lines +61 to +63
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a huge deal because even with topK=128, youwould need > 16M tokens per rank for the int product to overflow. But better to be consistent, and also, this casting of 1 value on the CPU side probably would not slow dow much


int threads = 64;
int blocks = (num_rows * topK + threads - 1) / threads;
int blocks = (num_out_tokens + threads - 1) / threads;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is correct here but has an implied prerequisite that host prefills the buffer with -1 and shift the ptr by num_minus_ones (what you did in the other file). Better make it more explicit with a comment so no regression will happen by someone accidentally changing this behavior and mess up the number of blocks here. Something like:

"// row_id_map MUST be pre-initialized to -1; sorted_row_id MUST point past the sentinel prefix"

num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK;
NVTE_CHECK(num_out_tokens <= num_tokens * topK, "num_out_tokens (", num_out_tokens,
") must not exceed num_tokens*topK (", num_tokens * topK, ")");
const int num_minus_ones = num_tokens * topK - num_out_tokens;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably going to introduce a regression for the capacity-drop path. This shift assumes the dropped routes are -1 sentinels at the head of sorted_row_id (cub's signed radix sort), which is true for the EP-mask case this PR targets. But the pre-existing capacity-drop path encodes drops as a large positive expert id that sorts to the tail. For that case, the head is valid low-expert-id rows, and shifting past them drops the wrong tokens.(just fyi, capacity-dropping case means no -1 in indices, num_out_tokens < num_tokens * topK because some expert exceeded capacity))

See in this file tests/pytorch/test_permutation.py, in pytorch_permute_index_map, we have:

sorted_indices[:num_out_tokens] (keeps the head),
so I'd expect test_permutation_index_map[..., num_out_tokens=2039, ...] to fail. We can run the te_ci to confirm it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think another solution to this without doing num_tokens * topk - num_out_tokens (or counting the number of -1 on host side) is to sort the keys as uint32_t instead of int32_t. So, -1 becomes UINT_MAX and sorts to the tail, unifying both capacity-dropping and dropless under the original idx >= num_out_tokens --> drop logic. That removes the need for the prefix shift you did, and the row_id_map pre-fill. This just needs expert_id to be <= UINT_MAX, which I do not think we are reaching there anytime soon

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the careful review. Acknowledging the capacity-drop regression concern and the unsigned-sort suggestion below — both make sense. Waiting on the te_ci result you triggered before I push any code change, so we have a concrete signal on what needs to move.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdophung
Copy link
Copy Markdown
Collaborator

tdophung commented May 4, 2026

/te_ci pytorch

@tdophung
Copy link
Copy Markdown
Collaborator

tdophung commented May 6, 2026

/te-ci pytorch

@tdophung
Copy link
Copy Markdown
Collaborator

tdophung commented May 6, 2026

/te-ci pytorch L0

The MoE permute path was correct for the existing capacity-drop convention
(drops encoded as a large positive expert id, sorted to the tail by the
signed cub::DeviceRadixSort), but it broke for callers that mark dropped
(token, slot) pairs with -1 (expert-parallel rank masking, e.g. DeepEP).
With signed sort the -1 sentinels land at the HEAD of sorted_row_id, while
moe_permute_row_map's `idx >= num_out_tokens` branch assumes drops are at
the tail.

Reinterpret the keys as uint32_t inside nvte_device_radix_sort_pairs so
-1 (= UINT_MAX) sorts to the tail, unifying the EP-mask case with the
existing capacity-drop convention. The kernel and caller sides are
unchanged - this is a one-place fix that makes both drop conventions
land in the existing drop branch.

Also widen the loop-carried indices in moe_unpermute_kernel and
moe_permute_kernel to int64_t (`source_token`, `source_row`, `dest_row`)
to keep `row * num_cols` in 64 bits. We hit this on DeepSeek-V3 long-
context training (hidden = 7168, topK = 8): once `num_out_tokens *
num_cols` reaches 2**31 the int product wraps and the kernel either
silently corrupts rows or raises CUDA `illegal memory access`.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
@jing-4369
Copy link
Copy Markdown
Author

@tdophung Pushed the unsigned-sort approach in 4f46dc2. Net diff is now +13/-6 in permutation.cu only; permutation.cpp is unchanged from upstream, and the earlier comments around NVTE_CHECK / launcher block-count no longer apply.

Could you re-trigger /te-ci pytorch when convenient?

@tdophung
Copy link
Copy Markdown
Collaborator

tdophung commented May 7, 2026

@tdophung Pushed the unsigned-sort approach in 4f46dc2. Net diff is now +13/-6 in permutation.cu only; permutation.cpp is unchanged from upstream, and the earlier comments around NVTE_CHECK / launcher block-count no longer apply.

Could you re-trigger /te-ci pytorch when convenient?

Hi, I can trigger it now. However, for future reference, you could also trigger it by commenting "/te-ci " on the PR. is either pytorch or jax, and can be L0, L1, L2. Alternatively, you can test it out locally also,, by running the related tests in tests/pytorch or tests/jax, depending on the change, or using the scripts in qa/L0_pytorch_unittest/test.sh

@tdophung
Copy link
Copy Markdown
Collaborator

tdophung commented May 7, 2026

/te-ci pytorch

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to greptile's P1 on moe_permute_row_map / its launcher: now that the kernel is back to walking all num_rows * topK entries, the int * int products are the same overflow class this PR is fixing in the other two kernels. Could you widen those to int64_t for consistency?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

source_topK_id * num_rows =. int * int. same as comment on line 242

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as comment on line 242

@tdophung
Copy link
Copy Markdown
Collaborator

tdophung commented May 7, 2026

I can see that the CI finished and the failures are not related to your change: https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/50584460

the only test failing was fused_qkv not permutation

Per reviewer feedback in NVIDIA#2907, promote the
int * int multiplications in moe_permute_row_map and its launcher to
int64_t. These are not the overflow path this PR was originally
fixing (DeepSeek-V3 long-context hits row * num_cols, where num_cols
is the hidden dim ~ 7-8k), and num_rows * topK only crosses 2**31 at
unrealistic per-rank token counts (>= 268M at topK=8). The change is
purely defensive but keeps the index arithmetic in this kernel
consistent with the int64_t source_token / source_row / dest_row
widening already applied to moe_unpermute_kernel and moe_permute_kernel.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
Comment on lines 24 to 26
int source_row = sorted_row_id[idx];
int source_token_id = source_row / topK;
int source_topK_id = source_row % topK;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 After the uint32_t-sort change, -1 sentinel entries land at the tail of sorted_row_id (positions ≥ num_out_tokens). The drop branch then computes source_token_id = (-1) / topK and source_topK_id = (-1) % topK. For topK > 1, C++ truncates toward zero, giving source_token_id = 0 and source_topK_id = -1, so the write becomes row_id_map[-1 * num_rows + 0]num_rows words before the buffer start. For topK = 1 the write lands at row_id_map[-1]. Both cases silently corrupt adjacent device memory. A simple early-exit on source_row < 0 closes this gap without touching the caller.

Suggested change
int source_row = sorted_row_id[idx];
int source_token_id = source_row / topK;
int source_topK_id = source_row % topK;
int source_row = sorted_row_id[idx];
if (source_row < 0) return; // skip -1 sentinel entries
int source_token_id = source_row / topK;
int source_topK_id = source_row % topK;

@jing-4369
Copy link
Copy Markdown
Author

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] moe_permute CUDA kernel: int32 overflow and incorrect -1 sentinel handling

3 participants