Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions pyhealth/interpret/methods/integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class IntegratedGradients(BaseInterpreter):
... )
"""

def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 50):
def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 50, decision_threshold: float = 0.5):
"""Initialize IntegratedGradients interpreter.

Args:
Expand All @@ -181,6 +181,8 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 5
approximation of the path integral. Default is 50.
Can be overridden in attribute() calls. More steps lead to
better approximation but slower computation.
decision_threshold: Decision threshold used when inferring the default
target for binary and multilabel prediction. Default is 0.5.

Raises:
AssertionError: If use_embeddings=True but model does not
Expand All @@ -193,6 +195,7 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 5

self.use_embeddings = use_embeddings
self.steps = steps
self.decision_threshold = decision_threshold

Comment on lines +198 to 199
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont see this field used anywhere.


def attribute(
Expand All @@ -217,11 +220,17 @@ def attribute(
the integral. If None, uses self.steps (set during
initialization). More steps lead to better approximation but
slower computation.
target_class_idx: Target class index for attribution.
For binary classification (single logit output), this is
a no-op because there is only one output. For multi-class
or multi-label, specifies which class to explain. If None,
uses the argmax of model output.
target_class_idx: Target used for attribution computation.
Default behavior depends on prediction mode:
- binary: uses (sigmoid(logit) > decision_threshold)
- multiclass: uses argmax(logits)
- multilabel: uses (sigmoid(logits) > decision_threshold)
Notes:
- In binary mode, target_class_idx effectively behaves like a target
label (0 or 1), not a class-axis index.
- In multilabel mode, if target_class_idx is None, the default target is
a multi-hot mask of all predicted-positive labels, and attribution is
computed for the sum of those selected logits.
**kwargs: Input data dictionary from a dataloader batch
containing:
- Feature keys (e.g., 'conditions', 'procedures'):
Expand Down Expand Up @@ -534,6 +543,9 @@ def _compute_target_output(
logits: Model output logits, shape [batch, num_classes].
target_indices: [batch] tensor of target class indices.

In multilabel mode, a multi-hot target corresponds to the sum of
the selected logits.

Returns:
Scalar tensor for backpropagation.
"""
Expand Down
Loading