diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..ed67dcfcc 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -246,3 +246,4 @@ Available Datasets datasets/pyhealth.datasets.TCGAPRADDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils + datasets/pyhealth.datasets.mimic3_cf \ No newline at end of file diff --git a/docs/api/datasets/pyhealth.datasets.mimic3_cf.rst b/docs/api/datasets/pyhealth.datasets.mimic3_cf.rst new file mode 100644 index 000000000..b0d47472c --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.mimic3_cf.rst @@ -0,0 +1,26 @@ +pyhealth.datasets.mimic3_cf +=========================== + +Overview +-------- + +MIMIC3CirculatoryFailureDataset is a MIMIC-III based dataset for early warning +prediction of circulatory failure. + +It constructs an ICU-stay-level cohort from PATIENTS, ADMISSIONS, and ICUSTAYS, +and uses CHARTEVENTS to extract Mean Arterial Pressure (MAP) measurements. + +Circulatory failure is defined using a proxy event: + +- MAP < 65 mmHg + +For each ICU stay, the dataset identifies the first occurrence of this event and +supports building task-ready patient records for downstream prediction tasks. + +API Reference +------------- + +.. autoclass:: pyhealth.datasets.MIMIC3CirculatoryFailureDataset + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..e68d80185 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Circulatory Failure Prediction \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.circulatory_failure_prediction.rst b/docs/api/tasks/pyhealth.tasks.circulatory_failure_prediction.rst new file mode 100644 index 000000000..6a9cf0e2c --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.circulatory_failure_prediction.rst @@ -0,0 +1,24 @@ +pyhealth.tasks.circulatory_failure_prediction +============================================= + +Overview +-------- + +CirculatoryFailurePredictionTask defines a time-series prediction task for early +detection of circulatory failure. + +The task predicts whether a patient will experience circulatory failure within +the next 12 hours based on physiological measurements. + +Label definition: + +- label = 1 if circulatory failure occurs within the next 12 hours +- label = 0 otherwise + +API Reference +------------- + +.. autoclass:: pyhealth.tasks.CirculatoryFailurePredictionTask + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/mimic3_cf_circulatory_failure_logreg.py b/examples/mimic3_cf_circulatory_failure_logreg.py new file mode 100644 index 000000000..72e81cb5a --- /dev/null +++ b/examples/mimic3_cf_circulatory_failure_logreg.py @@ -0,0 +1,145 @@ +""" +Example ablation script for MIMIC-III circulatory failure prediction. + +This script compares different prediction windows (6h, 12h, 24h) and +feature settings using logistic regression. It is intended as an example +usage script for the dataset-task pipeline and ablation study. +""" + +from pyhealth.datasets import MIMIC3CirculatoryFailureDataset +from pyhealth.tasks import CirculatoryFailurePredictionTask + +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, roc_auc_score, recall_score + + +def samples_to_df(samples: list[dict]) -> pd.DataFrame: + rows = [] + for s in samples: + rows.append( + { + "patient_id": s["patient_id"], + "icustay_id": s["icustay_id"], + "gender": s["gender"], + "timestamp": s["timestamp"], + "map": s["features"]["map"], + "label": s["label"], + } + ) + df = pd.DataFrame(rows) + return df + + +def add_advanced_features(df: pd.DataFrame) -> pd.DataFrame: + """Add simple temporal features for the advanced setting.""" + df = df.sort_values(["icustay_id", "timestamp"]).copy() + df["map_prev"] = df.groupby("icustay_id")["map"].shift(1) + df["map_diff"] = df["map"] - df["map_prev"] + df["map_prev"] = df["map_prev"].fillna(df["map"]) + df["map_diff"] = df["map_diff"].fillna(0.0) + return df + + +def evaluate_model( + df: pd.DataFrame, + feature_cols: list[str], + balanced: bool = False, +) -> dict: + if df.empty or df["label"].nunique() < 2: + return { + "n_samples": len(df), + "accuracy": None, + "roc_auc": None, + "recall": None, + } + + X = df[feature_cols] + y = df["label"] + + X_train, X_test, y_train, y_test = train_test_split( + X, + y, + test_size=0.2, + random_state=42, + stratify=y, + ) + + model = LogisticRegression( + max_iter=1000, + class_weight="balanced" if balanced else None, + ) + model.fit(X_train, y_train) + + preds = model.predict(X_test) + probs = model.predict_proba(X_test)[:, 1] + + return { + "n_samples": len(df), + "accuracy": accuracy_score(y_test, preds), + "roc_auc": roc_auc_score(y_test, probs), + "recall": recall_score(y_test, preds), + } + + +def print_metrics(title: str, metrics: dict) -> None: + print(f"\n=== {title} ===") + print(f"n_samples: {metrics['n_samples']}") + print(f"accuracy: {metrics['accuracy']}") + print(f"roc_auc: {metrics['roc_auc']}") + print(f"recall: {metrics['recall']}") + + +def main() -> None: + dataset = MIMIC3CirculatoryFailureDataset( + # path to the unzipped MIMIC-III database on your machine + root="mimic-iii-dataset" + ) + + # task ablation: prediction windows + for window in [6, 12, 24]: + print(f"\n############################") + print(f"Prediction window = {window}h") + print(f"############################") + + task = CirculatoryFailurePredictionTask(prediction_window_hours=window) + samples = dataset.set_task(task, max_patients=100) + df = samples_to_df(samples) + + print("\nSample preview:") + print(df.head()) + + # baseline setting + baseline_metrics = evaluate_model( + df=df, + feature_cols=["map"], + balanced=False, + ) + print_metrics("Baseline: LogisticRegression(map)", baseline_metrics) + + # advanced setting + df_adv = add_advanced_features(df) + advanced_metrics = evaluate_model( + df=df_adv, + feature_cols=["map", "map_diff"], + balanced=True, + ) + print_metrics( + "Advanced: LogisticRegression(map + map_diff, balanced)", + advanced_metrics, + ) + + # subgroup fairness + for gender in ["M", "F"]: + subgroup_df = df_adv[df_adv["gender"] == gender].copy() + subgroup_metrics = evaluate_model( + df=subgroup_df, + feature_cols=["map", "map_diff"], + balanced=True, + ) + print_metrics(f"Advanced subgroup gender={gender}", subgroup_metrics) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/mimic3_cf_example.py b/examples/mimic3_cf_example.py new file mode 100644 index 000000000..baa4c8783 --- /dev/null +++ b/examples/mimic3_cf_example.py @@ -0,0 +1,23 @@ +from pyhealth.datasets import MIMIC3CirculatoryFailureDataset +from pyhealth.tasks import CirculatoryFailurePredictionTask + + +def main(): + dataset = MIMIC3CirculatoryFailureDataset( + root="/path/to/mimic3" + ) + + task = CirculatoryFailurePredictionTask(prediction_window_hours=12) + + # apply task + samples = dataset.set_task(task, max_patients=5) + + print(f"Total samples: {len(samples)}") + + if samples: + print("Sample example:") + print(samples[0]) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..2cafac05d 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -91,3 +91,4 @@ def __init__(self, *args, **kwargs): save_processors, ) from .collate import collate_temporal +from .mimic3_cf import MIMIC3CirculatoryFailureDataset \ No newline at end of file diff --git a/pyhealth/datasets/configs/mimic3_cf.yaml b/pyhealth/datasets/configs/mimic3_cf.yaml new file mode 100644 index 000000000..0de518e58 --- /dev/null +++ b/pyhealth/datasets/configs/mimic3_cf.yaml @@ -0,0 +1,47 @@ +version: "1.4" +tables: + patients: + file_path: "PATIENTS.csv.gz" + patient_id: "subject_id" + timestamp: null + attributes: + - "gender" + - "dob" + - "dod" + - "expire_flag" + + admissions: + file_path: "ADMISSIONS.csv.gz" + patient_id: "subject_id" + timestamp: "admittime" + attributes: + - "hadm_id" + - "admittime" + - "dischtime" + - "deathtime" + - "hospital_expire_flag" + - "ethnicity" + + icustays: + file_path: "ICUSTAYS.csv.gz" + patient_id: "subject_id" + timestamp: "intime" + attributes: + - "hadm_id" + - "icustay_id" + - "intime" + - "outtime" + - "first_careunit" + - "last_careunit" + + chartevents: + file_path: "CHARTEVENTS.csv.gz" + patient_id: "subject_id" + timestamp: "charttime" + attributes: + - "hadm_id" + - "icustay_id" + - "itemid" + - "charttime" + - "value" + - "valuenum" \ No newline at end of file diff --git a/pyhealth/datasets/mimic3_cf.py b/pyhealth/datasets/mimic3_cf.py new file mode 100644 index 000000000..8aeaed876 --- /dev/null +++ b/pyhealth/datasets/mimic3_cf.py @@ -0,0 +1,284 @@ +import logging +from pathlib import Path +from typing import List, Optional +import pandas as pd +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) +MAP_ITEMID = 220052 + +class MIMIC3CirculatoryFailureDataset(BaseDataset): + """MIMIC-III dataset for circulatory failure early-warning prediction. + + This dataset is designed for a FAMEWS-inspired reproduction setting on + MIMIC-III. It will support cohort construction, event parsing, and + time-series feature extraction for circulatory failure prediction within + a future prediction window. + + Args: + root: Root directory of the MIMIC-III dataset. + tables: Additional tables to load beyond the default cohort tables. + dataset_name: Name of the dataset instance. + config_path: Path to the dataset config YAML file. + **kwargs: Additional keyword arguments passed to BaseDataset. + """ + + def __init__( + self, + root: str, + tables: Optional[List[str]] = None, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + """Initializes the MIMIC-III circulatory failure dataset.""" + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "mimic3_cf.yaml" + + default_tables = ["patients", "admissions", "icustays"] + tables = default_tables + (tables or []) + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "mimic3_cf", + config_path=str(config_path), + **kwargs, + ) + + def load_cohort(self): + """Load patients + admissions + icustays.""" + + import pandas as pd + from pathlib import Path + + root = Path(self.root) + + patients_df = pd.read_csv(root / "PATIENTS.csv.gz") + admissions_df = pd.read_csv(root / "ADMISSIONS.csv.gz") + icustays_df = pd.read_csv(root / "ICUSTAYS.csv.gz") + + df = patients_df.merge(admissions_df, on="SUBJECT_ID") + df = df.merge(icustays_df, on=["SUBJECT_ID", "HADM_ID"]) + + patients = [] + + for _, row in df.iterrows(): + patients.append( + { + "patient_id": row["SUBJECT_ID"], + "gender": row["GENDER"], + "hadm_id": row["HADM_ID"], + "icustay_id": row["ICUSTAY_ID"], + "admittime": row["ADMITTIME"], + "intime": row["INTIME"], + "outtime": row["OUTTIME"], + } + ) + + return patients + + def load_patients(self): + """Backward-compatible wrapper for current development.""" + return self.load_cohort() + + def build_failure_labels(self): + """Build first failure time per ICU stay (MAP < 65) using chunked reads.""" + + root = Path(self.root) + + # load cohort once + cohort = pd.DataFrame(self.load_cohort()) + cohort["intime"] = pd.to_datetime(cohort["intime"]) + cohort["outtime"] = pd.to_datetime(cohort["outtime"]) + + results = [] + + chunks = pd.read_csv( + root / "CHARTEVENTS.csv.gz", + usecols=[ + "SUBJECT_ID", + "HADM_ID", + "ICUSTAY_ID", + "ITEMID", + "CHARTTIME", + "VALUENUM", + ], + chunksize=50000, + ) + + for chunk in chunks: + # filter MAP only + chunk = chunk[chunk["ITEMID"] == MAP_ITEMID].copy() + if chunk.empty: + continue + + chunk["CHARTTIME"] = pd.to_datetime( + chunk["CHARTTIME"], + format="%Y-%m-%d %H:%M:%S", + errors="coerce", + ) + + merged = chunk.merge( + cohort, + left_on="ICUSTAY_ID", + right_on="icustay_id", + ) + + if merged.empty: + continue + + filtered = merged[ + (merged["CHARTTIME"] >= merged["intime"]) + & (merged["CHARTTIME"] <= merged["outtime"]) + ].copy() + + if filtered.empty: + continue + + filtered["failure_label"] = (filtered["VALUENUM"] < 65).astype(int) + + failure_events = filtered[filtered["failure_label"] == 1] + if failure_events.empty: + continue + + first_failure_chunk = ( + failure_events.groupby("ICUSTAY_ID")["CHARTTIME"] + .min() + .reset_index() + .rename(columns={"CHARTTIME": "first_failure_time"}) + ) + + results.append(first_failure_chunk) + + if not results: + return pd.DataFrame(columns=["ICUSTAY_ID", "first_failure_time"]) + + first_failure = pd.concat(results, ignore_index=True) + + # keep earliest failure time per ICU stay across all chunks + first_failure = ( + first_failure.groupby("ICUSTAY_ID")["first_failure_time"] + .min() + .reset_index() + ) + + return first_failure + + def get_patient_by_icustay_id(self, icustay_id: int): + """Build one task-ready patient dict for a given ICU stay.""" + + # 1) load cohort + cohort_df = pd.DataFrame(self.load_cohort()) + cohort_df["intime"] = pd.to_datetime(cohort_df["intime"]) + cohort_df["outtime"] = pd.to_datetime(cohort_df["outtime"]) + + row = cohort_df[cohort_df["icustay_id"] == icustay_id] + if row.empty: + return None + row = row.iloc[0] + + # 2) load failure labels + first_failure = self.build_failure_labels() + failure_row = first_failure[first_failure["ICUSTAY_ID"] == icustay_id] + + first_failure_time = None + if not failure_row.empty: + first_failure_time = failure_row.iloc[0]["first_failure_time"] + + # 3) load MAP time series for this ICU stay + map_df = self.load_map_cache() + ts = map_df[map_df["ICUSTAY_ID"] == icustay_id].copy() + ts = ts[ + (ts["CHARTTIME"] >= row["intime"]) & + (ts["CHARTTIME"] <= row["outtime"]) + ].copy() + + ts = ts.sort_values("CHARTTIME") + + time_series = [] + for _, ts_row in ts.iterrows(): + if pd.isna(ts_row["VALUENUM"]): + continue + time_series.append( + { + "charttime": ts_row["CHARTTIME"], + "map": float(ts_row["VALUENUM"]), + } + ) + + patient = { + "patient_id": int(row["patient_id"]), + "icustay_id": int(row["icustay_id"]), + "gender": row["gender"], + "intime": row["intime"], + "outtime": row["outtime"], + "time_series": time_series, + "first_failure_time": first_failure_time, + } + + return patient + + def load_map_cache(self): + """Load MAP (mean arterial pressure) data once and cache it.""" + + if hasattr(self, "_map_cache"): + return self._map_cache + + root = Path(self.root) + + print("Loading MAP cache (this will take a bit, only once)...") + + chunks = pd.read_csv( + root / "CHARTEVENTS.csv.gz", + usecols=["ICUSTAY_ID", "ITEMID", "CHARTTIME", "VALUENUM"], + chunksize=100000, + ) + + parts = [] + + for chunk in chunks: + chunk = chunk[chunk["ITEMID"] == 220052].copy() + if chunk.empty: + continue + + chunk["CHARTTIME"] = pd.to_datetime( + chunk["CHARTTIME"], + format="%Y-%m-%d %H:%M:%S", + errors="coerce", + ) + + parts.append(chunk) + + if parts: + df = pd.concat(parts, ignore_index=True) + else: + df = pd.DataFrame(columns=["ICUSTAY_ID", "CHARTTIME", "VALUENUM"]) + + self._map_cache = df + print("MAP cache loaded:", len(df)) + + return df + + def set_task(self, task, max_patients: int | None = None): + """Apply a task function to the cohort and return task samples.""" + + samples = [] + cohort = self.load_cohort() + + if max_patients is not None: + cohort = cohort[:max_patients] + + for row in cohort: + icustay_id = row["icustay_id"] + patient = self.get_patient_by_icustay_id(icustay_id) + + if patient is None: + continue + + task_samples = task(patient) + if task_samples: + samples.extend(task_samples) + + return samples \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..a883adfac 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -67,3 +67,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .circulatory_failure_prediction import CirculatoryFailurePredictionTask \ No newline at end of file diff --git a/pyhealth/tasks/circulatory_failure_prediction.py b/pyhealth/tasks/circulatory_failure_prediction.py new file mode 100644 index 000000000..c34250059 --- /dev/null +++ b/pyhealth/tasks/circulatory_failure_prediction.py @@ -0,0 +1,93 @@ +from pyhealth.tasks.base_task import BaseTask +from typing import List, Dict + + +class CirculatoryFailurePredictionTask(BaseTask): + """Early-warning task for circulatory failure prediction. + + This task converts one ICU-stay patient record into multiple + time-point prediction samples. At each timestamp t, the label is 1 + if the first circulatory failure event occurs within the next + prediction window, and 0 otherwise. + + Circulatory failure is defined upstream using a proxy event based on + MAP < 65 mmHg. + + Attributes: + task_name: Unique task identifier used by PyHealth. + input_schema: Expected input feature schema. + output_schema: Expected output label schema. + prediction_window_hours: Number of hours used for early-warning label + generation. + """ + task_name = "circulatory_failure_prediction" + + input_schema = { + "map": float, + "timestamp": "datetime", + "gender": str, + } + + output_schema = { + "label": int, + } + + def __init__(self, prediction_window_hours: int = 12): + """Initializes the circulatory failure prediction task. + + Args: + prediction_window_hours: Future prediction window in hours. + A sample is labeled positive if the first failure event + happens within this horizon. + """ + super().__init__() + self.prediction_window_hours = prediction_window_hours + + def __call__(self, patient) -> List[Dict]: + """Converts one patient record into task samples. + + Args: + patient: A task-ready patient dictionary containing ICU-stay + metadata, time-series MAP measurements, and + first_failure_time. + + Returns: + A list of sample dictionaries. Each sample contains patient + metadata, a timestamp, feature values, and a binary label. + Returns an empty list if the patient has no usable + time-series data or no failure time. + """ + if not patient["time_series"]: + return [] + + import pandas as pd + from datetime import timedelta + + first_failure_time = patient["first_failure_time"] + if first_failure_time is None: + return [] + + first_failure_time = pd.to_datetime(first_failure_time) + + prediction_window = timedelta(hours=self.prediction_window_hours) + + samples = [] + + for row in patient["time_series"]: + t = pd.to_datetime(row["charttime"]) + map_value = row["map"] + + label = int(t < first_failure_time <= t + prediction_window) + + samples.append( + { + "patient_id": patient["patient_id"], + "icustay_id": patient["icustay_id"], + "gender": patient["gender"], + "timestamp": t, + "features": {"map": map_value}, + "label": label, + } + ) + + return samples \ No newline at end of file diff --git a/pyhealth/test_pipeline.py b/pyhealth/test_pipeline.py new file mode 100644 index 000000000..7187d83cb --- /dev/null +++ b/pyhealth/test_pipeline.py @@ -0,0 +1,44 @@ +from pyhealth.datasets import MIMIC3CirculatoryFailureDataset +from pyhealth.tasks import CirculatoryFailurePredictionTask + + +def main(): + # 1. 初始化資料集(請確保路徑正確) + dataset = MIMIC3CirculatoryFailureDataset( + root="/mimic_test" + ) + + # 2. 初始化任務(12小時預警) + task = CirculatoryFailurePredictionTask(prediction_window_hours=12) + + # 3. 讀取 cohort,找到第一個真的能產生 samples 的 ICU stay + cohort = dataset.load_cohort() + + samples = None + chosen_icustay_id = None + + for row in cohort: + icustay_id = row["icustay_id"] + patient = dataset.get_patient_by_icustay_id(icustay_id) + samples = task(patient) + + if samples: + chosen_icustay_id = icustay_id + break + + # 4. 檢查結果 + if samples: + print(f"--- 成功測試 ICU Stay ID: {chosen_icustay_id} ---") + print(f"成功產生樣本數: {len(samples)}") + print( + f"其中 Label=1 (未來12小時內會衰竭) 的數量: " + f"{sum(s['label'] for s in samples)}" + ) + print(f"第一筆樣本特徵: {samples[0]['features']}") + print(f"第一筆完整樣本: {samples[0]}") + else: + print("未找到任何可產生樣本的 ICU stay,請檢查資料與路徑。") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/core/test_circulatory_failure_prediction.py b/tests/core/test_circulatory_failure_prediction.py new file mode 100644 index 000000000..e0daba760 --- /dev/null +++ b/tests/core/test_circulatory_failure_prediction.py @@ -0,0 +1,28 @@ +from pyhealth.tasks import CirculatoryFailurePredictionTask + + +def test_circulatory_failure_task_basic(): + task = CirculatoryFailurePredictionTask(prediction_window_hours=12) + + patient = { + "patient_id": 1, + "icustay_id": 1001, + "gender": "F", + "first_failure_time": "2150-01-01 11:00:00", + "time_series": [ + {"charttime": "2150-01-01 00:00:00", "map": 80.0}, + {"charttime": "2150-01-01 01:00:00", "map": 78.0}, + {"charttime": "2150-01-01 10:00:00", "map": 70.0}, + {"charttime": "2150-01-01 11:00:00", "map": 60.0}, + ], + } + + samples = task(patient) + + assert len(samples) == 4 + assert samples[0]["label"] == 1 + assert samples[1]["label"] == 1 + assert samples[2]["label"] == 1 + assert samples[3]["label"] == 0 + assert samples[0]["features"]["map"] == 80.0 + assert samples[0]["gender"] == "F" \ No newline at end of file diff --git a/tests/core/test_mimic3_cf.py b/tests/core/test_mimic3_cf.py new file mode 100644 index 000000000..c8fa7b4a1 --- /dev/null +++ b/tests/core/test_mimic3_cf.py @@ -0,0 +1,57 @@ +from pyhealth.datasets import MIMIC3CirculatoryFailureDataset +from pyhealth.tasks import CirculatoryFailurePredictionTask + + +class DummyMIMIC3CFDataset(MIMIC3CirculatoryFailureDataset): + """Small synthetic subclass for fast unit testing.""" + + def __init__(self): + # do not call super().__init__ because we don't want real files + pass + + def load_cohort(self): + return [ + { + "patient_id": 1, + "gender": "F", + "hadm_id": 100, + "icustay_id": 1001, + "admittime": "2150-01-01 00:00:00", + "intime": "2150-01-01 00:00:00", + "outtime": "2150-01-02 00:00:00", + } + ] + + def get_patient_by_icustay_id(self, icustay_id: int): + if icustay_id != 1001: + return None + + return { + "patient_id": 1, + "icustay_id": 1001, + "gender": "F", + "intime": "2150-01-01 00:00:00", + "outtime": "2150-01-02 00:00:00", + "first_failure_time": "2150-01-01 11:00:00", + "time_series": [ + {"charttime": "2150-01-01 00:00:00", "map": 80.0}, + {"charttime": "2150-01-01 01:00:00", "map": 78.0}, + {"charttime": "2150-01-01 10:00:00", "map": 70.0}, + {"charttime": "2150-01-01 11:00:00", "map": 60.0}, + ], + } + + +def test_set_task_returns_samples(): + dataset = DummyMIMIC3CFDataset() + task = CirculatoryFailurePredictionTask(prediction_window_hours=12) + + samples = dataset.set_task(task) + + assert isinstance(samples, list) + assert len(samples) == 4 + assert samples[0]["patient_id"] == 1 + assert samples[0]["icustay_id"] == 1001 + assert samples[0]["features"]["map"] == 80.0 + assert samples[0]["label"] == 1 + assert samples[-1]["label"] == 0 \ No newline at end of file