Source code for flashscenic.cistarget

"""
GPU-accelerated cisTarget pruning for flashscenic.

All operations are tensor-native for GPU acceleration.
"""

import torch
import numpy as np
import csv
from typing import Tuple, List, Dict, Optional, Union
from collections import defaultdict

ArrayLike = Union[np.ndarray, torch.Tensor]

[docs] def compute_recovery_aucs( rankings: torch.Tensor, module_gene_indices: torch.Tensor, rank_threshold: int, auc_threshold: float, weights: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute recovery curves and AUCs for all motifs given module genes. Vectorized implementation - processes all motifs in parallel. Args: rankings: (n_motifs, n_genes) - rank of each gene for each motif (0-indexed) module_gene_indices: (n_module_genes,) - indices of genes in the module rank_threshold: Maximum rank to consider for recovery curve auc_threshold: Fraction of genome for AUC calculation weights: (n_module_genes,) - optional weights for weighted recovery Returns: rccs: (n_motifs, rank_threshold) - recovery curves aucs: (n_motifs,) - AUC values """ device = rankings.device n_motifs, total_genes = rankings.shape n_module_genes = len(module_gene_indices) if weights is None: weights = torch.ones(n_module_genes, device=device, dtype=torch.float32) # Derive rank cutoff (same as pySCENIC) rank_cutoff = int(round(auc_threshold * total_genes)) - 1 rank_cutoff = max(1, min(rank_cutoff, rank_threshold - 1)) # Get rankings for module genes: (n_motifs, n_module_genes) module_rankings = rankings[:, module_gene_indices] # Compute recovery curves using scatter # For each motif, accumulate weights at each rank position rccs = torch.zeros((n_motifs, rank_threshold), device=device, dtype=torch.float32) # Clamp to valid range valid_mask = module_rankings < rank_threshold # (n_motifs, n_module_genes) # Process in vectorized manner using scatter_add # Create batch indices for scatter batch_idx = torch.arange(n_motifs, device=device).unsqueeze(1).expand(-1, n_module_genes) # Flatten for scatter flat_batch = batch_idx[valid_mask] # (n_valid,) flat_ranks = module_rankings[valid_mask].long() # (n_valid,) flat_weights = weights.unsqueeze(0).expand(n_motifs, -1)[valid_mask] # (n_valid,) # Create index for 2D scatter: batch_idx * rank_threshold + rank flat_idx = flat_batch * rank_threshold + flat_ranks # Scatter add into flattened rccs rccs_flat = rccs.view(-1) rccs_flat.scatter_add_(0, flat_idx, flat_weights) rccs = rccs_flat.view(n_motifs, rank_threshold) # Cumsum to get recovery curves rccs = rccs.cumsum(dim=1) # Compute AUC weight_sum = weights.sum() max_auc = float((rank_cutoff + 1) * weight_sum.item()) aucs = rccs[:, :rank_cutoff].sum(dim=1) / max_auc return rccs, aucs
[docs] def compute_nes(aucs: torch.Tensor) -> torch.Tensor: """ Compute Normalized Enrichment Scores (NES) from AUC values. NES = (AUC - mean(AUC)) / std(AUC) Uses population std (ddof=0) to match pySCENIC. """ mean_auc = aucs.mean() std_auc = aucs.std(unbiased=False) if std_auc > 0: return (aucs - mean_auc) / std_auc return torch.zeros_like(aucs)
[docs] class MotifAnnotation: """ Lightweight motif annotation storage without pandas. Stores motif annotations in a dictionary for fast lookup. Matches pyscenic's annotation filtering behavior. """ def __init__(self): # Core data structure: dict[(tf, motif_id)] -> annotation_info self.annotations: Dict[Tuple[str, str], Dict] = {} # All motif IDs (for fast lookup) self.all_motif_ids: set = set() # All TF names (for reference) self.all_tf_names: set = set()
[docs] @classmethod def load_from_file( cls, fname: str, motif_similarity_fdr: float = 0.001, orthologous_identity_threshold: float = 0.0, column_names: Optional[Tuple[str, ...]] = None ) -> 'MotifAnnotation': """ Load motif annotations from a motif2TF snapshot file. Args: fname: Path to TSV annotation file motif_similarity_fdr: Maximum FDR threshold (default: 0.001) orthologous_identity_threshold: Minimum orthologous identity (default: 0.0) column_names: Optional tuple of column names to use. If None, reads from header. Returns: MotifAnnotation instance """ instance = cls() # Read file and collect all annotations annotations_list = [] with open(fname, 'r', encoding='utf-8') as f: reader = csv.reader(f, delimiter='\t') # Read header header = next(reader, None) if header is None: raise ValueError("Empty annotation file") # Clean header: strip whitespace (keep # prefix for motif_id) header_clean = [col.strip() for col in header] # Find column indices try: # motif_id may have # prefix - check original header first if "#motif_id" in header_clean: motif_id_idx = header_clean.index("#motif_id") elif "motif_id" in header_clean: motif_id_idx = header_clean.index("motif_id") else: raise ValueError("motif_id column not found") # Other columns don't have # prefix gene_name_idx = header_clean.index("gene_name") similarity_qval_idx = header_clean.index("motif_similarity_qvalue") ortho_identity_idx = header_clean.index("orthologous_identity") description_idx = header_clean.index("description") except ValueError as e: raise ValueError(f"Required column not found in header: {e}. " f"Header: {header_clean}") print(f"Reading annotation file: {fname}") print(f"Column indices: motif_id={motif_id_idx}, gene_name={gene_name_idx}, " f"similarity_qvalue={similarity_qval_idx}, ortho_identity={ortho_identity_idx}, " f"description={description_idx}") row_count = 0 for row in reader: row_count += 1 if len(row) <= max(motif_id_idx, gene_name_idx, similarity_qval_idx, ortho_identity_idx, description_idx): continue try: motif_id = row[motif_id_idx].strip() gene_name = row[gene_name_idx].strip() # Skip empty rows if not motif_id or not gene_name: continue # Parse similarity_qvalue (may be 0 or scientific notation) similarity_qval_str = row[similarity_qval_idx].strip() if not similarity_qval_str or similarity_qval_str.lower() == 'none': similarity_qval = 0.0 # Treat 0 as direct annotation (best case) else: similarity_qval = float(similarity_qval_str) # Parse orthologous_identity (may be empty or 0) ortho_identity_str = row[ortho_identity_idx].strip() if not ortho_identity_str or ortho_identity_str.lower() == 'none': ortho_identity = float('nan') else: ortho_identity = float(ortho_identity_str) description = row[description_idx].strip() if description_idx < len(row) else "" # Apply filters if similarity_qval > motif_similarity_fdr: continue if not np.isnan(ortho_identity) and ortho_identity < orthologous_identity_threshold: continue annotations_list.append({ 'key': (gene_name, motif_id), 'motif_similarity_qvalue': similarity_qval, 'orthologous_identity': ortho_identity, 'annotation': description, }) except (ValueError, IndexError) as e: # Skip malformed rows if row_count <= 5: # Only warn for first few rows print(f"Warning: Skipping malformed row {row_count}: {e}") continue print(f"Read {row_count} rows, {len(annotations_list)} passed filters") # Sort and deduplicate: keep best annotation for each (TF, motif_id) # Match pyscenic: sort_values([COLUMN_NAME_MOTIF_SIMILARITY_QVALUE, COLUMN_NAME_ORTHOLOGOUS_IDENTITY], # ascending=[False, True]) then keep last (best) # ascending=[False, True] means: # - qvalue DESC: higher qvalue first, lower qvalue last (we want lower, so last is better) # - ortho ASC: lower ortho first, higher ortho last (we want higher, so last is better) # Then keep="last" means keep the last one after sorting, which is the best annotations_list.sort( key=lambda x: ( -x['motif_similarity_qvalue'], # DESC (negative for descending) x['orthologous_identity'] if not np.isnan(x['orthologous_identity']) else float('inf') # ASC ) ) # Deduplicate: keep last (best) for each (TF, motif_id) pair # pyscenic uses: ~annotated_features.index.duplicated(keep="last") # After sorting DESC/ASC, the last entry for each key is the best seen_keys = set() for ann in reversed(annotations_list): # Process in reverse to keep first best (which is last after sort) key = ann['key'] if key not in seen_keys: seen_keys.add(key) instance.annotations[key] = { 'motif_similarity_qvalue': ann['motif_similarity_qvalue'], 'orthologous_identity': ann['orthologous_identity'], 'annotation': ann['annotation'], } instance.all_motif_ids.add(key[1]) # motif_id instance.all_tf_names.add(key[0]) # TF name print(f"Loaded {len(instance.annotations)} motif annotations " f"({len(instance.all_motif_ids)} unique motifs, " f"{len(instance.all_tf_names)} unique TFs)") return instance
[docs] def has_annotation(self, motif_id: str, tf_name: Optional[str] = None) -> bool: """ Check if a motif has annotation. Args: motif_id: Motif ID tf_name: Optional TF name (if provided, checks (TF, motif) pair) Returns: True if annotation exists """ if tf_name is not None: return (tf_name, motif_id) in self.annotations return motif_id in self.all_motif_ids
[docs] def get_annotation( self, motif_id: str, tf_name: Optional[str] = None ) -> Optional[Dict]: """ Get annotation for a motif. Args: motif_id: Motif ID tf_name: Optional TF name Returns: Annotation dict or None """ if tf_name is not None: return self.annotations.get((tf_name, motif_id)) # If no TF specified, return first match (or None) for key, ann in self.annotations.items(): if key[1] == motif_id: return ann return None
[docs] def filter_by_annotations( result: Dict[str, torch.Tensor], motif_names: List[str], motif_annotations: Optional[MotifAnnotation], filter_for_annotation: bool = True, tf_name: Optional[str] = None ) -> Dict[str, torch.Tensor]: """ Filter enriched motifs by annotations (CPU implementation). Matches pyscenic behavior: filters enriched motifs to keep only those annotated for the specific TF of the module being pruned. Args: result: Pruning result dict with 'enriched_mask', 'nes', 'aucs', etc. motif_names: List of motif names (from database) motif_annotations: MotifAnnotation object (None = no filtering) filter_for_annotation: If True, only keep motifs with annotations tf_name: TF name to filter for. If provided, only keep motifs annotated for this specific TF (matching pySCENIC behavior). If None, keep motifs with any annotation. Returns: Filtered result dict (all tensors remain on original device) """ if motif_annotations is None or not filter_for_annotation: return result enriched_indices = torch.where(result['enriched_mask'])[0].cpu().numpy() if len(enriched_indices) == 0: return result # Create mask for motifs with annotations (CPU) # When tf_name is provided, only keep motifs annotated for that specific TF # (matching pySCENIC: motif must be annotated for the module's TF) device = result['enriched_mask'].device has_annotation_mask = torch.zeros(len(motif_names), dtype=torch.bool) for idx in enriched_indices: motif_id = motif_names[idx] if motif_annotations.has_annotation(motif_id, tf_name=tf_name): has_annotation_mask[idx] = True # Apply filter (keep on same device as result) has_annotation_mask = has_annotation_mask.to(device) new_enriched_mask = result['enriched_mask'] & has_annotation_mask # If no enriched motifs remain, return early n_enriched = new_enriched_mask.sum().item() if n_enriched == 0: result['enriched_mask'] = new_enriched_mask # Update leading_edge_masks and rank_at_max to empty n_module_genes = 0 if 'leading_edge_masks' in result and len(result['leading_edge_masks']) > 0: n_module_genes = result['leading_edge_masks'].shape[1] result['leading_edge_masks'] = torch.zeros((0, n_module_genes), dtype=torch.bool, device=device) result['rank_at_max'] = torch.zeros(0, dtype=torch.long, device=device) return result # Update enriched_mask result['enriched_mask'] = new_enriched_mask # Update leading_edge_masks and rank_at_max to only include filtered motifs old_enriched_indices = enriched_indices new_enriched_indices = torch.where(new_enriched_mask)[0].cpu().numpy() # Create mapping from old indices to new indices old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(new_enriched_indices)} # Filter leading_edge_masks and rank_at_max if 'leading_edge_masks' in result and len(result['leading_edge_masks']) > 0: filtered_masks = [] filtered_ranks = [] for i, old_idx in enumerate(old_enriched_indices): if old_idx in old_to_new: filtered_masks.append(result['leading_edge_masks'][i]) filtered_ranks.append(result['rank_at_max'][i]) if filtered_masks: result['leading_edge_masks'] = torch.stack(filtered_masks) result['rank_at_max'] = torch.stack(filtered_ranks) else: n_module_genes = result['leading_edge_masks'].shape[1] result['leading_edge_masks'] = torch.zeros((0, n_module_genes), dtype=torch.bool, device=device) result['rank_at_max'] = torch.zeros(0, dtype=torch.long, device=device) return result
[docs] def compute_leading_edge( rcc: torch.Tensor, avg2std_rcc: torch.Tensor, rankings: torch.Tensor, weights: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, int]: """ Compute leading edge genes for an enriched motif. Args: rcc: (rank_threshold,) - recovery curve for this motif avg2std_rcc: (rank_threshold,) - average + 2*std recovery curve rankings: (n_module_genes,) - rankings for module genes weights: (n_module_genes,) - gene weights Returns: mask: (n_module_genes,) - boolean mask for leading edge genes leading_weights: weights for leading edge genes rank_at_max: rank at maximum difference """ # Find rank at maximum difference diff = rcc - avg2std_rcc rank_at_max = int(diff.argmax().item()) # Get genes with rank <= rank_at_max mask = rankings <= rank_at_max return mask, weights[mask], rank_at_max
[docs] def prune_single_module( rankings: torch.Tensor, module_gene_indices: torch.Tensor, rank_threshold: int = 5000, auc_threshold: float = 0.05, nes_threshold: float = 3.0, weights: Optional[torch.Tensor] = None ) -> Dict[str, torch.Tensor]: """ Perform cisTarget pruning for a single module. All inputs and outputs are tensors on the same device. Args: rankings: (n_motifs, n_genes) - ranking database tensor module_gene_indices: (n_module_genes,) - gene indices for this module rank_threshold: Maximum rank for recovery curve auc_threshold: Fraction of genome for AUC nes_threshold: NES threshold for enrichment weights: Optional (n_module_genes,) gene weights Returns: Dict with keys: - enriched_mask: (n_motifs,) bool - which motifs are enriched - nes: (n_motifs,) - NES scores - aucs: (n_motifs,) - AUC scores - rccs: (n_motifs, rank_threshold) - recovery curves - leading_edge_masks: (n_enriched, n_module_genes) - leading edge for each enriched motif - rank_at_max: (n_enriched,) - rank at max for each enriched motif """ device = rankings.device n_motifs = rankings.shape[0] n_module_genes = len(module_gene_indices) if weights is None: weights = torch.ones(n_module_genes, device=device, dtype=torch.float32) # Compute AUCs and recovery curves rccs, aucs = compute_recovery_aucs( rankings, module_gene_indices, rank_threshold, auc_threshold, weights ) # Compute NES nes = compute_nes(aucs) # Filter enriched motifs enriched_mask = nes >= nes_threshold enriched_indices = torch.where(enriched_mask)[0] n_enriched = len(enriched_indices) result = { 'enriched_mask': enriched_mask, 'nes': nes, 'aucs': aucs, 'rccs': rccs, } if n_enriched == 0: result['leading_edge_masks'] = torch.zeros((0, n_module_genes), dtype=torch.bool, device=device) result['rank_at_max'] = torch.zeros(0, dtype=torch.long, device=device) return result # Compute average + 2*std recovery curve avg_rcc = rccs.mean(dim=0) std_rcc = rccs.std(dim=0, unbiased=False) avg2std_rcc = avg_rcc + 2.0 * std_rcc # Vectorized leading edge: only gather rankings for enriched motifs enriched_module_rankings = rankings[enriched_indices][:, module_gene_indices] enriched_rccs = rccs[enriched_indices] # (n_enriched, rank_threshold) diff = enriched_rccs - avg2std_rcc.unsqueeze(0) # (n_enriched, rank_threshold) rank_at_max = diff.argmax(dim=1) # (n_enriched,) leading_edge_masks = enriched_module_rankings <= rank_at_max.unsqueeze(1) result['leading_edge_masks'] = leading_edge_masks result['rank_at_max'] = rank_at_max return result
[docs] class CisTargetPruner: """ GPU-accelerated cisTarget pruning with support for single or multiple databases. Example (single database): ```python pruner = CisTargetPruner(device='cuda') pruner.load_database('rankings.feather') pruner.load_annotations('motifs.tbl', filter_for_annotation=True) # Prune with tensor input result = pruner.prune(module_gene_indices) ``` Example (multiple databases): ```python pruner = CisTargetPruner(device='cuda') pruner.load_database(['db_500bp.feather', 'db_10kb.feather']) pruner.load_annotations('motifs.tbl') # Prune modules across all databases regulon_info = pruner.prune_modules(modules, tf_names, gene_names) ``` """ def __init__( self, rank_threshold: int = 5000, # Match pyscenic CLI default auc_threshold: float = 0.05, nes_threshold: float = 3.0, device: str = 'cuda', min_genes_per_regulon: int = 0, # Minimum genes per regulon (for multi-db mode) merge_strategy: str = 'union' # 'union' or 'best' - how to merge regulons from multiple DBs ): self.rank_threshold = rank_threshold self.auc_threshold = auc_threshold self.nes_threshold = nes_threshold self.device = device self.min_genes_per_regulon = min_genes_per_regulon self.merge_strategy = merge_strategy # Single database mode self.rankings: Optional[torch.Tensor] = None self.n_motifs: int = 0 self.n_genes: int = 0 self.database_name: Optional[str] = None self.motif_names: Optional[List[str]] = None self.gene_names: Optional[List[str]] = None self.gene_to_idx: Optional[Dict[str, int]] = None # Multi-database mode self.pruners: List['CisTargetPruner'] = [] self.database_names: List[str] = [] self._multi_db_mode: bool = False # Motif annotations (shared across all databases) self.motif_annotations: Optional[MotifAnnotation] = None self.filter_for_annotation: bool = True
[docs] def load_database( self, paths: Union[str, List[str]], database_names: Optional[Union[str, List[str]]] = None ): """ Load ranking database(s) from feather file(s). Args: paths: Path to .feather ranking database, or list of paths for multiple databases database_names: Optional name(s) for database(s) (defaults to filename(s)) """ import pyarrow.feather as pf import os # Check if single or multiple databases if isinstance(paths, str): # Single database mode self._multi_db_mode = False path = paths database_name = database_names if isinstance(database_names, str) else None if database_name is None: database_name = os.path.splitext(os.path.basename(path))[0] self.database_name = database_name table = pf.read_table(path) columns = table.column_names # First column is usually the index (motif/region names) # Remaining columns are genes index_col = columns[-1] gene_cols = columns[:-1] # Extract data self.motif_names = table.column(index_col).to_pylist() self.gene_names = list(gene_cols) self.gene_to_idx = {g: i for i, g in enumerate(self.gene_names)} # Load rankings as numpy then convert to tensor ranking_data = table.select(gene_cols).to_pandas().values self.rankings = torch.tensor( ranking_data, device=self.device, dtype=torch.int32 ) self.n_motifs, self.n_genes = self.rankings.shape print(f"Loaded database '{database_name}': {self.n_motifs} motifs × {self.n_genes} genes") else: # Multiple databases mode self._multi_db_mode = True database_paths = paths if database_names is None: database_names = [os.path.splitext(os.path.basename(p))[0] for p in database_paths] elif isinstance(database_names, str): database_names = [database_names] self.pruners = [] self.database_names = [] for path, name in zip(database_paths, database_names): pruner = CisTargetPruner( rank_threshold=self.rank_threshold, auc_threshold=self.auc_threshold, nes_threshold=self.nes_threshold, device=self.device ) # For single database, pass database_name as string pruner.load_database(path, database_names=name) # Share motif annotations if loaded if self.motif_annotations is not None: pruner.motif_annotations = self.motif_annotations pruner.filter_for_annotation = self.filter_for_annotation self.pruners.append(pruner) self.database_names.append(name) print(f"Loaded {len(self.pruners)} databases")
[docs] def load_annotations( self, annotation_file: str, filter_for_annotation: bool = True, motif_similarity_fdr: float = 0.001, orthologous_identity_threshold: float = 0.0 ): """ Load motif annotations and enable filtering. Args: annotation_file: Path to motif annotation TSV file filter_for_annotation: If True, filter enriched motifs to keep only those with annotations motif_similarity_fdr: Maximum FDR threshold (default: 0.001) orthologous_identity_threshold: Minimum orthologous identity (default: 0.0) """ self.motif_annotations = MotifAnnotation.load_from_file( annotation_file, motif_similarity_fdr=motif_similarity_fdr, orthologous_identity_threshold=orthologous_identity_threshold ) self.filter_for_annotation = filter_for_annotation # Update sub-pruners if in multi-db mode if self._multi_db_mode: for pruner in self.pruners: pruner.motif_annotations = self.motif_annotations pruner.filter_for_annotation = self.filter_for_annotation print(f"Loaded motif annotations (filter_for_annotation={filter_for_annotation})")
[docs] def load_from_tensor( self, rankings: ArrayLike, motif_names: Optional[List[str]] = None, gene_names: Optional[List[str]] = None ): """ Load database from tensor/array directly. Args: rankings: (n_motifs, n_genes) ranking matrix motif_names: Optional list of motif names gene_names: Optional list of gene names """ if isinstance(rankings, np.ndarray): rankings = torch.from_numpy(rankings) self.rankings = rankings.to(device=self.device, dtype=torch.int32) self.n_motifs, self.n_genes = self.rankings.shape self.motif_names = motif_names self.gene_names = gene_names if gene_names: self.gene_to_idx = {g: i for i, g in enumerate(gene_names)}
[docs] def genes_to_indices(self, genes: List[str]) -> torch.Tensor: """Convert gene names to indices tensor.""" if self.gene_to_idx is None: raise ValueError("Gene names not loaded. Use load_from_tensor with gene_names.") indices = [self.gene_to_idx[g] for g in genes if g in self.gene_to_idx] return torch.tensor(indices, device=self.device, dtype=torch.long)
[docs] def prune( self, module_gene_indices: ArrayLike, weights: Optional[ArrayLike] = None, tf_name: Optional[str] = None ) -> Dict[str, torch.Tensor]: """ Prune a single module (single database mode only). Args: module_gene_indices: (n_module_genes,) indices of genes in module weights: Optional (n_module_genes,) gene weights tf_name: TF name for TF-specific annotation filtering (matching pySCENIC behavior). If None, keeps motifs with any annotation. Returns: Dict with pruning results (all tensors) """ if self._multi_db_mode: raise ValueError("prune() is for single database mode. Use prune_modules() for multiple databases.") if self.rankings is None: raise ValueError("Database not loaded") # Convert to tensor if needed if isinstance(module_gene_indices, np.ndarray): module_gene_indices = torch.from_numpy(module_gene_indices) module_gene_indices = module_gene_indices.to(device=self.device, dtype=torch.long) if weights is not None: if isinstance(weights, np.ndarray): weights = torch.from_numpy(weights) weights = weights.to(device=self.device, dtype=torch.float32) result = prune_single_module( self.rankings, module_gene_indices, self.rank_threshold, self.auc_threshold, self.nes_threshold, weights ) # Apply annotation filtering if enabled if self.motif_annotations is not None and self.filter_for_annotation: result = filter_by_annotations( result, self.motif_names, self.motif_annotations, filter_for_annotation=self.filter_for_annotation, tf_name=tf_name ) return result
[docs] def prune_batch( self, modules: List[torch.Tensor], weights_list: Optional[List[torch.Tensor]] = None ) -> List[Dict[str, torch.Tensor]]: """ Prune multiple modules. Args: modules: List of (n_genes_i,) tensors with gene indices weights_list: Optional list of weight tensors Returns: List of pruning result dicts """ results = [] for i, module_indices in enumerate(modules): weights = weights_list[i] if weights_list else None results.append(self.prune(module_indices, weights)) return results
[docs] def get_enriched_motif_names(self, result: Dict[str, torch.Tensor]) -> List[str]: """Get names of enriched motifs from pruning result.""" if self.motif_names is None: raise ValueError("Motif names not loaded") enriched_indices = torch.where(result['enriched_mask'])[0].cpu().numpy() return [self.motif_names[i] for i in enriched_indices]
[docs] def get_leading_edge_genes( self, result: Dict[str, torch.Tensor], module_gene_indices: torch.Tensor ) -> List[List[str]]: """ Get leading edge gene names for each enriched motif. Args: result: Pruning result dict module_gene_indices: Original module gene indices Returns: List of gene name lists, one per enriched motif """ if self.gene_names is None: raise ValueError("Gene names not loaded") leading_edges = [] for mask in result['leading_edge_masks']: gene_indices = module_gene_indices[mask].cpu().numpy() genes = [self.gene_names[i] for i in gene_indices] leading_edges.append(genes) return leading_edges
[docs] def prune_modules( self, modules: List[torch.Tensor], tf_names: List[str], gene_names: List[str], weights_list: Optional[List[torch.Tensor]] = None ) -> List[Dict]: """ Prune modules across all databases and merge results (multi-database mode only). Args: modules: List of (n_genes_i,) tensors with gene indices for each TF module tf_names: List of TF names corresponding to modules gene_names: List of all gene names weights_list: Optional list of weight tensors for each module Returns: List of regulon dictionaries with keys: name, tf, motif, n_genes, genes, context, nes, auc """ if not self._multi_db_mode: raise ValueError("prune_modules() is for multi-database mode. Use prune() for single database.") all_regulons = [] # Pre-transfer all modules to CPU numpy once (avoid repeated GPU→CPU copies) modules_np = [m.cpu().numpy() for m in modules] # Prune modules for each database for db_idx, (pruner, db_name) in enumerate(zip(self.pruners, self.database_names)): print(f"\nPruning with database {db_idx + 1}/{len(self.pruners)}: {db_name}") # Precompute gene name → DB index mapping once per database gene_to_db_mapping = np.full(len(gene_names), -1, dtype=np.int64) for i, name in enumerate(gene_names): if name in pruner.gene_to_idx: gene_to_db_mapping[i] = pruner.gene_to_idx[name] for module_idx, (module_np, tf_name) in enumerate(zip(modules_np, tf_names)): # Vectorized gene index conversion using precomputed mapping mapped_indices = gene_to_db_mapping[module_np] mapped_indices = mapped_indices[mapped_indices >= 0] if len(mapped_indices) < 20: # Skip if too few genes mapped continue db_indices = torch.tensor(mapped_indices, device=self.device, dtype=torch.long) # Get weights if provided weights = weights_list[module_idx] if weights_list else None # Prune (pass tf_name for TF-specific annotation filtering) result = pruner.prune(db_indices, weights, tf_name=tf_name) # Check for enriched motifs n_enriched = result['enriched_mask'].sum().item() if n_enriched > 0: # Get enriched motifs and their info enriched_indices = torch.where(result['enriched_mask'])[0].cpu().numpy() enriched_motif_names = pruner.get_enriched_motif_names(result) leading_edge_genes = pruner.get_leading_edge_genes(result, db_indices) nes_values = result['nes'][enriched_indices].cpu().numpy() auc_values = result['aucs'][enriched_indices].cpu().numpy() # Create regulon for each enriched motif for motif_name, le_genes, nes, auc in zip( enriched_motif_names, leading_edge_genes, nes_values, auc_values ): # Filter by minimum genes if specified if len(le_genes) < self.min_genes_per_regulon: continue regulon_name = f"{tf_name}_{motif_name}" all_regulons.append({ 'name': regulon_name, 'tf': tf_name, 'motif': motif_name, 'n_genes': len(le_genes), 'genes': le_genes, 'context': db_name, # Store database name in context 'nes': float(nes), 'auc': float(auc), 'database': db_name }) # Merge regulons from multiple databases print(f"\nTotal regulons before merging: {len(all_regulons)}") merged_regulons = self._merge_regulons(all_regulons) print(f"Total regulons after merging: {len(merged_regulons)}") # Merge regulons by TF (matching pyscenic behavior) # pyscenic groups by (TF, Type) and merges all motifs for each TF into one regulon final_regulons = self._merge_regulons_by_tf(merged_regulons) print(f"Total regulons after TF merging: {len(final_regulons)}") return final_regulons
[docs] def _merge_regulons(self, regulons: List[Dict]) -> List[Dict]: """ Merge regulons from multiple databases. For the same TF+motif combination: - If merge_strategy='union': keep all (they may have different genes from different DBs) - If merge_strategy='best': keep the one with highest NES Note: pyscenic uses union strategy - it merges genes from all databases for the same TF+motif combination. """ if self.merge_strategy == 'best': # Group by TF+motif, keep best NES regulon_dict = {} for reg in regulons: key = (reg['tf'], reg['motif']) if key not in regulon_dict or reg['nes'] > regulon_dict[key]['nes']: regulon_dict[key] = reg return list(regulon_dict.values()) elif self.merge_strategy == 'union': # Group by TF+motif, merge genes (union) regulon_groups = defaultdict(list) for reg in regulons: key = (reg['tf'], reg['motif']) regulon_groups[key].append(reg) merged = [] for (tf, motif), group in regulon_groups.items(): # Merge genes (union) all_genes = set() all_contexts = set() best_nes = max(reg['nes'] for reg in group) best_auc = max(reg['auc'] for reg in group) for reg in group: all_genes.update(reg['genes']) all_contexts.add(reg['context']) merged.append({ 'name': f"{tf}_{motif}", 'tf': tf, 'motif': motif, 'n_genes': len(all_genes), 'genes': list(all_genes), 'context': ','.join(sorted(all_contexts)), # All database names 'nes': best_nes, 'auc': best_auc, 'database': ','.join(sorted(all_contexts)) }) return merged else: raise ValueError(f"Unknown merge_strategy: {self.merge_strategy}")
[docs] def _merge_regulons_by_tf(self, regulons: List[Dict]) -> List[Dict]: """ Merge regulons by TF, matching pyscenic's df2regulons behavior. pyscenic groups by (TF, Type) and uses Regulon.union to merge all motifs for each TF into a single regulon. This function implements the same logic. Args: regulons: List of regulon dictionaries Returns: Merged regulons (one per TF) """ from collections import defaultdict # Group by TF tf_to_regulons = defaultdict(list) for reg in regulons: tf_to_regulons[reg['tf']].append(reg) merged = [] for tf, regs in tf_to_regulons.items(): # Merge genes (union) - matching Regulon.union behavior all_genes = set() all_motifs = set() best_nes = max(reg['nes'] for reg in regs) best_auc = max(reg['auc'] for reg in regs) # Collect all genes and motifs for reg in regs: all_genes.update(reg['genes']) all_motifs.add(reg['motif']) # Use TF name as regulon name (matching pyscenic: "{tf}(+)") # pyscenic uses the highest NES motif's name in context, but regulon name is just TF merged.append({ 'name': f"{tf}(+)", # pyscenic format: TF name + interaction type 'tf': tf, 'motif': ','.join(sorted(all_motifs)), # All motifs (for reference) 'n_genes': len(all_genes), 'genes': list(all_genes), 'nes': best_nes, # Keep best NES (pyscenic keeps max combined score) 'auc': best_auc, 'context': regs[0].get('context', ''), # Keep context from first regulon 'database': regs[0].get('database', '') # Keep database info }) return merged
[docs] def clear_gpu_memory(self): """Release GPU memory.""" if self._multi_db_mode: for pruner in self.pruners: pruner.clear_gpu_memory() else: if self.rankings is not None: del self.rankings self.rankings = None if torch.cuda.is_available(): torch.cuda.empty_cache()
# Backwards compatibility aliases MultiDatabaseCisTargetPruner = CisTargetPruner GPUCisTargetPruner = CisTargetPruner