"""Disease signature construction. Week 1 (PLAN.md §6). Builds a Tier-A sickle cell signature from GEO expression data via differential expression on TWO studies, keeping only genes that are concordant across both (the multi-source evidence that earns Tier A — see project memory). Persists with full provenance to ``data/processed/sickle_cell_signature_v1.json``. Pipeline (driven from ``notebooks/02_disease_signature.ipynb``): per study: compute_differential_expression -> collapse_probes_to_symbols across: concordance_filter -> build_signature -> persist_signature Conventions: - log_fc > 0 => up-regulated in DISEASE vs HEALTHY (Week 1 refinement R1) - cross-study join key is the HGNC gene symbol (R2) """ from __future__ import annotations from pathlib import Path import numpy as np import pandas as pd from pydantic import BaseModel, Field from scipy.stats import false_discovery_control, ttest_ind from . import PIPELINE_VERSION, PROCESSED_DIR from .provenance import ConfidenceTier # Number of genes to take per direction (PLAN.md §6, Week 1 task 5). TOP_N_PER_DIRECTION = 250 QVALUE_CUTOFF = 0.05 # Group labels used throughout. The disease group is the "treatment" arm in DE contrasts. DISEASE_LABEL = "disease" HEALTHY_LABEL = "healthy" class GeneEntry(BaseModel): """A single differentially expressed gene in the signature.""" gene: str = Field(..., description="HGNC gene symbol, e.g. 'HBG2'.") entrez_id: str | None = None ensembl_id: str | None = None log_fc: float qvalue: float class StudyProvenance(BaseModel): """Provenance for one contributing GEO study (Week 1 refinement R6).""" geo_accession: str n_disease: int n_healthy: int platform: str tissue: str method: str = Field(..., description="DE method: 'welch' (microarray) or 'deseq2' (RNA-seq).") class ConcordanceSummary(BaseModel): """How the two studies were reconciled into the signature (R6).""" rule: str = Field( default="q<0.05 in both studies AND same sign of log_fc in both", description="The concordance rule applied across studies.", ) n_genes_tested: int = Field(..., description="Genes present in both studies after symbol collapse.") n_concordant: int n_up: int n_down: int class SignatureProvenance(BaseModel): """Provenance block for a disease signature (revised for 2-study concordance, R6).""" studies: list[StudyProvenance] concordance: ConcordanceSummary created_date: str pipeline_version: str = PIPELINE_VERSION class DiseaseSignature(BaseModel): """The persisted sickle cell disease signature (PLAN.md §6 schema).""" signature_id: str = "sickle_cell_v1" disease_mondo_id: str = "MONDO:0011382" pipeline_version: str = PIPELINE_VERSION up_regulated: list[GeneEntry] down_regulated: list[GeneEntry] provenance: SignatureProvenance confidence_tier: ConfidenceTier tier_rationale: str limitations: list[str] def _looks_like_linear_intensity(expression: pd.DataFrame) -> bool: """Heuristic: microarray data still on a linear scale has a large dynamic range. log2-transformed intensities are typically <~20; raw intensities run into the thousands. """ return float(np.nanmax(expression.to_numpy())) > 50.0 def _welch_de(expression: pd.DataFrame, groups: pd.Series) -> pd.DataFrame: """Per-gene Welch t-test + Benjamini-Hochberg, the limma-equivalent for microarray. Assumes ``expression`` is genes (rows) x samples (columns), on (or coercible to) a log2 scale. ``log_fc`` is mean(disease) - mean(healthy) (R1 sign convention). """ expr = expression.copy() if _looks_like_linear_intensity(expr): # log2(x+1) guards against zeros/negatives while keeping the transform standard. expr = np.log2(expr.clip(lower=0) + 1.0) disease_cols = groups.index[groups == DISEASE_LABEL] healthy_cols = groups.index[groups == HEALTHY_LABEL] disease = expr[disease_cols].to_numpy() healthy = expr[healthy_cols].to_numpy() log_fc = disease.mean(axis=1) - healthy.mean(axis=1) # Welch's t-test (unequal variance) per gene across samples. _, pvalue = ttest_ind(disease, healthy, axis=1, equal_var=False, nan_policy="omit") out = pd.DataFrame({"log_fc": log_fc, "pvalue": pvalue}, index=expr.index) out = out.dropna(subset=["pvalue"]) out["qvalue"] = false_discovery_control(out["pvalue"].to_numpy(), method="bh") return out.sort_values("qvalue") def _deseq2_de(counts: pd.DataFrame, groups: pd.Series) -> pd.DataFrame: """RNA-seq differential expression via pydeseq2. ``counts`` is genes (rows) x samples (columns) of raw integer counts. Returns the same schema as ``_welch_de`` (log_fc = log2FC of disease vs healthy, plus pvalue/qvalue). """ from pydeseq2.dds import DeseqDataSet from pydeseq2.ds import DeseqStats # pydeseq2 wants samples (rows) x genes (cols), with an aligned metadata frame. counts_t = counts.T.round().astype(int) metadata = pd.DataFrame({"condition": groups.reindex(counts_t.index).to_numpy()}, index=counts_t.index) dds = DeseqDataSet(counts=counts_t, metadata=metadata, design="~condition", quiet=True) dds.deseq2() stats = DeseqStats(dds, contrast=["condition", DISEASE_LABEL, HEALTHY_LABEL], quiet=True) stats.summary() res = stats.results_df.rename(columns={"log2FoldChange": "log_fc", "pvalue": "pvalue", "padj": "qvalue"}) return res[["log_fc", "pvalue", "qvalue"]].dropna(subset=["qvalue"]).sort_values("qvalue") def compute_differential_expression( expression: pd.DataFrame, sample_groups: pd.Series, *, method: str, ) -> pd.DataFrame: """Compute gene-level log fold change and adjusted p-values for one study. Args: expression: Genes (rows) x samples (columns). Microarray intensities or RNA-seq counts. sample_groups: Per-sample 'disease'/'healthy' label, indexed by sample id (column). method: 'welch' (microarray, PLAN choice) or 'deseq2' (RNA-seq). Returns: Table indexed by gene/probe with ``log_fc``, ``pvalue``, ``qvalue`` (R1 sign convention). """ groups = sample_groups.reindex(expression.columns) if groups.isna().any(): missing = list(groups.index[groups.isna()]) raise ValueError(f"sample_groups missing labels for samples: {missing[:5]}...") if method == "welch": return _welch_de(expression, groups) if method == "deseq2": return _deseq2_de(expression, groups) raise ValueError(f"Unknown method {method!r}; expected 'welch' or 'deseq2'.") def collapse_probes_to_symbols( de_table: pd.DataFrame, probe_to_symbol: pd.Series, expression_for_ranking: pd.DataFrame | None = None, ) -> pd.DataFrame: """Collapse a probe-level DE table to one row per HGNC symbol (R2). When multiple probes map to the same symbol, keep the probe with the highest mean expression (the standard collapseRows heuristic) if ``expression_for_ranking`` is given; otherwise keep the probe with the smallest qvalue. Args: de_table: DE table indexed by probe id (from ``compute_differential_expression``). probe_to_symbol: Probe id -> HGNC symbol mapping. expression_for_ranking: Optional genes x samples matrix to rank duplicate probes. Returns: DE table indexed by HGNC symbol. """ df = de_table.copy() df["symbol"] = probe_to_symbol.reindex(df.index) df = df.dropna(subset=["symbol"]) if expression_for_ranking is not None: rank_key = expression_for_ranking.reindex(df.index).mean(axis=1) df = df.assign(_rank=rank_key).sort_values("_rank", ascending=False) else: df = df.sort_values("qvalue", ascending=True) collapsed = df[~df["symbol"].duplicated(keep="first")].set_index("symbol") return collapsed.drop(columns=[c for c in ("_rank",) if c in collapsed.columns]) def concordance_filter( de_a: pd.DataFrame, de_b: pd.DataFrame, *, qvalue_cutoff: float = QVALUE_CUTOFF, ) -> tuple[pd.DataFrame, ConcordanceSummary]: """Keep only genes concordant across two symbol-level DE tables (R3). A gene qualifies iff q<``qvalue_cutoff`` in BOTH studies AND log_fc has the same sign in both. Reported ``log_fc`` = mean of the two; ``qvalue`` = max of the two (worst case). Ranked by ``max(q_a, q_b)`` ascending. Returns: (concordant_table indexed by symbol with log_fc/qvalue, ConcordanceSummary). """ merged = de_a.join(de_b, lsuffix="_a", rsuffix="_b", how="inner") n_tested = len(merged) sig_both = (merged["qvalue_a"] < qvalue_cutoff) & (merged["qvalue_b"] < qvalue_cutoff) same_sign = np.sign(merged["log_fc_a"]) == np.sign(merged["log_fc_b"]) keep = merged[sig_both & same_sign].copy() keep["log_fc"] = (keep["log_fc_a"] + keep["log_fc_b"]) / 2.0 keep["qvalue"] = keep[["qvalue_a", "qvalue_b"]].max(axis=1) keep = keep[["log_fc", "qvalue"]].sort_values("qvalue") summary = ConcordanceSummary( n_genes_tested=n_tested, n_concordant=len(keep), n_up=int((keep["log_fc"] > 0).sum()), n_down=int((keep["log_fc"] < 0).sum()), ) return keep, summary def build_signature( concordant_table: pd.DataFrame, provenance: SignatureProvenance, *, tier: ConfidenceTier, tier_rationale: str, limitations: list[str], id_map: pd.DataFrame | None = None, top_n: int = TOP_N_PER_DIRECTION, ) -> DiseaseSignature: """Assemble a ``DiseaseSignature`` from the concordant gene table. Takes the top ``top_n`` up- and down-regulated genes by qvalue per direction (R3). If fewer than ``top_n`` qualify in a direction, takes all available (the caller should have logged the count). ``id_map`` optionally provides 'entrez_id'/'ensembl_id' columns indexed by symbol (from ``mygene``). Args: concordant_table: Output of ``concordance_filter`` (indexed by symbol). provenance: Fully populated ``SignatureProvenance`` (both studies + concordance). tier: Confidence tier (Tier A only if genuinely multi-source). tier_rationale: One-line justification of the tier. limitations: Honest limitations list (R8). id_map: Optional symbol -> {entrez_id, ensembl_id} table. top_n: Max genes per direction. Returns: A populated ``DiseaseSignature``. """ def _entries(rows: pd.DataFrame) -> list[GeneEntry]: entries: list[GeneEntry] = [] for symbol, row in rows.iterrows(): ids = id_map.loc[symbol] if (id_map is not None and symbol in id_map.index) else None entries.append( GeneEntry( gene=str(symbol), entrez_id=(str(ids["entrez_id"]) if ids is not None and pd.notna(ids.get("entrez_id")) else None), ensembl_id=(str(ids["ensembl_id"]) if ids is not None and pd.notna(ids.get("ensembl_id")) else None), log_fc=float(row["log_fc"]), qvalue=float(row["qvalue"]), ) ) return entries up = concordant_table[concordant_table["log_fc"] > 0].sort_values("qvalue").head(top_n) down = concordant_table[concordant_table["log_fc"] < 0].sort_values("qvalue").head(top_n) return DiseaseSignature( up_regulated=_entries(up), down_regulated=_entries(down), provenance=provenance, confidence_tier=tier, tier_rationale=tier_rationale, limitations=limitations, ) def persist_signature(signature: DiseaseSignature, out_path: Path | None = None) -> Path: """Write a signature to ``data/processed/sickle_cell_signature_v1.json``.""" out_path = out_path or (PROCESSED_DIR / "sickle_cell_signature_v1.json") out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(signature.model_dump_json(indent=2)) return out_path