Implement per-token NVFP4 fprop recipe#2931
Conversation
Greptile SummaryThis PR implements per-token (per-row) NVFP4 fprop-only quantization for
Confidence Score: 3/5Safe to merge for fprop-only workloads without bias; grouped-GEMM path can crash at runtime when outputs are not pre-allocated. One confirmed P1 (AttributeError on NoneType in grouped GEMM) plus the already-flagged previous P1s around the early-return return signature. Extensive test coverage reported on B200, but the pre-allocation contract is neither enforced nor documented. transformer_engine/pytorch/cpp_extensions/gemm.py — grouped per-token path; transformer_engine/common/recipe/init.py — missing backward_override validation. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Activation Tensor] --> B[split_quantize per-token]
B --> C[Per-token CUDA kernel\nFP4 data plus FP8 block scales\nper-row amax vector]
C --> D[NVFP4TensorStorage\namax shape equals num-rows]
D --> E[general_gemm detection\namax numel greater than 1]
E --> F[Strip global amax\nset amax to ones]
F --> G[cuBLAS GEMM\nblock-scaled only fp32 out]
G --> H[Multiply by per-token scales\nactivation-amax times weight-amax]
H --> I[Add bias then cast\nFinal output]
Reviews (6): Last reviewed commit: "Improve accuracy by unfolding weight per..." | Re-trigger Greptile |
| // Compute "correct" per-block encoding scaling factor | ||
| const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; | ||
| const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : | ||
| fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm); |
There was a problem hiding this comment.
We have to change here to stay aligned with pytorch reference.
Signed-off-by: Ziang Li <ziangli@umich.edu> Co-authored-by: Yigong Qin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
6998f64 to
5b2f606
Compare
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
The following extended tests all passed: |
…clean up Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
| else: | ||
| out_views = out | ||
| for i in range(num_gemms): | ||
| if out_views[i].numel() == 0: | ||
| continue |
There was a problem hiding this comment.
out_views iteration crashes with AttributeError when single_output=False and outputs are not pre-allocated
When single_output=False, out_views = out (line 351). If the caller passes [None] * num_gemms, every out_views[i] is NoneType. The loop immediately calls out_views[i].numel(), raising AttributeError: 'NoneType' object has no attribute 'numel'.
All existing callers appear to pre-allocate, but this contract is not enforced or documented. At minimum, add a guard:
if not all(isinstance(v, torch.Tensor) for v in out_views):
raise RuntimeError(
"Per-token NVFP4 grouped GEMM requires pre-allocated output tensors."
)
Description
@HumansAnd
Implement per-token NVFP4 recipe with fprop only.
Currently, the per-token scaling is handled by separate pytorch code.
Quantization kernels are bitwise exact with existing TE reference implementation.
The following tests passed on B200:
Type of change
Changes
Please list the changes introduced in this PR:
per_token_activationfield in nvfp4 recipe, can be turned on byNVTE_NVFP4_PER_TOKEN_ACTIVATIONtransformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh, bitwise exact with existing TE pytorch reference implementation and per-tesor nvfp4 emulated implmentation.transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuhto correctly handle this per-token nvfp4TransformerEngine/transformer_engine/pytorch/cpp_extensions/gemm.py, if per-token nvfp4 is detected, it conducts separate per-token scaling using pytorch code, after cublas gemmtests/cpp/operator/test_cast_nvfp4_transpose.cuto align with pytorch reference numericsChecklist: