diff --git a/pyhealth/interpret/methods/integrated_gradients.py b/pyhealth/interpret/methods/integrated_gradients.py index 249f5a5e9..1f56eb9c1 100644 --- a/pyhealth/interpret/methods/integrated_gradients.py +++ b/pyhealth/interpret/methods/integrated_gradients.py @@ -166,7 +166,7 @@ class IntegratedGradients(BaseInterpreter): ... ) """ - def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 50): + def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 50, decision_threshold: float = 0.5): """Initialize IntegratedGradients interpreter. Args: @@ -181,6 +181,8 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 5 approximation of the path integral. Default is 50. Can be overridden in attribute() calls. More steps lead to better approximation but slower computation. + decision_threshold: Decision threshold used when inferring the default + target for binary and multilabel prediction. Default is 0.5. Raises: AssertionError: If use_embeddings=True but model does not @@ -193,6 +195,7 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 5 self.use_embeddings = use_embeddings self.steps = steps + self.decision_threshold = decision_threshold def attribute( @@ -217,11 +220,17 @@ def attribute( the integral. If None, uses self.steps (set during initialization). More steps lead to better approximation but slower computation. - target_class_idx: Target class index for attribution. - For binary classification (single logit output), this is - a no-op because there is only one output. For multi-class - or multi-label, specifies which class to explain. If None, - uses the argmax of model output. + target_class_idx: Target used for attribution computation. + Default behavior depends on prediction mode: + - binary: uses (sigmoid(logit) > decision_threshold) + - multiclass: uses argmax(logits) + - multilabel: uses (sigmoid(logits) > decision_threshold) + Notes: + - In binary mode, target_class_idx effectively behaves like a target + label (0 or 1), not a class-axis index. + - In multilabel mode, if target_class_idx is None, the default target is + a multi-hot mask of all predicted-positive labels, and attribution is + computed for the sum of those selected logits. **kwargs: Input data dictionary from a dataloader batch containing: - Feature keys (e.g., 'conditions', 'procedures'): @@ -534,6 +543,9 @@ def _compute_target_output( logits: Model output logits, shape [batch, num_classes]. target_indices: [batch] tensor of target class indices. + In multilabel mode, a multi-hot target corresponds to the sum of + the selected logits. + Returns: Scalar tensor for backpropagation. """