Skip to content

[Common] Use specialized unfused MXFP8 cast kernels by default#2958

Open
Oleg-Goncharov wants to merge 5 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_fast_default_mxfp8_kernels
Open

[Common] Use specialized unfused MXFP8 cast kernels by default#2958
Oleg-Goncharov wants to merge 5 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_fast_default_mxfp8_kernels

Conversation

@Oleg-Goncharov
Copy link
Copy Markdown
Collaborator

Description

This PR enables the fast unfused MXFP8 cast kernels by default.

Previously, these kernels were gated behind an environment variable and therefore were not used unless explicitly enabled. This change makes the specialized cast-only path the default behavior.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Removed environment variable

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 5, 2026

Greptile Summary

This PR promotes the unfused MXFP8 cast-only kernels from opt-in (via ENABLE_CAST_ONLY env var) to the default path for supported type combinations (fp16/bf16 → fp8e5m2/fp8e4m3, no dbias/dact/act, no GEMM-swizzled scales).

  • specialized/quantize_mxfp8.cuh: Removes is_cast_only_enabled() and its ENABLE_CAST_ONLY env-var logic; all four hasSpec() specializations now unconditionally return true.
  • quantize_mxfp8.cuh: Adds an is_full_rowwise_chunk guard (cols % CastTraits::chunkElems == 0, i.e. cols % 32 == 0) so that rowwise shapes with a partial tail fall back to the generic kernel safely, and removes the previously-unreachable COLWISE case from the switch that was flagged in a prior review.

Confidence Score: 5/5

The change is safe to merge. The correctness risk from enabling the fast path by default is tightly bounded by the new is_full_rowwise_chunk guard for the rowwise kernel and by the TMA-backed bounds checks already present in the bidimensional kernel.

Both changed files are narrow and well-reasoned. The env-var removal is a clean mechanical change. The rowwise alignment guard (cols % 32 == 0) directly addresses the out-of-bounds load/store risk that was the only functional concern with unconditionally enabling this path. The bidimensional kernel uses TMA loads with internal coords.x >= cols early-exit guards, so it handles non-aligned shapes correctly without a comparable external guard. The previously-flagged dead COLWISE case has been removed. No regression risk identified.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh Removes the is_cast_only_enabled() env-var gate and makes all four hasSpec() specializations (fp16/bf16 → fp8e5m2/fp8e4m3 without dbias/dact/act) unconditionally return true; dead debug FIXME comment also removed.
transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh Adds is_full_rowwise_chunk guard (cols % 32 == 0) to prevent the rowwise cast-only kernel from reading/writing past the logical row end on non-aligned shapes; removes the now-unreachable dead COLWISE case from the switch that was previously flagged in review.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["quantize() called"] --> B{"hasSpec<IS_DBIAS,IS_DACT,IS_ACT,IType,OType>()"}
    B -- "false (unsupported combo)" --> G["Generic kernel"]
    B -- "true (fp16/bf16→fp8, cast-only)" --> C{"WITH_GEMM_SWIZZLED_SCALES?"}
    C -- "yes" --> G
    C -- "no" --> D{"scaling_type_has_specialized_support?"}
    D -- "no (COLWISE, or ROWWISE with cols%32≠0)" --> G
    D -- "yes" --> E{"scaling_type"}
    E -- "ROWWISE\n(cols%32==0 guaranteed)" --> F1["specialized rowwise kernel\n(vectorized 32-elem chunks)"]
    E -- "BIDIMENSIONAL\n(TMA handles non-aligned)" --> F2["specialized bidirectional kernel\n(TMA loads, internal OOB guards)"]
    E -- "other" --> ERR["NVTE_ERROR: Invalid scaling type"]
    F1 --> RET["return"]
    F2 --> RET
    G --> GEN["Generic MXFP8 kernel (TMA-based)"]
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

ksivaman
ksivaman previously approved these changes May 5, 2026
Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov
Copy link
Copy Markdown
Collaborator Author

/te-ci

ksivaman
ksivaman previously approved these changes May 5, 2026
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov
Copy link
Copy Markdown
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants