Skip to content

Implement per-token NVFP4 fprop recipe#2931

Open
zianglih wants to merge 22 commits intoNVIDIA:mainfrom
zianglih:fp4-per-token
Open

Implement per-token NVFP4 fprop recipe#2931
zianglih wants to merge 22 commits intoNVIDIA:mainfrom
zianglih:fp4-per-token

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Apr 27, 2026

Description

@HumansAnd

Implement per-token NVFP4 recipe with fprop only.
Currently, the per-token scaling is handled by separate pytorch code.
Quantization kernels are bitwise exact with existing TE reference implementation.

The following tests passed on B200:

python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
python3 -m pytest --tb=auto tests/pytorch/test_backward_override.py
python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
python3 -m pytest --tb=auto tests/pytorch/test_recipe.py
python3 -m pytest --tb=auto tests/pytorch/test_torch_compile.py
python3 -m pytest --tb=auto tests/pytorch/test_cpu_offloading.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto tests/pytorch/test_cuda_graphs.py
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add a per_token_activation field in nvfp4 recipe, can be turned on by NVTE_NVFP4_PER_TOKEN_ACTIVATION
  • New per-token nvfp4 quantize kernels in transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh, bitwise exact with existing TE pytorch reference implementation and per-tesor nvfp4 emulated implmentation.
  • Expand dequant kernel transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh to correctly handle this per-token nvfp4
  • In TransformerEngine/transformer_engine/pytorch/cpp_extensions/gemm.py, if per-token nvfp4 is detected, it conducts separate per-token scaling using pytorch code, after cublas gemm
  • Broad test coverage by expanding 7 test files
  • Modify 1d quant reference implementation in tests/cpp/operator/test_cast_nvfp4_transpose.cu to align with pytorch reference numerics

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zianglih zianglih marked this pull request as draft April 27, 2026 06:24
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 27, 2026

Greptile Summary

This PR implements per-token (per-row) NVFP4 fprop-only quantization for GroupedLinear, adding new CUDA kernels, amax buffer management, and a post-GEMM per-token scaling step in PyTorch. The kernel implementations are clean and the per-tensor fallback for backward is correctly enforced.

  • P1 — out_views[i].numel() crash in grouped GEMM (non-single_output path): When single_output=False the per-token early-return sets out_views = out. If the caller passes [None] * num_gemms, out_views[i].numel() immediately raises AttributeError. The contract that callers must pre-allocate outputs is not enforced or documented.
  • P2 — Bias fused into cuBLAS then manually subtracted: The bias is not stripped from gemm_args before the per-token cuBLAS call. The post-processing (out - bias) * scales + bias is algebraically correct but risks fp32 catastrophic cancellation when |bias| greatly exceeds |Z|. Stripping bias from the cuBLAS call and adding it once after scaling would be cleaner and more stable.

Confidence Score: 3/5

Safe to merge for fprop-only workloads without bias; grouped-GEMM path can crash at runtime when outputs are not pre-allocated.

One confirmed P1 (AttributeError on NoneType in grouped GEMM) plus the already-flagged previous P1s around the early-return return signature. Extensive test coverage reported on B200, but the pre-allocation contract is neither enforced nor documented.

transformer_engine/pytorch/cpp_extensions/gemm.py — grouped per-token path; transformer_engine/common/recipe/init.py — missing backward_override validation.

Important Files Changed

