Skip to content

[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938

Open
hxbai wants to merge 8 commits into
NVIDIA:mainfrom
hxbai:swiglu_offset
Open

[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938
hxbai wants to merge 8 commits into
NVIDIA:mainfrom
hxbai:swiglu_offset

Conversation

@hxbai
Copy link
Copy Markdown
Contributor

@hxbai hxbai commented Apr 28, 2026

Description

The previous ClampedSwiGLU follows GPT-OSS, which hard-coded the offset 1.0.
DeepSeek-V4 uses ClampedSwiGLU without alpha and offset.
This PR makes the offset of ClampedSwiGLU configurable to support DeepSeek-V4.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR makes the linear-gate offset in ClampedSwiGLU configurable (previously hardcoded to 1.0) to support DeepSeek-V4's variant, which uses an offset of 0.0. The implementation introduces a new _v2 C API pair (nvte_clamped_swiglu_v2 / nvte_clamped_dswiglu_v2) with an explicit glu_linear_offset parameter while keeping the original API as a compatibility wrapper.

  • New glu_linear_offset field added to ClampedSwiGLUParam struct (default 1.0f); all affected CUDA kernels in vectorized_pointwise.h, gated_fp8.cuh, and gated_mxfp8.cuh updated to use the struct field instead of the hardcoded constant.
  • PyTorch (clamped_swiglu/clamped_dswiglu extensions) and JAX (XLA FFI struct decoding and ClampedSwigluParams) bindings consistently updated with the new parameter, threaded from public Python API through C++ to CUDA.
  • The cuDNN-fused GroupedLinear + ScaledClampedQGeGLU + GroupedLinear path in _common.py is correctly gated out when glu_linear_offset != 1.0, since the cuDNN grouped_gemm_glu_wrapper_sm100 only supports the hardcoded-offset variant.

Confidence Score: 5/5

Safe to merge. All forward and backward code paths correctly propagate the new offset, the original C API is preserved as a compatibility wrapper, and the cuDNN fused path is correctly gated out when the offset differs from 1.0.

The new glu_linear_offset field is threaded consistently through every affected CUDA kernel, C++ extension, and Python binding in both the PyTorch and JAX stacks. The mathematical correctness of the backward pass is preserved — the derivative of clamp(x) + offset with respect to x is identical to the derivative of clamp(x), so no changes to the boolean gate-derivative were needed. The old C API symbols are untouched, and the struct field defaults to 1.0f, so all existing callers produce identical results without recompilation.

No files require special attention. All seventeen changed files are consistent with each other.

Important Files Changed

Filename Overview
transformer_engine/common/util/math.h Adds glu_linear_offset field (default 1.0f) to ClampedSwiGLUParam struct; backward-compatible struct change.
transformer_engine/common/activation/swiglu.cu Preserves original nvte_clamped_swiglu/nvte_clamped_dswiglu as wrappers with hardcoded offset=1.0; adds new _v2 variants that accept explicit glu_linear_offset. Clean backward-compatible API design.
transformer_engine/common/util/vectorized_pointwise.h Replaces hardcoded + 1 and + 1.0f with + p.glu_linear_offset in both forward and backward gated-activation kernels. Derivative of the gate (clamp boolean) is correctly unchanged.
transformer_engine/common/cast/fp8/gated_fp8.cuh Replaces hardcoded + 1 with + p.glu_linear_offset in the FP8 gated kernel; backward derivative unchanged (offset is a constant w.r.t. the gate input).
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh Both MXFP8 gated kernel instances updated with p.glu_linear_offset; same correct treatment of derivative as gated_fp8.cuh.
transformer_engine/pytorch/ops/_common.py Adds glu_linear_offset guard to fuse_grouped_mlp_ops, correctly preventing cuDNN fused path when offset is not 1.0 since that kernel hardcodes the +1 offset.
transformer_engine/pytorch/ops/basic/swiglu.py Adds glu_linear_offset to ClampedSwiGLU and ScaledClampedQGeGLU; consistently threads it through _tex_clamped_swiglu_forward and _tex_clamped_dswiglu calls.
transformer_engine/jax/cpp_extensions/activation.py Adds glu_linear_offset to ClampedSwigluParams, updates __hash__ and to_ffi_lowering_dict; updates _convert_to_activation_function closure to use the offset.
transformer_engine/jax/csrc/extensions.h Adds glu_linear_offset to ClampedSwigluConfig struct and XLA FFI struct member registration.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["Python: ClampedSwiGLU(glu_linear_offset=x)"] --> B{"Is fusion eligible?\nalpha ≈ 1.702 AND\nglu_linear_offset ≈ 1.0?"}
    B -- Yes --> C["cuDNN fused kernel\nForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8\nact_func='geglu' (offset=1.0 hardcoded)"]
    B -- No --> D["tex.clamped_swiglu(limit, alpha, glu_linear_offset)"]
    D --> E["nvte_clamped_swiglu_v2(limit, alpha, offset, stream)"]
    E --> F["ClampedSwiGLUParam{limit, alpha, glu_linear_offset}"]
    F --> G1["vectorized_pointwise.h\nval2 = clamp(val2) + p.glu_linear_offset"]
    F --> G2["gated_fp8.cuh\ngate_elt = clamp(gate_elt) + p.glu_linear_offset"]
    F --> G3["gated_mxfp8.cuh\ngate_elt = clamp(gate_elt) + p.glu_linear_offset"]
    H["nvte_clamped_swiglu (v1 compat)"] --> |"hard-codes offset=1.0f"| E
    I["JAX: ClampedSwigluParams(glu_linear_offset=x)"] --> J["to_ffi_lowering_dict()"]
    J --> K["XLA FFI: ClampedSwigluConfig{limit, alpha, glu_linear_offset}"]
    K --> E
Loading

Reviews (9): Last reviewed commit: "Merge branch 'main' into swiglu_offset" | Re-trigger Greptile

Comment on lines 339 to 341
* \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0).
* \param[in] stream CUDA stream used for the operation.
*/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Breaking public C API change

nvte_clamped_swiglu and nvte_clamped_dswiglu are public symbols declared in a versioned public header. Inserting glu_linear_offset before cudaStream_t is an ABI-breaking change: any external binary or shared library compiled against the old header will silently pass the stream pointer as the offset and a garbage value as the stream, leading to undefined behavior at runtime rather than a clean compile error if called via a pre-compiled library. This should be acknowledged as a breaking change in the PR checklist, and — if this library follows semantic versioning or a compatibility guarantee — a deprecation/transition path or version bump is needed.

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused op for grouped MLP is hard-coded for GPT-OSS, so we should make sure not to fuse if glu_linear_offset != 1:

elif isinstance(window[1], ScaledClampedQGeGLU) and (
abs(window[1]._clamped.alpha - 1.702) > 0.001
or not _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu()
):

@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@hxbai hxbai marked this pull request as draft April 29, 2026 00:28
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@hxbai hxbai marked this pull request as ready for review April 29, 2026 01:01

void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream) {
float glu_linear_offset, cudaStream_t stream) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define new APIs named nvte_clamped_swiglu_v2 and nvte_clamped_dswiglu_v2
and deprecate this API here to not break backward compatibility?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrited this part

vthumbe1503 and others added 3 commits May 6, 2026 11:38
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci

hxbai added 2 commits May 12, 2026 15:13
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci

Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks pretty good from the JAX side, thanks for adding the JAX changes too! Left a couple small comments

::xla::ffi::StructMember<float>("limit"),
::xla::ffi::StructMember<float>("alpha"));
::xla::ffi::StructMember<float>("alpha"),
::xla::ffi::StructMember<float>("glu_linear_offset"));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a default value for users on HLO from a previous version? Would glu_linear_offset=1 be the same as the current behavior on main?

int: Hash value of the dataclass instance.
"""
return hash((self.limit, self.alpha))
return hash((self.limit, self.alpha, self.glu_linear_offset))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update one of the tests here to use a non-default value of glu_linear_offset?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants