"""Phase A: calibrate sickle connectivity against a REAL disease-signature reference population. The v1.1 tau saturated because it used random gene-set nulls. Proper specificity calibration asks: does a drug reverse SICKLE more than it reverses diseases in general? We download a library of real disease signatures (Enrichr "Disease Signatures from GEO", up+down), compute each drug's connectivity to every reference disease, and express its sickle connectivity as a z-score within that per-drug reference distribution. Broad-effect drugs (reverse everything) -> z~0 -> down-ranked. Re-runs the recovery test and compares to v1.1 (random-null). Writes nothing committed. """ from __future__ import annotations from pathlib import Path import numpy as np import pandas as pd import requests import sys sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from src.scoring import _ks_connectivity # noqa: E402 PROCESSED = Path("data/processed") RAW = Path("data/raw/disease_sigs") ENRICHR = "https://maayanlab.cloud/Enrichr/geneSetLibrary?mode=text&libraryName={}" LIBS = {"up": "Disease_Signatures_from_GEO_up_2014", "down": "Disease_Signatures_from_GEO_down_2014"} NEG5 = ["clotrimazole", "astemizole", "azithromycin", "ethinyl-estradiol", "caffeine"] def fetch_gmt(name: str) -> dict[str, list[str]]: RAW.mkdir(parents=True, exist_ok=True) path = RAW / f"{name}.gmt" if not path.exists(): r = requests.get(ENRICHR.format(name), timeout=120) r.raise_for_status() path.write_text(r.text) out = {} for line in path.read_text().splitlines(): parts = line.split("\t") if len(parts) < 3: continue term = parts[0] genes = [g.split(",")[0].strip().upper() for g in parts[2:] if g.strip()] out[term] = genes print(f" {name}: {path.stat().st_size/1e6:.1f} MB, {len(out)} terms") return out def build_reference() -> list[dict]: up = fetch_gmt(LIBS["up"]) down = fetch_gmt(LIBS["down"]) shared = set(up) & set(down) refs = [] for term in shared: if "sickle" in term.lower(): continue # exclude the target disease from its own reference population refs.append({"name": term, "up": up[term], "down": down[term]}) print(f"reference disease signatures (paired up+down, sickle excluded): {len(refs)}") return refs def cols_for(genes, gene_to_col): return np.array([gene_to_col[g] for g in set(genes) if g in gene_to_col], dtype=int) def main(): import json sig = json.loads((PROCESSED / "sickle_cell_signature_v1.json").read_text()) sk_up = [g["gene"] for g in sig["up_regulated"]] sk_down = [g["gene"] for g in sig["down_regulated"]] lincs = pd.read_parquet(PROCESSED / "lincs_signatures_v1.parquet") genes = list(lincs.columns) gene_to_col = {g: i for i, g in enumerate(genes)} n = len(genes) R = lincs.rank(axis=1, ascending=False).to_numpy() refs = build_reference() # connectivity of every drug to every reference disease -> (n_drugs, n_refs) C = np.empty((R.shape[0], len(refs))) for j, d in enumerate(refs): C[:, j] = _ks_connectivity(R, cols_for(d["up"], gene_to_col), cols_for(d["down"], gene_to_col), n) # drop reference diseases with too few mapped genes (degenerate columns) mapped = np.array([len(cols_for(d["up"], gene_to_col)) + len(cols_for(d["down"], gene_to_col)) for d in refs]) keep = mapped >= 10 C = C[:, keep] print(f"usable reference diseases (>=10 mapped genes): {keep.sum()}") real = _ks_connectivity(R, cols_for(sk_up, gene_to_col), cols_for(sk_down, gene_to_col), n) ref_mean, ref_std = C.mean(axis=1), C.std(axis=1) ref_std[ref_std == 0] = np.nan spec_z = (real - ref_mean) / ref_std # negative = reverses sickle more than diseases-in-general ranked = pd.DataFrame({"spec_z": spec_z, "connectivity": real}, index=lincs.index).sort_values("spec_z") ranked.insert(0, "rank", range(1, len(ranked) + 1)) profiles = pd.read_parquet(PROCESSED / "drug_profiles_v1.parquet").set_index("name") ranked = ranked.join(profiles[["inclusion_reason"]]) N = len(ranked) top10, top25, half = int(N * .10), int(N * .25), N // 2 hu, glut = int(ranked.loc["hydroxyurea", "rank"]), int(ranked.loc["glutamine", "rank"]) negs = {d: int(ranked.loc[d, "rank"]) for d in NEG5 if d in ranked.index} n_bottom = sum(r > half for r in negs.values()) print("\n==== RECOVERY TEST (reference-calibrated tau) ====") print(f" hydroxyurea rank {hu}/{N} (top {100*hu/N:.1f}%) top-10%? {hu <= top10} z={ranked.loc['hydroxyurea','spec_z']:.2f}") print(f" L-glutamine rank {glut}/{N} (top {100*glut/N:.1f}%) top-25%? {glut <= top25}") print(f" neg controls bottom-half: {n_bottom}/5 {negs}") crit = (hu <= top10) and (glut <= top25) and (n_bottom >= 4) print(f" OVERALL: {'PASS' if crit else 'FAIL'}") print("\n top 12 by reference-calibrated z:") for name, r in ranked.nsmallest(12, "spec_z").iterrows(): print(f" {int(r['rank']):2d} {name:20s} z={r['spec_z']:6.2f} [{r['inclusion_reason']}]") if __name__ == "__main__": main()