Filename Overview
transformer_engine/pytorch/cpp_extensions/gemm.py Adds per-token NVFP4 detection and GEMM dispatch; grouped-GEMM early-return crashes with AttributeError when outputs are not pre-allocated, and bias forward-then-subtract pattern risks fp32 cancellation.
transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh New per-token NVFP4 CUDA kernels for rowwise and columnwise quantization; columnwise num_rows divisibility check only enforced inside the launch helper, not at the public entry point.
transformer_engine/pytorch/csrc/quantizer.cpp Propagates per_token_activation flag through create_tensor, convert_and_update_tensor, and quantize_impl; amax buffer size and shape correctly adjusted for per-token mode.
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds per-token amax buffer support in bulk_allocate_nvfp4_tensors and split_quantize_nvfp4_impl_helper; standalone quantize_nvfp4_per_token function looks correct.
transformer_engine/pytorch/quantization.py Correctly disables per_token_activation for backward quantizers; forward quantizer role assignment via idx%3!=1 is undocumented.
transformer_engine/common/recipe/init.py Adds per_token_activation field to NVFP4BlockScaling recipe with env-var default; no post-init validation enforcing backward_override when per-token mode is on.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds per_token_activation field and correctly sizes amax buffers using flat_first_dim rows instead of 1.
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Correctly computes total_amax_elements and per-tensor amax offsets for the per-token path.
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh Dequant kernel correctly parameterised by amax_numel; selects tensor_amax[0] for per-tensor or tensor_amax[y] for per-token.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Activation Tensor] --> B[split_quantize per-token]
    B --> C[Per-token CUDA kernel\nFP4 data plus FP8 block scales\nper-row amax vector]
    C --> D[NVFP4TensorStorage\namax shape equals num-rows]
    D --> E[general_gemm detection\namax numel greater than 1]
    E --> F[Strip global amax\nset amax to ones]
    F --> G[cuBLAS GEMM\nblock-scaled only fp32 out]
    G --> H[Multiply by per-token scales\nactivation-amax times weight-amax]
    H --> I[Add bias then cast\nFinal output]
Loading

Reviews (6): Last reviewed commit: "Improve accuracy by unfolding weight per..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py Outdated
// Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32;
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f :
fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to change here to stay aligned with pytorch reference.

@zianglih zianglih marked this pull request as ready for review April 27, 2026 09:14
@zianglih zianglih marked this pull request as draft May 2, 2026 18:22
zianglih and others added 14 commits May 2, 2026 11:27
Signed-off-by: Ziang Li <ziangli@umich.edu>
Co-authored-by: Yigong Qin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@ziang-and ziang-and force-pushed the fp4-per-token branch 2 times, most recently from 6998f64 to 5b2f606 Compare May 2, 2026 19:10
zianglih added 5 commits May 2, 2026 16:33
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented May 2, 2026

The following extended tests all passed:

python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
python3 -m pytest --tb=auto tests/pytorch/test_backward_override.py
python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
python3 -m pytest --tb=auto tests/pytorch/test_recipe.py
python3 -m pytest --tb=auto tests/pytorch/test_torch_compile.py
python3 -m pytest --tb=auto tests/pytorch/test_cpu_offloading.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto tests/pytorch/test_cuda_graphs.py
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py

cd /root/TransformerEngine/tests/cpp
cmake --build build -j200
TEST_BIN="$(find build -type f -name test_operator -perm -u+x | head -n 1)"
"$TEST_BIN" --gtest_filter='*FusedCastTransposeNVFP4*:*DequantizeNVFP4*'
EOF

@zianglih zianglih marked this pull request as ready for review May 2, 2026 23:54
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
zianglih added 2 commits May 2, 2026 17:23
…clean up

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Comment on lines +350 to +354
else:
out_views = out
for i in range(num_gemms):
if out_views[i].numel() == 0:
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 out_views iteration crashes with AttributeError when single_output=False and outputs are not pre-allocated

When single_output=False, out_views = out (line 351). If the caller passes [None] * num_gemms, every out_views[i] is NoneType. The loop immediately calls out_views[i].numel(), raising AttributeError: 'NoneType' object has no attribute 'numel'.

All existing callers appear to pre-allocate, but this contract is not enforced or documented. At minimum, add a guard:

if not all(isinstance(v, torch.Tensor) for v in out_views):
    raise RuntimeError(
        "Per-token NVFP4 grouped GEMM requires pre-allocated output tensors."
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant