diff --git a/docs/notebooks/structural_reliability.ipynb b/docs/notebooks/structural_reliability.ipynb
index 1f0ec14..2a4b4af 100644
--- a/docs/notebooks/structural_reliability.ipynb
+++ b/docs/notebooks/structural_reliability.ipynb
@@ -12,32 +12,35 @@
},
{
"cell_type": "code",
+ "execution_count": 10,
"id": "cedd5ec9-31f7-4e7f-91be-f73b1d1d00f1",
"metadata": {
- "scrolled": true,
"ExecuteTime": {
"end_time": "2024-05-30T09:06:58.579066Z",
"start_time": "2024-05-30T09:06:58.552735Z"
- }
+ },
+ "scrolled": true
},
+ "outputs": [],
"source": [
"import pathlib\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import simdec as sd"
- ],
- "outputs": [],
- "execution_count": 1
+ ]
},
{
- "metadata": {},
"cell_type": "markdown",
- "source": "Let's first load the dataset. It's a CSV file, each row represent a simulation or sample. The first column is the output or quantity of interest and other columns are parameters' values.",
- "id": "8700ed278bb1c06d"
+ "id": "8700ed278bb1c06d",
+ "metadata": {},
+ "source": [
+ "Let's first load the dataset. It's a CSV file, each row represent a simulation or sample. The first column is the output or quantity of interest and other columns are parameters' values."
+ ]
},
{
"cell_type": "code",
+ "execution_count": 11,
"id": "0b21846d-edff-4e39-a423-b247f81c4520",
"metadata": {
"ExecuteTime": {
@@ -45,25 +48,9 @@
"start_time": "2024-05-30T09:06:58.579870Z"
}
},
- "source": [
- "fname = pathlib.Path(\"../../tests/data/stress.csv\")\n",
- "\n",
- "data = pd.read_csv(fname)\n",
- "output_name, *inputs_names = list(data.columns)\n",
- "inputs, output = data[inputs_names], data[output_name]\n",
- "inputs.head()"
- ],
"outputs": [
{
"data": {
- "text/plain": [
- " Kf sigma_res Rp0.2 R\n",
- "0 2.454866 -84.530638 297.406169 -0.834480\n",
- "1 2.774116 347.586947 379.499452 -0.131827\n",
- "2 2.504617 946.567040 940.477667 -0.039126\n",
- "3 2.466723 74.222224 406.622486 0.440311\n",
- "4 2.615602 -32.937734 979.498038 0.419690"
- ],
"text/html": [
"
\n",
"\n",
- "
\n",
+ "\n",
" \n",
" \n",
" | | \n",
" | \n",
- " N° | \n",
- " colour | \n",
- " std | \n",
- " min | \n",
- " mean | \n",
- " max | \n",
- " probability | \n",
+ " N° | \n",
+ " colour | \n",
+ " std | \n",
+ " min | \n",
+ " mean | \n",
+ " max | \n",
+ " probability | \n",
"
\n",
" \n",
" | sigma_res | \n",
@@ -478,119 +482,133 @@
"
\n",
" \n",
" \n",
- " | low | \n",
- " low | \n",
- " 9 | \n",
- " | \n",
- " 95.34 | \n",
- " 11.19 | \n",
- " 282.08 | \n",
- " 460.07 | \n",
- " 0.19 | \n",
+ " low | \n",
+ " low | \n",
+ " 9 | \n",
+ " | \n",
+ " 84.73 | \n",
+ " 11.74 | \n",
+ " 226.72 | \n",
+ " 397.62 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | medium | \n",
- " 8 | \n",
- " | \n",
- " 87.76 | \n",
- " 67.53 | \n",
- " 407.79 | \n",
- " 622.35 | \n",
- " 0.12 | \n",
+ " medium | \n",
+ " 8 | \n",
+ " | \n",
+ " 82.65 | \n",
+ " 11.19 | \n",
+ " 385.26 | \n",
+ " 619.78 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | high | \n",
- " 7 | \n",
- " | \n",
- " 108.03 | \n",
- " 237.13 | \n",
- " 541.32 | \n",
- " 819.41 | \n",
- " 0.26 | \n",
+ " high | \n",
+ " 7 | \n",
+ " | \n",
+ " 101.23 | \n",
+ " 384.75 | \n",
+ " 567.44 | \n",
+ " 817.84 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | medium | \n",
- " low | \n",
- " 6 | \n",
- " | \n",
- " 34.92 | \n",
- " 350.30 | \n",
- " 434.90 | \n",
- " 523.84 | \n",
- " 0.09 | \n",
+ " medium | \n",
+ " low | \n",
+ " 6 | \n",
+ " | \n",
+ " 43.37 | \n",
+ " 268.77 | \n",
+ " 376.10 | \n",
+ " 515.25 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | medium | \n",
- " 5 | \n",
- " | \n",
- " 44.39 | \n",
- " 398.42 | \n",
- " 485.72 | \n",
- " 650.98 | \n",
- " 0.06 | \n",
+ " medium | \n",
+ " 5 | \n",
+ " | \n",
+ " 63.56 | \n",
+ " 318.15 | \n",
+ " 485.57 | \n",
+ " 711.72 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | high | \n",
- " 4 | \n",
- " | \n",
- " 75.80 | \n",
- " 414.21 | \n",
- " 534.19 | \n",
- " 814.43 | \n",
- " 0.11 | \n",
+ " high | \n",
+ " 4 | \n",
+ " | \n",
+ " 106.22 | \n",
+ " 420.61 | \n",
+ " 580.03 | \n",
+ " 819.41 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | high | \n",
- " low | \n",
- " 3 | \n",
- " | \n",
- " 35.43 | \n",
- " 630.24 | \n",
- " 703.90 | \n",
- " 794.81 | \n",
- " 0.06 | \n",
+ " high | \n",
+ " low | \n",
+ " 3 | \n",
+ " | \n",
+ " 132.84 | \n",
+ " 383.11 | \n",
+ " 576.80 | \n",
+ " 794.81 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | medium | \n",
- " 2 | \n",
- " | \n",
- " 33.95 | \n",
- " 656.34 | \n",
- " 725.15 | \n",
- " 816.48 | \n",
- " 0.04 | \n",
+ " medium | \n",
+ " 2 | \n",
+ " | \n",
+ " 129.07 | \n",
+ " 410.77 | \n",
+ " 611.25 | \n",
+ " 824.87 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | high | \n",
- " 1 | \n",
- " | \n",
- " 39.51 | \n",
- " 668.50 | \n",
- " 755.65 | \n",
- " 851.00 | \n",
- " 0.08 | \n",
+ " high | \n",
+ " 1 | \n",
+ " | \n",
+ " 127.28 | \n",
+ " 448.85 | \n",
+ " 643.91 | \n",
+ " 851.00 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
"
\n"
+ ],
+ "text/plain": [
+ ""
]
},
- "execution_count": 9,
+ "execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
- "execution_count": 9
+ "source": [
+ "table, styler = sd.tableau(\n",
+ " statistic=res.statistic,\n",
+ " var_names=res.var_names,\n",
+ " states=res.states,\n",
+ " bins=res.bins,\n",
+ " palette=palette,\n",
+ ")\n",
+ "styler"
+ ]
},
{
- "metadata": {},
"cell_type": "markdown",
- "source": "Congratulations, now you know how to use SimDec to get more insights on your problem!",
- "id": "3cdf58c4bd3dbbca"
+ "id": "3cdf58c4bd3dbbca",
+ "metadata": {},
+ "source": [
+ "Congratulations, now you know how to use SimDec to get more insights on your problem!"
+ ]
}
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "Python 3",
"language": "python",
"name": "python3"
},
@@ -604,7 +622,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.1"
+ "version": "3.11.14"
}
},
"nbformat": 4,
diff --git a/panel/index.html b/panel/index.html
index 52dda01..9cc0ffc 100644
--- a/panel/index.html
+++ b/panel/index.html
@@ -203,34 +203,36 @@
-
+
+
-
-
+
-
+
-
-
+
-
+
-
+
diff --git a/panel/simdec_app.py b/panel/simdec_app.py
index cc98c62..93b7a12 100644
--- a/panel/simdec_app.py
+++ b/panel/simdec_app.py
@@ -170,8 +170,11 @@ def filtered_si(sensitivity_indices_table, input_names):
def explained_variance_80(sensitivity_indices_table):
- si = sensitivity_indices_table.value["Indices"]
- pos_80 = bisect.bisect_right(np.cumsum(si), 0.8)
+ df = sensitivity_indices_table.value
+ df = df[df["Inputs"] != "Sum of Indices"]
+ si = df["Indices"].values
+ target = 0.8 * np.sum(si)
+ pos_80 = bisect.bisect_right(np.cumsum(si), target)
# pos_80 = max(2, pos_80)
# pos_80 = min(len(si), pos_80)
diff --git a/pyproject.toml b/pyproject.toml
index ae759fe..34473ab 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -41,6 +41,10 @@ dashboard = [
"cryptography",
]
+ipython = [
+ "ipython"
+]
+
test = [
"pytest",
"pytest-cov",
@@ -55,7 +59,7 @@ doc = [
]
dev = [
- "simdec[doc,test,dashboard]",
+ "simdec[doc,test,dashboard, ipython]",
"watchfiles",
"pre-commit",
]
diff --git a/src/simdec/__init__.py b/src/simdec/__init__.py
index 9394a8f..c91c1dc 100644
--- a/src/simdec/__init__.py
+++ b/src/simdec/__init__.py
@@ -2,6 +2,7 @@
from simdec.decomposition import *
from simdec.sensitivity_indices import *
from simdec.visualization import *
+from simdec.heterogeneity_indices import *
__all__ = [
"sensitivity_indices",
@@ -11,4 +12,5 @@
"two_output_visualization",
"tableau",
"palette",
+ "heterogeneity_indices",
]
diff --git a/src/simdec/heterogeneity_indices.py b/src/simdec/heterogeneity_indices.py
new file mode 100644
index 0000000..99cec27
--- /dev/null
+++ b/src/simdec/heterogeneity_indices.py
@@ -0,0 +1,239 @@
+from dataclasses import dataclass
+import logging
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+
+import simdec as sd
+
+logger = logging.getLogger(__name__)
+
+__all__ = ["heterogeneity_indices", "plot_heterogeneity"]
+
+
+@dataclass
+class HeterogeneityResult:
+ summary: pd.DataFrame
+ regional_profiles: pd.DataFrame
+ split_name: str
+
+
+def heterogeneity_indices(
+ output: pd.Series,
+ inputs: pd.DataFrame,
+ split_variable: str | pd.Series,
+ n_subdivisions: int | None = None,
+ plot: bool = False,
+) -> HeterogeneityResult:
+ """Heterogeneity indices.
+
+ Compute sensitivity-based heterogeneity across subdivisions
+ of a variable.
+
+ Parameters
+ ----------
+ output : pd.Series
+ Model output vector.
+ inputs : pd.DataFrame
+ Input/feature matrix.
+ split_variable : str or pd.Series
+ Variable to split on. If string, must be a column in 'inputs'.
+ n_subdivisions : int, optional
+ Number of regions for continuous variables. Defaults to 4.
+ plot : bool, default False
+ If True, displays a stacked bar chart of regional sensitivity profiles
+ by calling :func:`plot_heterogeneity`. The chart shows variance
+ contributions of each input across subdivisions of ``split_variable``,
+ ranked by global sensitivity indices. To capture the returned
+ ``matplotlib.axes.Axes`` object, call :func:`plot_heterogeneity`
+ directly on the result instead.
+
+ Returns
+ -------
+ res : HeterogeneityResult
+ An object with attributes:
+
+ summary : DataFrame
+ A summary of calculated heterogeneity indices.
+ regional_profiles : DataFrame
+ Regional sensitivity indices for each input across subdivisions.
+ split_name : str
+ The name of the variable used to split the data.
+
+ """
+ y = pd.Series(output).reset_index(drop=True)
+ X = pd.DataFrame(inputs).reset_index(drop=True)
+
+ if isinstance(split_variable, str):
+ if split_variable not in X.columns:
+ raise ValueError(f"'{split_variable}' not found in inputs.")
+ z = X[split_variable].reset_index(drop=True)
+ split_name = split_variable
+ else:
+ z = pd.Series(split_variable).reset_index(drop=True)
+ split_name = getattr(split_variable, "name", "split_variable")
+
+ unique_vals = z.dropna().unique()
+ n_unique = len(unique_vals)
+
+ # Determine if variable is categorical/binary
+ is_categorical = (
+ isinstance(z.dtype, pd.CategoricalDtype)
+ or pd.api.types.is_object_dtype(z)
+ or pd.api.types.is_string_dtype(z)
+ or pd.api.types.is_bool_dtype(z)
+ or n_unique <= 2
+ )
+
+ if is_categorical:
+ regions = z.astype("category")
+ else:
+ q = n_subdivisions if n_subdivisions is not None else 4
+ try:
+ regions = pd.qcut(z, q=q, duplicates="drop")
+ except ValueError as e:
+ raise ValueError(
+ f"Failed to bin '{split_name}' into {q} quantiles: {e}"
+ ) from e
+
+ regional_profiles = []
+ skipped = []
+
+ for region in regions.cat.categories:
+ mask = regions == region
+ n_in_region = mask.sum()
+
+ if n_in_region < 10:
+ # Need enough samples for meaningful sensitivity indices
+ skipped.append((region, n_in_region, "too few samples (< 10)"))
+ continue
+
+ X_sub = X.loc[mask]
+ y_sub = y.loc[mask]
+
+ # Skip if output has zero or near-zero variance in this region
+ if y_sub.var() < 1e-12:
+ skipped.append((region, n_in_region, "output variance ≈ 0"))
+ continue
+
+ try:
+ res = sd.sensitivity_indices(inputs=X_sub, output=y_sub)
+ si_vals = np.asarray(res.si).ravel()
+
+ # Guard against NaN/Inf from degenerate sensitivity computation
+ if not np.all(np.isfinite(si_vals)):
+ skipped.append((region, n_in_region, "non-finite SI values"))
+ continue
+
+ si_region = pd.Series(si_vals, index=X.columns, name=region)
+ regional_profiles.append(si_region)
+
+ except Exception as e:
+ skipped.append((region, n_in_region, f"exception: {e}"))
+ continue
+
+ if skipped:
+ logger.info("Skipped %d region(s) of '%s':", len(skipped), split_name)
+ for reg, n, reason in skipped:
+ logger.info(" - region=%r, n=%d, reason=%s", reg, n, reason)
+
+ if len(regional_profiles) < 2:
+ total_regions = len(regions.cat.categories)
+ valid = len(regional_profiles)
+ raise ValueError(
+ f"Not enough valid subdivisions to compute heterogeneity: "
+ f"{valid}/{total_regions} regions passed all checks for '{split_name}'.\n"
+ f"Skipped regions:\n"
+ "\n".join(f" {r!r}: n={n}, {reason} " for r, n, reason in skipped),
+ "\n\nTry: (1) reducing n_subdivisions, "
+ "(2) using a different split_variable, or "
+ "(3) ensuring more samples per region.",
+ )
+
+ regional_si = pd.concat(regional_profiles, axis=1)
+
+ res_global = sd.sensitivity_indices(inputs=X, output=y)
+ overall_si = pd.Series(
+ np.asarray(res_global.si).ravel(),
+ index=X.columns,
+ name="Overall_SI",
+ )
+
+ # Heterogeneity = 2 × population std dev across regions
+ hetero_scores = 2 * regional_si.std(axis=1, ddof=0)
+ total_hetero = hetero_scores.mean()
+
+ hetero_col_name = f"Heterogeneity (across {split_name})"
+ summary = pd.DataFrame(
+ {"Overall_SI": overall_si, hetero_col_name: hetero_scores}
+ ).sort_values(by=hetero_col_name, ascending=False)
+ summary.loc["SUM / TOTAL"] = [overall_si.sum(), total_hetero]
+
+ result = HeterogeneityResult(summary, regional_si, split_name)
+
+ if plot:
+ plot_heterogeneity(result)
+
+ return result
+
+
+def plot_heterogeneity(result: HeterogeneityResult, ax: plt.Axes = None) -> plt.Axes:
+ """Plot regional sensitivity profiles.
+
+ Parameters
+ ----------
+ result : HeterogeneityResult
+ The result object from heterogeneity_indices.
+ ax : matplotlib.axes.Axes, optional
+ Existing axes to plot on.
+
+ Returns
+ -------
+ ax : matplotlib.axes.Axes
+ The axes with the plot.
+
+ """
+ summary = result.summary
+ regional_si = result.regional_profiles
+ split_name = result.split_name
+
+ plot_order = summary.index[summary.index != "SUM / TOTAL"]
+ plot_order = (
+ summary.loc[plot_order].sort_values(by="Overall_SI", ascending=False).index
+ )
+
+ cmap = plt.colormaps["terrain"]
+ colors = [cmap(i) for i in np.linspace(0.05, 0.95, len(regional_si.index))]
+
+ data_to_plot = regional_si.loc[plot_order].T
+
+ if ax is None:
+ _, ax = plt.subplots(figsize=(10, 6))
+
+ data_to_plot.plot(
+ kind="bar",
+ stacked=True,
+ ax=ax,
+ color=colors,
+ edgecolor="white",
+ width=0.8,
+ )
+
+ ax.set_title(f"Sensitivity Profiles across {split_name}", fontsize=14)
+ ax.set_ylabel("Variance Contribution", fontsize=12)
+ ax.set_xlabel(f"Regions of {split_name}", fontsize=12)
+
+ ax.legend(
+ title="Inputs (Ranked by Global SI)",
+ bbox_to_anchor=(1.05, 1),
+ loc="upper left",
+ )
+
+ ax.tick_params(axis="x", labelrotation=45)
+ ax.grid(axis="y", linestyle="--", alpha=0.7)
+
+ if plt.get_backend().lower() != "agg":
+ plt.tight_layout()
+
+ return ax
diff --git a/src/simdec/sensitivity_indices.py b/src/simdec/sensitivity_indices.py
index ed9a07b..6c63be6 100644
--- a/src/simdec/sensitivity_indices.py
+++ b/src/simdec/sensitivity_indices.py
@@ -37,7 +37,9 @@ class SensitivityAnalysisResult:
def sensitivity_indices(
- inputs: pd.DataFrame | np.ndarray, output: pd.DataFrame | np.ndarray
+ inputs: pd.DataFrame | np.ndarray,
+ output: pd.DataFrame | np.ndarray,
+ print_indices: bool = False,
) -> SensitivityAnalysisResult:
"""Sensitivity indices.
@@ -50,6 +52,8 @@ def sensitivity_indices(
Input variables.
output : ndarray or DataFrame of shape (n_runs, 1)
Target variable.
+ print_indices : bool, default False
+ If True, displays computed indices.
Returns
-------
@@ -97,11 +101,18 @@ def sensitivity_indices(
"""
# Handle inputs conversion
if isinstance(inputs, pd.DataFrame):
- cat_columns = inputs.select_dtypes(["category", "O"]).columns
- inputs[cat_columns] = inputs[cat_columns].apply(
- lambda x: x.astype("category").cat.codes
- )
+ var_names = inputs.columns.tolist()
+ cat_cols = inputs.select_dtypes(include=["category", "O", "string"]).columns
+ if not cat_cols.empty:
+ inputs = inputs.copy() # Avoid SettingWithCopyWarning
+ inputs[cat_cols] = inputs[cat_cols].apply(
+ lambda x: x.astype("category").cat.codes
+ )
inputs = inputs.to_numpy()
+ else:
+ inputs = np.asarray(inputs)
+ # Fallback names if it's just a numpy array
+ var_names = [f"x{i}" for i in range(inputs.shape[1])]
# Handle output conversion first, then flatten
if isinstance(output, (pd.DataFrame, pd.Series)):
@@ -181,4 +192,12 @@ def sensitivity_indices(
for k in range(n_factors):
si[k] = foe[k] + (soe[:, k].sum() / 2)
+ if print_indices:
+ df_foe = pd.DataFrame(foe, index=var_names, columns=["First-order effect"])
+ df_soe = pd.DataFrame(soe, index=var_names, columns=var_names)
+ df_si = pd.DataFrame(si, index=var_names, columns=["Combined effect"])
+
+ df_indices = pd.concat([df_foe, df_soe, df_si], axis=1)
+ print(f"\n{df_indices}\n")
+
return SensitivityAnalysisResult(si, foe, soe)
diff --git a/src/simdec/visualization.py b/src/simdec/visualization.py
index e77adf4..5f4378d 100644
--- a/src/simdec/visualization.py
+++ b/src/simdec/visualization.py
@@ -10,9 +10,18 @@
import seaborn as sns
import pandas as pd
from pandas.io.formats.style import Styler
+import warnings
+
+from simdec.decomposition import DecompositionResult
__all__ = ["visualization", "two_output_visualization", "tableau", "palette"]
+try:
+ from IPython.display import display
+
+ HAS_IPYTHON = True
+except ImportError:
+ HAS_IPYTHON = False
SEQUENTIAL_PALETTES = [
"#DC267F",
@@ -139,6 +148,8 @@ def visualization(
n_bins: str | int = "auto",
kind: Literal["histogram", "boxplot"] = "histogram",
ax=None,
+ print_legend: bool = False,
+ decomposition: DecompositionResult | None = None,
) -> plt.Axes:
"""Histogram plot of scenarios.
@@ -154,6 +165,10 @@ def visualization(
Histogram or Box Plot.
ax : Axes, optional
Matplotlib axis.
+ print_legend: Boolean, optional
+ Prints plot legend.
+ decomposition: DecompositionResult, optional
+ Required for print_legend.
Returns
-------
@@ -186,6 +201,31 @@ def visualization(
)
else:
raise ValueError("'kind' can only be 'histogram' or 'boxplot'")
+
+ if print_legend:
+ if not HAS_IPYTHON:
+ warnings.warn(
+ "print_legend=True requires ipython to be installed. "
+ "Install it with: pip install simdec[ipython]",
+ stacklevel=2,
+ )
+ elif decomposition is None:
+ warnings.warn(
+ "print_legend=True requires the decomposition parameter. Table skipped.",
+ stacklevel=2,
+ )
+ else:
+ try:
+ _, styler = tableau(
+ var_names=decomposition.var_names,
+ statistic=decomposition.statistic,
+ states=decomposition.states,
+ bins=decomposition.bins,
+ palette=palette,
+ )
+ display(styler)
+ except ImportError:
+ pass
return ax
@@ -200,6 +240,8 @@ def two_output_visualization(
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
r_scatter: float = 1.0,
+ print_legend: bool = False,
+ decomposition: DecompositionResult | None = None,
) -> tuple[plt.Figure, np.ndarray]:
"""Two-output visualization.
@@ -229,6 +271,10 @@ def two_output_visualization(
Limits for the secondary output axis (scatter y / right histogram).
r_scatter : float, default 1.0
Fraction of data points shown in the scatter plot.
+ print_legend: Boolean, optional
+ Prints plot legend.
+ decomposition: DecompositionResult, optional
+ Required for print_legend.
Returns
-------
@@ -286,6 +332,31 @@ def two_output_visualization(
axs[1, 1].axis("off")
fig.subplots_adjust(wspace=-0.015, hspace=0)
+
+ if print_legend:
+ if not HAS_IPYTHON:
+ warnings.warn(
+ "print_legend=True requires ipython to be installed. "
+ "Install it with: pip install simdec[ipython]",
+ stacklevel=2,
+ )
+ elif decomposition is None:
+ warnings.warn(
+ "print_legend=True requires the decomposition parameter. Table skipped.",
+ stacklevel=2,
+ )
+ else:
+ try:
+ _, styler = tableau(
+ var_names=decomposition.var_names,
+ statistic=decomposition.statistic,
+ states=decomposition.states,
+ bins=decomposition.bins,
+ palette=palette,
+ )
+ display(styler)
+ except ImportError:
+ pass
return fig, axs
diff --git a/tests/test_heterogeneity_indices.py b/tests/test_heterogeneity_indices.py
new file mode 100644
index 0000000..36fdd06
--- /dev/null
+++ b/tests/test_heterogeneity_indices.py
@@ -0,0 +1,131 @@
+import pathlib
+import pytest
+
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+
+import simdec as sd
+
+
+path_data = pathlib.Path(__file__).parent / "data"
+
+
+@pytest.fixture(autouse=True)
+def close_plots():
+ yield
+ plt.close("all")
+
+
+@pytest.fixture
+def dummy_data():
+ rng = np.random.default_rng(42)
+ n = 200
+
+ inputs = pd.DataFrame(
+ {
+ "x1": rng.random(n),
+ "x2": rng.random(n),
+ "x3": rng.random(n),
+ "cat_var": rng.choice(["A", "B", "C"], size=n),
+ }
+ )
+
+ # Create a dummy output dependent on x1 and x2
+ y = 2.0 * inputs["x1"] + 0.5 * inputs["x2"] + rng.normal(0, 0.1, n)
+ return inputs, y
+
+
+def test_heterogeneity_categorical_str(dummy_data):
+ """Test splitting by a string column name (categorical)."""
+ inputs, y = dummy_data
+
+ res = sd.heterogeneity_indices(output=y, inputs=inputs, split_variable="cat_var")
+
+ # Check object structure
+ assert hasattr(res, "summary")
+ assert hasattr(res, "regional_profiles")
+ assert res.split_name == "cat_var"
+
+ # Check DataFrame structures
+ assert "Overall_SI" in res.summary.columns
+ assert "Heterogeneity (across cat_var)" in res.summary.columns
+ assert "SUM / TOTAL" in res.summary.index
+
+ # 3 categories
+ assert res.regional_profiles.shape[1] == 3
+ assert list(res.regional_profiles.index) == ["x1", "x2", "x3", "cat_var"]
+
+
+def test_heterogeneity_continuous_series(dummy_data):
+ """Test splitting by passing a pandas Series (continuous)."""
+ inputs, y = dummy_data
+ split_series = inputs["x1"]
+
+ res = sd.heterogeneity_indices(
+ output=y, inputs=inputs, split_variable=split_series, n_subdivisions=4
+ )
+
+ assert res.split_name == "x1"
+ assert res.regional_profiles.shape[1] == 4 # 4 quantiles
+
+
+def test_heterogeneity_missing_column(dummy_data):
+ """Test that a ValueError is raised when split_variable is not in inputs."""
+ inputs, y = dummy_data
+
+ with pytest.raises(ValueError, match="'missing_col' not found in inputs"):
+ sd.heterogeneity_indices(output=y, inputs=inputs, split_variable="missing_col")
+
+
+def test_heterogeneity_too_few_regions():
+ """Test that a ValueError is raised when there are not enough valid subdivisions."""
+ inputs = pd.DataFrame({"x1": [1, 2, 3, 4, 5], "cat": ["A", "B", "C", "D", "E"]})
+ y = pd.Series([1.0, 2.0, 3.0, 4.0, 5.0])
+
+ with pytest.raises(ValueError, match="Not enough valid subdivisions"):
+ sd.heterogeneity_indices(output=y, inputs=inputs, split_variable="cat")
+
+
+def test_heterogeneity_plot_argument(dummy_data):
+ """Test that setting plot=True works without throwing an error."""
+ inputs, y = dummy_data
+
+ res = sd.heterogeneity_indices(
+ output=y, inputs=inputs, split_variable="cat_var", plot=True
+ )
+
+ assert res is not None
+ # Figure exists in the active pyplot state
+ assert len(plt.get_fignums()) > 0
+
+
+def test_plot_heterogeneity(dummy_data):
+ """Test the independent plot_heterogeneity function."""
+ inputs, y = dummy_data
+
+ res = sd.heterogeneity_indices(output=y, inputs=inputs, split_variable="cat_var")
+
+ ax = sd.plot_heterogeneity(res)
+
+ assert isinstance(ax, plt.Axes)
+ assert ax.get_title() == "Sensitivity Profiles across cat_var"
+ assert ax.get_ylabel() == "Variance Contribution"
+ assert ax.get_xlabel() == "Regions of cat_var"
+
+
+def test_heterogeneity_real_data():
+ """Integration test using the real stress.csv dataset from the project."""
+ fname = path_data / "stress.csv"
+ data = pd.read_csv(fname)
+ output_name, *v_names = list(data.columns)
+
+ inputs, output = data[v_names], data[output_name]
+
+ res = sd.heterogeneity_indices(
+ output=output, inputs=inputs, split_variable="R", n_subdivisions=2
+ )
+
+ assert res.split_name == "R"
+ assert not res.summary.empty
+ assert res.regional_profiles.shape[1] == 2
diff --git a/tests/test_visualization.py b/tests/test_visualization.py
index a974ae0..35f84a9 100644
--- a/tests/test_visualization.py
+++ b/tests/test_visualization.py
@@ -1,4 +1,6 @@
import pytest
+import pathlib
+import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import simdec as sd
@@ -62,3 +64,71 @@ def test_two_output_visualization_r_scatter():
bins=bins, bins2=bins2, palette=palette, r_scatter=0.5
)
assert isinstance(fig, plt.Figure)
+
+
+# Setup data path to match your decomposition tests
+path_data = pathlib.Path(__file__).parent / "data"
+
+
+@pytest.fixture
+def stress_results():
+ """Runs the actual decomposition to get a real result object."""
+ fname = path_data / "stress.csv"
+ data = pd.read_csv(fname)
+ output_name, *v_names = list(data.columns)
+ inputs, output = data[v_names], data[output_name]
+ si = np.array([0.04, 0.50, 0.11, 0.28])
+
+ res = sd.decomposition(
+ inputs=inputs, output=output, sensitivity_indices=si, dec_limit=1
+ )
+ return res
+
+
+def test_visualization_with_legend(stress_results):
+ """Verify visualization works with print_legend using live decomposition results."""
+ # Generate palette based on the live results
+ palette = sd.palette(stress_results.states)
+
+ # Test single visualization
+ ax = sd.visualization(
+ bins=stress_results.bins,
+ palette=palette,
+ print_legend=True,
+ decomposition=stress_results,
+ )
+
+ assert isinstance(ax, plt.Axes)
+ # Check that the columns were handled (RangeIndex is applied inside visualization)
+ assert isinstance(stress_results.bins.columns, pd.RangeIndex)
+
+
+def test_two_output_visualization_with_legend(stress_results):
+ """Verify two_output works with print_legend using live decomposition results."""
+ palette = sd.palette(stress_results.states)
+
+ # Using the same bins for both axes for testing purposes
+ fig, axs = sd.two_output_visualization(
+ bins=stress_results.bins,
+ bins2=stress_results.bins,
+ palette=palette,
+ print_legend=True,
+ decomposition=stress_results,
+ output_name="Primary",
+ output_name2="Secondary",
+ )
+
+ assert isinstance(fig, plt.Figure)
+ assert axs.shape == (2, 2)
+ assert axs[1, 0].get_xlabel() == "Primary"
+ assert axs[1, 0].get_ylabel() == "Secondary"
+
+
+def test_visualization_missing_decomposition_warning():
+ """Verify that omitting the decomposition object triggers a warning, not a crash."""
+ # Using small dummy data for a quick standalone check
+ bins = pd.DataFrame({"s1": [1, 2]})
+ pal = [[1, 0, 0, 1]]
+
+ with pytest.warns(UserWarning, match="requires the decomposition parameter"):
+ sd.visualization(bins=bins, palette=pal, print_legend=True, decomposition=None)