diff --git a/README.md b/README.md index c78b47ea..e6a37327 100644 --- a/README.md +++ b/README.md @@ -87,17 +87,25 @@ The cell lines and perturbations specified in the TOML should match the values a you can use the `tx predict` command: ```bash -state tx predict --output_dir $HOME/state/test/ --checkpoint final.ckpt +state tx predict \ + --output-dir $HOME/state/test/ \ + --checkpoint final.ckpt ``` -It will look in the `output_dir` above, for a `checkpoints` folder. +It will look in the `output-dir` above, for a `checkpoints` folder. If you instead want to use a trained checkpoint for inference (e.g. on data not specified) in the TOML file: ```bash -state tx infer --output $HOME/state/test/ --output_dir /path/to/model/ --checkpoint /path/to/model/final.ckpt --adata /path/to/anndata/processed.h5 --pert_col gene --embed_key X_hvg +state tx infer \ + --output $HOME/state/test/ \ + --output-dir /path/to/model/ \ + --checkpoint /path/to/model/final.ckpt \ + --adata /path/to/anndata/processed.h5 \ + --pert-col gene \ + --embed-key X_hvg ``` Here, `/path/to/model/` is the folder downloaded from [HuggingFace](https://huggingface.co/arcinstitute). @@ -108,13 +116,13 @@ State provides two preprocessing commands to prepare data for training and infer #### Training Data Preprocessing -Use `preprocess_train` to normalize, log-transform, and select highly variable genes from your training data: +Use `preprocess-train` to normalize, log-transform, and select highly variable genes from your training data: ```bash -state tx preprocess_train \ +state tx preprocess-train \ --adata /path/to/raw_data.h5ad \ --output /path/to/preprocessed_training_data.h5ad \ - --num_hvgs 2000 + --num-hvgs 2000 ``` This command: @@ -125,14 +133,14 @@ This command: #### Inference Data Preprocessing -Use `preprocess_infer` to create a "control template" for model inference: +Use `preprocess-infer` to create a "control template" for model inference: ```bash -state tx preprocess_infer \ +state tx preprocess-infer \ --adata /path/to/real_data.h5ad \ --output /path/to/control_template.h5ad \ - --control_condition "DMSO" \ - --pert_col "treatment" \ + --control-condition "DMSO" \ + --pert-col "treatment" \ --seed 42 ``` @@ -301,16 +309,19 @@ state emb transform \ ``` Running this command multiple times with the same lancedb appends the new data to the provided database. +Existing cell records will be updated with the new embeddings. #### Query the database +> For this example, we will use the same dataset (SRX27532045), so the top hit should be the same cell. + Obtain the embeddings: ```bash state emb transform \ --model-folder /large_storage/ctc/userspace/aadduri/SE-600M \ - --input /large_storage/ctc/public/scBasecamp/GeneFull_Ex50pAS/GeneFull_Ex50pAS/Homo_sapiens/SRX27532046.h5ad \ - --output tmp/SRX27532046.h5ad \ + --input /large_storage/ctc/public/scBasecamp/GeneFull_Ex50pAS/GeneFull_Ex50pAS/Homo_sapiens/SRX27532045.h5ad \ + --output tmp/SRX27532045.h5ad \ --gene-column gene_symbols ``` @@ -319,9 +330,41 @@ Query the database with the embeddings: ```bash state emb query \ --lancedb tmp/state_embeddings.lancedb \ - --input tmp/SRX27532046.h5ad \ + --input tmp/SRX27532045.h5ad \ --output tmp/similar_cells.csv \ --k 3 +``` + +Output: + - `query_cell_id` : The cell id of the query cell + - `subject_rank` : The rank of the h (smallest distance to) + - `query_subject_distance` : The distance between the query and subject cell vectors + - `subject_cell_id` : The cell id of the hit cell + - `subject_dataset` : The dataset of the hit cell + - `embedding_key` : The embedding key of the hit cell + - `...` : Other `obs` metadata columns from the query cell + +#### Summarize the vector database + +Get comprehensive statistics about your vector database: + +```bash +state emb vectordb \ + --lancedb tmp/state_embeddings.lancedb \ + --format table +``` + +Output formats: + - `table` (default): Human-readable table format with emojis + - `json`: Machine-readable JSON format + - `yaml`: YAML format + +The summary includes: + - Total number of cells and datasets + - Number of unique embedding keys + - Embedding vector dimensions + - Cell count breakdown by dataset + - List of all embedding keys # Singularity diff --git a/src/state/__main__.py b/src/state/__main__.py index 12b2b600..5392ea88 100644 --- a/src/state/__main__.py +++ b/src/state/__main__.py @@ -3,12 +3,14 @@ from hydra import compose, initialize from omegaconf import DictConfig +from ._cli._utils import CustomFormatter from ._cli import ( add_arguments_emb, add_arguments_tx, run_emb_fit, run_emb_transform, run_emb_query, + run_emb_vectordb, run_tx_infer, run_tx_predict, run_tx_preprocess_infer, @@ -19,10 +21,25 @@ def get_args() -> tuple[ap.Namespace, list[str]]: """Parse known args and return remaining args for Hydra overrides""" - parser = ap.ArgumentParser() + desc = """description: + Entry point for the STATE command line interface. + Use these commands to train models, compute embeddings, and run inference. + Run `state --help` for details on each command.""" + parser = ap.ArgumentParser(description=desc, formatter_class=CustomFormatter) + parser.add_argument("--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level") subparsers = parser.add_subparsers(required=True, dest="command") - add_arguments_emb(subparsers.add_parser("emb")) - add_arguments_tx(subparsers.add_parser("tx")) + + # emb + desc = """description: + Commands for generating and querying STATE embeddings. + See `state emb --help` for subcommand options.""" + add_arguments_emb(subparsers.add_parser("emb", description=desc, formatter_class=CustomFormatter)) + + # tx + desc = """description: + Train and evaluate perturbation models with Hydra configuration. + Overrides can be passed via `state tx param=value`.""" + add_arguments_tx(subparsers.add_parser("tx", description=desc, formatter_class=CustomFormatter)) # Use parse_known_args to get both known args and remaining args return parser.parse_args() @@ -62,21 +79,20 @@ def show_hydra_help(method: str): print() print("Usage examples:") print(" Override single parameter:") - print(f" uv run state tx train data.batch_size=64") + print(" uv run state tx train data.batch_size=64") print() print(" Override nested parameter:") - print(f" uv run state tx train model.kwargs.hidden_dim=512") + print(" uv run state tx train model.kwargs.hidden_dim=512") print() print(" Override multiple parameters:") - print(f" uv run state tx train data.batch_size=64 training.lr=0.001") + print(" uv run state tx train data.batch_size=64 training.lr=0.001") print() print(" Change config group:") - print(f" uv run state tx train data=custom_data model=custom_model") + print(" uv run state tx train data=custom_data model=custom_model") print() print("Available config groups:") # Show available config groups - import os from pathlib import Path config_dir = Path(__file__).parent / "configs" @@ -103,6 +119,8 @@ def main(): run_emb_transform(args) case "query": run_emb_query(args) + case "vectordb": + run_emb_vectordb(args) case "tx": match args.subcommand: case "train": @@ -112,19 +130,19 @@ def main(): else: # Load Hydra config with overrides for sets training cfg = load_hydra_config("tx", args.hydra_overrides) - run_tx_train(cfg) + run_tx_train(cfg, args) case "predict": # For now, predict uses argparse and not hydra run_tx_predict(args) case "infer": # Run inference using argparse, similar to predict run_tx_infer(args) - case "preprocess_train": + case "preprocess-train": # Run preprocessing using argparse - run_tx_preprocess_train(args.adata, args.output, args.num_hvgs) - case "preprocess_infer": + run_tx_preprocess_train(args.adata, args.output, args.num_hvgs, args.log_level) + case "preprocess-infer": # Run inference preprocessing using argparse - run_tx_preprocess_infer(args.adata, args.output, args.control_condition, args.pert_col, args.seed) + run_tx_preprocess_infer(args.adata, args.output, args.control_condition, args.pert_col, args.seed, args.log_level) if __name__ == "__main__": diff --git a/src/state/_cli/__init__.py b/src/state/_cli/__init__.py index a17d4ac3..fcfc9bd5 100644 --- a/src/state/_cli/__init__.py +++ b/src/state/_cli/__init__.py @@ -1,4 +1,4 @@ -from ._emb import add_arguments_emb, run_emb_fit, run_emb_transform, run_emb_query +from ._emb import add_arguments_emb, run_emb_fit, run_emb_transform, run_emb_query, run_emb_vectordb from ._tx import add_arguments_tx, run_tx_infer, run_tx_predict, run_tx_preprocess_infer, run_tx_preprocess_train, run_tx_train __all__ = [ @@ -12,4 +12,5 @@ "run_emb_fit", "run_emb_query", "run_emb_transform", + "run_emb_vectordb", ] diff --git a/src/state/_cli/_emb/__init__.py b/src/state/_cli/_emb/__init__.py index 1cda4d4d..5cbfa50b 100644 --- a/src/state/_cli/_emb/__init__.py +++ b/src/state/_cli/_emb/__init__.py @@ -3,13 +3,42 @@ from ._fit import add_arguments_fit, run_emb_fit from ._transform import add_arguments_transform, run_emb_transform from ._query import add_arguments_query, run_emb_query +from ._vectordb import add_arguments_vectordb, run_emb_vectordb +from .._utils import CustomFormatter -__all__ = ["run_emb_fit", "run_emb_transform", "run_emb_query", "add_arguments_emb"] +__all__ = ["run_emb_fit", "run_emb_transform", "run_emb_query", "run_emb_vectordb", "add_arguments_emb"] def add_arguments_emb(parser: ap.ArgumentParser): - """""" + """Add embedding commands to the parser""" subparsers = parser.add_subparsers(required=True, dest="subcommand") - add_arguments_fit(subparsers.add_parser("fit")) - add_arguments_transform(subparsers.add_parser("transform")) - add_arguments_query(subparsers.add_parser("query")) + + # fit + desc = """description: + Train an embedding model on a reference dataset. + Provide Hydra overrides to adjust training parameters.""" + add_arguments_fit( + subparsers.add_parser("fit", description=desc, formatter_class=CustomFormatter) + ) + + # transform + desc = """description: + Encode an input dataset with a trained embedding model. + Results can be saved locally and inserted into a LanceDB vector store.""" + add_arguments_transform( + subparsers.add_parser("transform", description=desc, formatter_class=CustomFormatter) + ) + + # query + desc = """description: + Search a LanceDB vector store (created with `transform`) for cells with similar embeddings.""" + add_arguments_query( + subparsers.add_parser("query", description=desc, formatter_class=CustomFormatter) + ) + + # vectordb + desc = """description: + Get summary statistics about a LanceDB vector database including datasets, cell counts, and embeddings.""" + add_arguments_vectordb( + subparsers.add_parser("vectordb", description=desc, formatter_class=CustomFormatter) + ) diff --git a/src/state/_cli/_emb/_fit.py b/src/state/_cli/_emb/_fit.py index fddbb032..4e010458 100644 --- a/src/state/_cli/_emb/_fit.py +++ b/src/state/_cli/_emb/_fit.py @@ -21,6 +21,7 @@ def run_emb_fit(cfg, args): from ...emb.train.trainer import main as trainer_main + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) log = logging.getLogger(__name__) # Load the base configuration diff --git a/src/state/_cli/_emb/_query.py b/src/state/_cli/_emb/_query.py index a10fe1e1..39c26849 100644 --- a/src/state/_cli/_emb/_query.py +++ b/src/state/_cli/_emb/_query.py @@ -1,3 +1,4 @@ +import os import argparse as ap import logging import pandas as pd @@ -14,15 +15,16 @@ def add_arguments_query(parser: ap.ArgumentParser): parser.add_argument("--embed-key", default="X_state", help="Key containing embeddings in input file") parser.add_argument("--exclude-distances", action="store_true", help="Exclude vector distances in results") - parser.add_argument("--filter", type=str, help="Filter expression (e.g., 'cell_type==\"B cell\"')") - parser.add_argument("--batch-size", type=int, default=100, - help="Batch size for query operations") + parser.add_argument("--filter", type=str, + help="Filter expression (e.g., 'cell_type==\"B cell\"', assuming a 'cell_type' column exists in the database)") + parser.add_argument("--batch-size", type=int, default=100, help="Batch size for query operations") + parser.add_argument("--max-workers", type=int, default=os.cpu_count(), help="Maximum number of workers for parallel processing") def run_emb_query(args: ap.ArgumentParser): """ Query a LanceDB database for similar cells. """ - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) logger = logging.getLogger(__name__) from ...emb.vectordb import StateVectorDB @@ -59,6 +61,7 @@ def run_emb_query(args: ap.ArgumentParser): filter=args.filter, include_distance=not args.exclude_distances, batch_size=args.batch_size, + max_workers=args.max_workers, show_progress=True ) @@ -66,11 +69,21 @@ def run_emb_query(args: ap.ArgumentParser): all_results = [] for query_idx, result_df in enumerate(results_list): result_df['query_cell_id'] = query_adata.obs.index[query_idx] - result_df['query_rank'] = range(1, len(result_df) + 1) + result_df['subject_rank'] = range(1, len(result_df) + 1) all_results.append(result_df) # Combine results final_results = pd.concat(all_results, ignore_index=True) + + # Format the results table + ## Move certain columns to the start, if they exist + to_move = ['query_cell_id', 'subject_rank', 'query_subject_distance', 'cell_id', 'dataset', 'embedding_key'] + to_move = [col for col in to_move if col in final_results.columns] + final_results = final_results[to_move + [col for col in final_results.columns if col not in to_move]] + ## Rename `cell_id` to 'subject_cell_id' + rn_dict = {'cell_id': 'subject_cell_id', 'dataset': 'subject_dataset'} + rn_dict = {k:v for k,v in rn_dict.items() if k in final_results.columns} + final_results = final_results.rename(columns=rn_dict) # Save results output_path = Path(args.output) @@ -96,11 +109,11 @@ def create_result_anndata(query_adata, results_df, k): cell_ids_array = np.array(cell_ids_pivot.values, dtype=str) # Handle distances - convert to float64 and handle missing values - if 'vector_distance' in results_df: + if 'query_subject_distance' in results_df: distances_pivot = results_df.pivot( index='query_cell_id', columns='query_rank', - values='vector_distance' + values='query_subject_distance' ) distances_array = np.array(distances_pivot.values, dtype=np.float64) else: @@ -118,5 +131,5 @@ def create_result_anndata(query_adata, results_df, k): # Create result anndata result_adata = query_adata.copy() result_adata.uns['lancedb_query_results'] = uns_data - - return result_adata \ No newline at end of file + + return result_adata diff --git a/src/state/_cli/_emb/_transform.py b/src/state/_cli/_emb/_transform.py index 5d4616a3..b04e6014 100644 --- a/src/state/_cli/_emb/_transform.py +++ b/src/state/_cli/_emb/_transform.py @@ -9,13 +9,12 @@ def add_arguments_transform(parser: ap.ArgumentParser): parser.add_argument("--output", required=False, help="Path to output embedded anndata file (h5ad)") parser.add_argument("--embed-key", default="X_state", help="Name of key to store embeddings") parser.add_argument("--gene-column", default="gene_name", help="Name of column in var dataframe to use for gene names") - parser.add_argument("--lancedb", type=str, help="Path to LanceDB database for vector storage") - parser.add_argument("--lancedb-update", action="store_true", - help="Update existing entries in LanceDB (default: append)") - parser.add_argument("--lancedb-batch-size", type=int, default=1000, - help="Batch size for LanceDB operations") - - + parser.add_argument("--dataset-name", type=str, default=None, help="Name of the dataset. If None, the input file name will be used.") + lancedb_group = parser.add_argument_group("Vector database options") + lancedb_group.add_argument("--lancedb", type=str, help="Path to LanceDB database for vector storage") + lancedb_group.add_argument("--lancedb-batch-size", type=int, default=1000, + help="Batch size for LanceDB operations") + def run_emb_transform(args: ap.ArgumentParser): """ Compute embeddings for an input anndata file using a pre-trained VCI model checkpoint. @@ -27,7 +26,7 @@ def run_emb_transform(args: ap.ArgumentParser): import torch from omegaconf import OmegaConf - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) logger = logging.getLogger(__name__) from ...emb.inference import Inference @@ -79,8 +78,8 @@ def run_emb_transform(args: ap.ArgumentParser): output_adata_path=args.output, emb_key=args.embed_key, gene_column=args.gene_column, + dataset_name=args.dataset_name, lancedb_path=args.lancedb, - update_lancedb=args.lancedb_update, lancedb_batch_size=args.lancedb_batch_size, ) diff --git a/src/state/_cli/_emb/_vectordb.py b/src/state/_cli/_emb/_vectordb.py new file mode 100644 index 00000000..af159ef2 --- /dev/null +++ b/src/state/_cli/_emb/_vectordb.py @@ -0,0 +1,79 @@ +import argparse as ap +import logging +import json +import yaml +from typing import Dict, Any + + +def add_arguments_vectordb(parser: ap.ArgumentParser): + """Add arguments for state embedding vectordb CLI.""" + parser.add_argument("--lancedb", required=True, help="Path to existing LanceDB database") + parser.add_argument("--format", choices=["json", "yaml", "table"], default="table", + help="Output format for database summary") + + +def run_emb_vectordb(args: ap.ArgumentParser): + """ + Get summary statistics about a LanceDB vector database. + """ + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) + logger = logging.getLogger(__name__) + + from ...emb.vectordb import StateVectorDB + + # Connect to database + logger.info(f"Connecting to database at {args.lancedb}") + vector_db = StateVectorDB(args.lancedb) + + # Get database summary + summary = vector_db.get_database_summary() + + # Output in requested format + if args.format == "json": + print(json.dumps(summary, indent=2)) + elif args.format == "yaml": + print(yaml.dump(summary, default_flow_style=False)) + elif args.format == "table": + _print_table_summary(summary) + + logger.info("Database summary completed successfully!") + + +def _print_table_summary(summary: Dict[str, Any]) -> None: + """Print database summary in a nice table format.""" + if not summary["table_exists"]: + print("❌ Database table does not exist") + return + + if summary["num_cells"] == 0: + print("⚠️ Database table exists but is empty") + return + + # Print header + print("=" * 60) + print("📊 STATE VECTOR DATABASE SUMMARY") + print("=" * 60) + + # Basic stats + print(f"🔢 Total cells: {summary['num_cells']:,}") + print(f"📦 Total datasets: {summary['num_datasets']}") + print(f"🔑 Embedding keys: {summary['num_embedding_keys']}") + print(f"📐 Embedding dimension: {summary['embedding_dim']}") + print() + + # Datasets breakdown + if summary["datasets"]: + print("📂 DATASETS:") + for dataset in summary["datasets"]: + cell_count = summary["cells_per_dataset"].get(dataset, 0) + print(f" • {dataset}: {cell_count:,} cells") + print() + + # Embedding keys + if summary["embedding_keys"]: + print("🗝️ EMBEDDING KEYS:") + for key in summary["embedding_keys"]: + print(f" • {key}") + print() + + print("=" * 60) \ No newline at end of file diff --git a/src/state/_cli/_tx/__init__.py b/src/state/_cli/_tx/__init__.py index e44446cd..f901958c 100644 --- a/src/state/_cli/_tx/__init__.py +++ b/src/state/_cli/_tx/__init__.py @@ -5,15 +5,45 @@ from ._preprocess_infer import add_arguments_preprocess_infer, run_tx_preprocess_infer from ._preprocess_train import add_arguments_preprocess_train, run_tx_preprocess_train from ._train import add_arguments_train, run_tx_train +from .._utils import CustomFormatter __all__ = ["run_tx_train", "run_tx_predict", "run_tx_infer", "run_tx_preprocess_train", "run_tx_preprocess_infer", "add_arguments_tx"] def add_arguments_tx(parser: ap.ArgumentParser): - """""" + """Add transcriptomic commands to the parser""" subparsers = parser.add_subparsers(required=True, dest="subcommand") - add_arguments_train(subparsers.add_parser("train", add_help=False)) - add_arguments_predict(subparsers.add_parser("predict")) - add_arguments_infer(subparsers.add_parser("infer")) - add_arguments_preprocess_train(subparsers.add_parser("preprocess_train")) - add_arguments_preprocess_infer(subparsers.add_parser("preprocess_infer")) + + # Train + desc = """description: + Train a perturbation model using a Hydra configuration. + Provide overrides to customize training, e.g.: + `state tx train data.batch_size=32`""" + add_arguments_train( + subparsers.add_parser("train", description=desc, formatter_class=CustomFormatter) + ) + + # Predict + desc = """description: + Generate predictions from a trained model and optionally compute evaluation metrics.""" + add_arguments_predict( + subparsers.add_parser("predict", description=desc, formatter_class=CustomFormatter) + ) + + # Infer + desc = """description: + Run inference on new samples using a trained model.""" + add_arguments_infer( + subparsers.add_parser("infer", description=desc, formatter_class=CustomFormatter) + ) + + # Preprocess: train + desc = """description: + Preprocess a dataset for training.""" + add_arguments_preprocess_train(subparsers.add_parser("preprocess-train", description=desc, formatter_class=CustomFormatter)) + + # Preprocess: infer + desc = """description: + Preprocess a dataset for inference.""" + add_arguments_preprocess_infer(subparsers.add_parser("preprocess-infer", description=desc, formatter_class=CustomFormatter)) + diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index 3babe7bc..4bb185a1 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -9,24 +9,24 @@ def add_arguments_infer(parser: argparse.ArgumentParser): help="Path to model checkpoint (.ckpt). If not provided, will use model_dir/checkpoints/final.ckpt", ) parser.add_argument("--adata", type=str, required=True, help="Path to input AnnData file (.h5ad)") - parser.add_argument("--embed_key", type=str, default=None, help="Key in adata.obsm for input features") + parser.add_argument("--embed-key", type=str, default=None, help="Key in adata.obsm for input features") parser.add_argument( - "--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels" + "--pert-col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels" ) parser.add_argument("--output", type=str, default=None, help="Path to output AnnData file (.h5ad)") parser.add_argument( - "--model_dir", + "--model-dir", type=str, required=True, help="Path to the model_dir containing the config.yaml file and the pert_onehot_map.pt file that was saved during training.", ) parser.add_argument( - "--celltype_col", type=str, default=None, help="Column in adata.obs for cell type labels (optional)" + "--celltype-col", type=str, default=None, help="Column in adata.obs for cell type labels (optional)" ) parser.add_argument( "--celltypes", type=str, default=None, help="Comma-separated list of cell types to include (optional)" ) - parser.add_argument("--batch_size", type=int, default=1000, help="Batch size for inference (default: 1000)") + parser.add_argument("--batch-size", type=int, default=1000, help="Batch size for inference") def run_tx_infer(args): @@ -42,7 +42,7 @@ def run_tx_infer(args): from ...tx.models.state_transition import StateTransitionPerturbationModel - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) logger = logging.getLogger(__name__) def load_config(cfg_path: str) -> dict: diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index 2115a815..7bdd1129 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -59,7 +59,7 @@ def run_tx_predict(args: ap.ArgumentParser): from cell_load.data_modules import PerturbationDataModule from tqdm import tqdm - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) logger = logging.getLogger(__name__) torch.multiprocessing.set_sharing_strategy("file_system") diff --git a/src/state/_cli/_tx/_preprocess_infer.py b/src/state/_cli/_tx/_preprocess_infer.py index ed15b529..9f0aa44e 100644 --- a/src/state/_cli/_tx/_preprocess_infer.py +++ b/src/state/_cli/_tx/_preprocess_infer.py @@ -22,13 +22,13 @@ def add_arguments_preprocess_infer(parser: ap.ArgumentParser): help="Path to output preprocessed AnnData file (.h5ad)", ) parser.add_argument( - "--control_condition", + "--control-condition", type=str, required=True, help="Control condition identifier (e.g., \"[('DMSO_TF', 0.0, 'uM')]\")", ) parser.add_argument( - "--pert_col", + "--pert-col", type=str, required=True, help="Column name containing perturbation information (e.g., 'drugname_drugconc')", @@ -37,7 +37,7 @@ def add_arguments_preprocess_infer(parser: ap.ArgumentParser): "--seed", type=int, default=42, - help="Random seed for reproducibility (default: 42)", + help="Random seed for reproducibility", ) @@ -46,7 +46,8 @@ def run_tx_preprocess_infer( output_path: str, control_condition: str, pert_col: str, - seed: int = 42 + seed: int = 42, + log_level: str = "INFO" ): """ Preprocess inference data by replacing perturbed cells with control expression. @@ -62,6 +63,7 @@ def run_tx_preprocess_infer( pert_col: Column name containing perturbation information seed: Random seed for reproducibility """ + logging.basicConfig(level=getattr(logging, log_level, logging.INFO)) logger.info(f"Loading AnnData from {adata_path}") adata = ad.read_h5ad(adata_path) diff --git a/src/state/_cli/_tx/_preprocess_train.py b/src/state/_cli/_tx/_preprocess_train.py index f2c0b304..987e516f 100644 --- a/src/state/_cli/_tx/_preprocess_train.py +++ b/src/state/_cli/_tx/_preprocess_train.py @@ -22,14 +22,14 @@ def add_arguments_preprocess_train(parser: ap.ArgumentParser): help="Path to output preprocessed AnnData file (.h5ad)", ) parser.add_argument( - "--num_hvgs", + "--num-hvgs", type=int, required=True, help="Number of highly variable genes to select", ) -def run_tx_preprocess_train(adata_path: str, output_path: str, num_hvgs: int): +def run_tx_preprocess_train(adata_path: str, output_path: str, num_hvgs: int, log_level: str): """ Preprocess training data by normalizing, log-transforming, and selecting highly variable genes. @@ -38,6 +38,7 @@ def run_tx_preprocess_train(adata_path: str, output_path: str, num_hvgs: int): output_path: Path to save preprocessed AnnData file num_hvgs: Number of highly variable genes to select """ + logging.basicConfig(level=getattr(logging, log_level, logging.INFO)) logger.info(f"Loading AnnData from {adata_path}") adata = ad.read_h5ad(adata_path) diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index 7e1b3531..0ae05f5f 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -6,11 +6,9 @@ def add_arguments_train(parser: ap.ArgumentParser): # Allow remaining args to be passed through to Hydra parser.add_argument("hydra_overrides", nargs="*", help="Hydra configuration overrides (e.g., data.batch_size=32)") - # Add custom help handler - parser.add_argument("--help", "-h", action="store_true", help="Show configuration help with all parameters") -def run_tx_train(cfg: DictConfig): +def run_tx_train(cfg: DictConfig, args: ap.ArgumentParser): import json import logging import os @@ -29,6 +27,7 @@ def run_tx_train(cfg: DictConfig): from ...tx.callbacks import BatchSpeedMonitorCallback from ...tx.utils import get_checkpoint_callbacks, get_lightning_module, get_loggers + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) logger = logging.getLogger(__name__) torch.set_float32_matmul_precision("medium") diff --git a/src/state/_cli/_utils.py b/src/state/_cli/_utils.py new file mode 100644 index 00000000..0ef6f032 --- /dev/null +++ b/src/state/_cli/_utils.py @@ -0,0 +1,8 @@ +import argparse + +class CustomFormatter( + argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter +): + """Combine default and raw formatting styles.""" + + pass diff --git a/src/state/emb/inference.py b/src/state/emb/inference.py index d6fc8571..20a45d22 100644 --- a/src/state/emb/inference.py +++ b/src/state/emb/inference.py @@ -150,17 +150,15 @@ def encode_adata( output_adata_path: str | None = None, emb_key: str = "X_emb", dataset_name: str | None = None, - batch_size: int = 32, lancedb_path: str | None = None, - update_lancedb: bool = False, lancedb_batch_size: int = 1000, gene_column: str = "gene_name", ): - shape_dict = self.__load_dataset_meta(input_adata_path) - adata = anndata.read_h5ad(input_adata_path) if dataset_name is None: dataset_name = Path(input_adata_path).stem - + shape_dict = self.__load_dataset_meta(input_adata_path) + adata = anndata.read_h5ad(input_adata_path) + # Convert to CSR format if needed adata = self._convert_to_csr(adata) @@ -202,7 +200,7 @@ def encode_adata( if lancedb_path is not None: from .vectordb import StateVectorDB - log.info(f"Saving embeddings to LanceDB at {lancedb_path}") + log.info(f"Saving embeddings to LanceDB at {lancedb_path} using dataset name: {dataset_name}") vector_db = StateVectorDB(lancedb_path) # Extract relevant metadata @@ -213,8 +211,8 @@ def encode_adata( embeddings=all_embeddings, metadata=metadata, embedding_key=emb_key, - dataset_name=dataset_name or Path(input_adata_path).stem, - batch_size=lancedb_batch_size + dataset_name=dataset_name, + batch_size=lancedb_batch_size, ) log.info(f"Successfully saved {len(all_embeddings)} embeddings to LanceDB") diff --git a/src/state/emb/vectordb.py b/src/state/emb/vectordb.py index 073866df..a57ce2ae 100644 --- a/src/state/emb/vectordb.py +++ b/src/state/emb/vectordb.py @@ -3,6 +3,8 @@ import pandas as pd from typing import Optional, List from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import partial class StateVectorDB: """Manages LanceDB operations for State embeddings.""" @@ -22,8 +24,8 @@ def create_or_update_table( metadata: pd.DataFrame, embedding_key: str = "X_state", dataset_name: Optional[str] = None, - batch_size: int = 1000 - ): + batch_size: int = 1000, + ) -> None: """Create or update the embeddings table. Args: @@ -40,9 +42,10 @@ def create_or_update_table( batch_data = [] for j in range(i, batch_end): + cell_id = metadata.index[j] record = { "vector": embeddings[j].tolist(), - "cell_id": metadata.index[j], + "cell_id": cell_id, "embedding_key": embedding_key, "dataset": dataset_name or "unknown", **{col: metadata.iloc[j][col] for col in metadata.columns} @@ -54,7 +57,12 @@ def create_or_update_table( # Create or append to table if self.table_name in self.db.table_names(): table = self.db.open_table(self.table_name) - table.add(data) + ( + table.merge_insert(["cell_id", "dataset"]) + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(data) + ) else: self.db.create_table(self.table_name, data=data) @@ -90,16 +98,17 @@ def search( if columns: query = query.select(columns + ['_distance'] if include_distance else columns) + # convert to pandas results = query.to_pandas() # deal with _distance column if '_distance' in results.columns: if include_distance: - results = results.rename(columns={'_distance': 'query_distance'}) + results = results.rename(columns={'_distance': 'query_subject_distance'}) else: results = results.drop('_distance', axis=1) elif include_distance: - results['query_distance'] = 0.0 + results['query_subject_distance'] = 0.0 # drop vector column if include_vector is False if not include_vector and 'vector' in results.columns: @@ -107,17 +116,29 @@ def search( return results + def _search_single(self, query_vector: np.ndarray, k: int, filter: str | None, + include_distance: bool, include_vector: bool): + """Helper method for parallel search.""" + return self.search( + query_vector=query_vector, + k=k, + filter=filter, + include_distance=include_distance, + include_vector=include_vector, + ) + def batch_search( self, query_vectors: np.ndarray, k: int = 10, filter: str | None = None, include_distance: bool = True, - batch_size: int = 100, + include_vector: bool = False, + max_workers: int = 4, + batch_size: int = 1000, show_progress: bool = True, - include_vector: bool = False ): - """Batch search for multiple query vectors. + """Parallel batch search for multiple query vectors using ThreadPoolExecutor. Args: query_vectors: Array of query embedding vectors @@ -125,35 +146,57 @@ def batch_search( filter: Optional filter expression include_distance: Whether to include distances include_vector: Whether to include the query vector in the results - batch_size: Number of queries to process at once + max_workers: Maximum number of worker threads + batch_size: Number of queries to submit to executor at once show_progress: Show progress bar Returns: List of DataFrames with search results """ from tqdm import tqdm - results = [] - iterator = range(0, len(query_vectors), batch_size) + # Create a partial function with fixed parameters + search_func = partial( + self._search_single, + k=k, + filter=filter, + include_distance=include_distance, + include_vector=include_vector, + ) - if show_progress: - iterator = tqdm(iterator, desc="Searching") + results = [None] * len(query_vectors) - for i in iterator: - batch_end = min(i + batch_size, len(query_vectors)) - batch_queries = query_vectors[i:batch_end] + # Process in batches to manage memory and avoid overwhelming the database + with ThreadPoolExecutor(max_workers=max_workers) as executor: + total_processed = 0 - batch_results = [] - for query_vec in batch_queries: - result = self.search( - query_vector=query_vec, - k=k, - filter=filter, - include_distance=include_distance, - include_vector=include_vector, - ) - batch_results.append(result) + if show_progress: + pbar = tqdm(total=len(query_vectors), desc="Searching") + + for batch_start in range(0, len(query_vectors), batch_size): + batch_end = min(batch_start + batch_size, len(query_vectors)) + batch_vectors = query_vectors[batch_start:batch_end] + + # Submit batch to executor + future_to_index = { + executor.submit(search_func, batch_vectors[i]): batch_start + i + for i in range(len(batch_vectors)) + } + + # Collect results for this batch + for future in as_completed(future_to_index): + index = future_to_index[future] + try: + results[index] = future.result() + except Exception as e: + print(f"Query {index} failed: {e}") + results[index] = pd.DataFrame() # Empty result on error + + total_processed += 1 + if show_progress: + pbar.update(1) - results.extend(batch_results) + if show_progress: + pbar.close() return results @@ -167,4 +210,62 @@ def get_table_info(self): "num_rows": len(table), "columns": table.schema.names, "embedding_dim": len(table.to_pandas().iloc[0]['vector']) if len(table) > 0 else 0 - } \ No newline at end of file + } + + def get_database_summary(self) -> dict: + """Get comprehensive summary statistics about the database contents. + + Returns: + Dictionary containing database statistics including: + - num_cells: Total number of cells stored + - num_datasets: Number of unique datasets + - num_embedding_keys: Number of unique embedding keys + - datasets: List of dataset names + - embedding_keys: List of embedding key names + - cells_per_dataset: Dictionary mapping dataset to cell count + """ + if self.table_name not in self.db.table_names(): + return { + "num_cells": 0, + "num_datasets": 0, + "num_embedding_keys": 0, + "datasets": [], + "embedding_keys": [], + "cells_per_dataset": {}, + "table_exists": False + } + + table = self.db.open_table(self.table_name) + + # Get the full dataset to compute statistics + # For large tables, we might want to optimize this with SQL-like queries + df = table.to_pandas() + + if len(df) == 0: + return { + "num_cells": 0, + "num_datasets": 0, + "num_embedding_keys": 0, + "datasets": [], + "embedding_keys": [], + "cells_per_dataset": {}, + "table_exists": True + } + + # Calculate summary statistics + datasets = df['dataset'].unique().tolist() + embedding_keys = df['embedding_key'].unique().tolist() + cells_per_dataset = df['dataset'].value_counts().to_dict() + + summary = { + "num_cells": len(df), + "num_datasets": len(datasets), + "num_embedding_keys": len(embedding_keys), + "datasets": sorted(datasets), + "embedding_keys": sorted(embedding_keys), + "cells_per_dataset": cells_per_dataset, + "table_exists": True, + "embedding_dim": len(df.iloc[0]['vector']) if 'vector' in df.columns else 0 + } + + return summary \ No newline at end of file