[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938
[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938hxbai wants to merge 8 commits into
Conversation
Greptile SummaryThis PR makes the linear-gate offset in
Confidence Score: 5/5Safe to merge. All forward and backward code paths correctly propagate the new offset, the original C API is preserved as a compatibility wrapper, and the cuDNN fused path is correctly gated out when the offset differs from 1.0. The new No files require special attention. All seventeen changed files are consistent with each other. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["Python: ClampedSwiGLU(glu_linear_offset=x)"] --> B{"Is fusion eligible?\nalpha ≈ 1.702 AND\nglu_linear_offset ≈ 1.0?"}
B -- Yes --> C["cuDNN fused kernel\nForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8\nact_func='geglu' (offset=1.0 hardcoded)"]
B -- No --> D["tex.clamped_swiglu(limit, alpha, glu_linear_offset)"]
D --> E["nvte_clamped_swiglu_v2(limit, alpha, offset, stream)"]
E --> F["ClampedSwiGLUParam{limit, alpha, glu_linear_offset}"]
F --> G1["vectorized_pointwise.h\nval2 = clamp(val2) + p.glu_linear_offset"]
F --> G2["gated_fp8.cuh\ngate_elt = clamp(gate_elt) + p.glu_linear_offset"]
F --> G3["gated_mxfp8.cuh\ngate_elt = clamp(gate_elt) + p.glu_linear_offset"]
H["nvte_clamped_swiglu (v1 compat)"] --> |"hard-codes offset=1.0f"| E
I["JAX: ClampedSwigluParams(glu_linear_offset=x)"] --> J["to_ffi_lowering_dict()"]
J --> K["XLA FFI: ClampedSwigluConfig{limit, alpha, glu_linear_offset}"]
K --> E
Reviews (9): Last reviewed commit: "Merge branch 'main' into swiglu_offset" | Re-trigger Greptile |
| * \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0). | ||
| * \param[in] stream CUDA stream used for the operation. | ||
| */ |
There was a problem hiding this comment.
nvte_clamped_swiglu and nvte_clamped_dswiglu are public symbols declared in a versioned public header. Inserting glu_linear_offset before cudaStream_t is an ABI-breaking change: any external binary or shared library compiled against the old header will silently pass the stream pointer as the offset and a garbage value as the stream, leading to undefined behavior at runtime rather than a clean compile error if called via a pre-compiled library. This should be acknowledged as a breaking change in the PR checklist, and — if this library follows semantic versioning or a compatibility guarantee — a deprecation/transition path or version bump is needed.
timmoon10
left a comment
There was a problem hiding this comment.
The fused op for grouped MLP is hard-coded for GPT-OSS, so we should make sure not to fuse if glu_linear_offset != 1:
TransformerEngine/transformer_engine/pytorch/ops/_common.py
Lines 180 to 183 in df0025b
|
/te-ci |
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
|
|
||
| void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, | ||
| cudaStream_t stream) { | ||
| float glu_linear_offset, cudaStream_t stream) { |
There was a problem hiding this comment.
Can we define new APIs named nvte_clamped_swiglu_v2 and nvte_clamped_dswiglu_v2
and deprecate this API here to not break backward compatibility?
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
|
/te-ci |
|
/te-ci |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Overall looks pretty good from the JAX side, thanks for adding the JAX changes too! Left a couple small comments
| ::xla::ffi::StructMember<float>("limit"), | ||
| ::xla::ffi::StructMember<float>("alpha")); | ||
| ::xla::ffi::StructMember<float>("alpha"), | ||
| ::xla::ffi::StructMember<float>("glu_linear_offset")); |
There was a problem hiding this comment.
can we add a default value for users on HLO from a previous version? Would glu_linear_offset=1 be the same as the current behavior on main?
| int: Hash value of the dataclass instance. | ||
| """ | ||
| return hash((self.limit, self.alpha)) | ||
| return hash((self.limit, self.alpha, self.glu_linear_offset)) |
There was a problem hiding this comment.
Can you update one of the tests here to use a non-default value of glu_linear_offset?
Description
The previous ClampedSwiGLU follows GPT-OSS, which hard-coded the offset 1.0.
DeepSeek-V4 uses ClampedSwiGLU without alpha and offset.
This PR makes the offset of ClampedSwiGLU configurable to support DeepSeek-V4.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: