diff --git a/scripts/phaseA_reference_tau.py b/scripts/phaseA_reference_tau.py new file mode 100644 index 0000000..d24f5d6 --- /dev/null +++ b/scripts/phaseA_reference_tau.py @@ -0,0 +1,118 @@ +"""Phase A: calibrate sickle connectivity against a REAL disease-signature reference population. + +The v1.1 tau saturated because it used random gene-set nulls. Proper specificity calibration +asks: does a drug reverse SICKLE more than it reverses diseases in general? We download a library +of real disease signatures (Enrichr "Disease Signatures from GEO", up+down), compute each drug's +connectivity to every reference disease, and express its sickle connectivity as a z-score within +that per-drug reference distribution. Broad-effect drugs (reverse everything) -> z~0 -> down-ranked. + +Re-runs the recovery test and compares to v1.1 (random-null). Writes nothing committed. +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pandas as pd +import requests + +import sys +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) +from src.scoring import _ks_connectivity # noqa: E402 + +PROCESSED = Path("data/processed") +RAW = Path("data/raw/disease_sigs") +ENRICHR = "https://maayanlab.cloud/Enrichr/geneSetLibrary?mode=text&libraryName={}" +LIBS = {"up": "Disease_Signatures_from_GEO_up_2014", "down": "Disease_Signatures_from_GEO_down_2014"} +NEG5 = ["clotrimazole", "astemizole", "azithromycin", "ethinyl-estradiol", "caffeine"] + + +def fetch_gmt(name: str) -> dict[str, list[str]]: + RAW.mkdir(parents=True, exist_ok=True) + path = RAW / f"{name}.gmt" + if not path.exists(): + r = requests.get(ENRICHR.format(name), timeout=120) + r.raise_for_status() + path.write_text(r.text) + out = {} + for line in path.read_text().splitlines(): + parts = line.split("\t") + if len(parts) < 3: + continue + term = parts[0] + genes = [g.split(",")[0].strip().upper() for g in parts[2:] if g.strip()] + out[term] = genes + print(f" {name}: {path.stat().st_size/1e6:.1f} MB, {len(out)} terms") + return out + + +def build_reference() -> list[dict]: + up = fetch_gmt(LIBS["up"]) + down = fetch_gmt(LIBS["down"]) + shared = set(up) & set(down) + refs = [] + for term in shared: + if "sickle" in term.lower(): + continue # exclude the target disease from its own reference population + refs.append({"name": term, "up": up[term], "down": down[term]}) + print(f"reference disease signatures (paired up+down, sickle excluded): {len(refs)}") + return refs + + +def cols_for(genes, gene_to_col): + return np.array([gene_to_col[g] for g in set(genes) if g in gene_to_col], dtype=int) + + +def main(): + import json + sig = json.loads((PROCESSED / "sickle_cell_signature_v1.json").read_text()) + sk_up = [g["gene"] for g in sig["up_regulated"]] + sk_down = [g["gene"] for g in sig["down_regulated"]] + + lincs = pd.read_parquet(PROCESSED / "lincs_signatures_v1.parquet") + genes = list(lincs.columns) + gene_to_col = {g: i for i, g in enumerate(genes)} + n = len(genes) + R = lincs.rank(axis=1, ascending=False).to_numpy() + + refs = build_reference() + # connectivity of every drug to every reference disease -> (n_drugs, n_refs) + C = np.empty((R.shape[0], len(refs))) + for j, d in enumerate(refs): + C[:, j] = _ks_connectivity(R, cols_for(d["up"], gene_to_col), cols_for(d["down"], gene_to_col), n) + # drop reference diseases with too few mapped genes (degenerate columns) + mapped = np.array([len(cols_for(d["up"], gene_to_col)) + len(cols_for(d["down"], gene_to_col)) for d in refs]) + keep = mapped >= 10 + C = C[:, keep] + print(f"usable reference diseases (>=10 mapped genes): {keep.sum()}") + + real = _ks_connectivity(R, cols_for(sk_up, gene_to_col), cols_for(sk_down, gene_to_col), n) + ref_mean, ref_std = C.mean(axis=1), C.std(axis=1) + ref_std[ref_std == 0] = np.nan + spec_z = (real - ref_mean) / ref_std # negative = reverses sickle more than diseases-in-general + + ranked = pd.DataFrame({"spec_z": spec_z, "connectivity": real}, index=lincs.index).sort_values("spec_z") + ranked.insert(0, "rank", range(1, len(ranked) + 1)) + profiles = pd.read_parquet(PROCESSED / "drug_profiles_v1.parquet").set_index("name") + ranked = ranked.join(profiles[["inclusion_reason"]]) + + N = len(ranked) + top10, top25, half = int(N * .10), int(N * .25), N // 2 + hu, glut = int(ranked.loc["hydroxyurea", "rank"]), int(ranked.loc["glutamine", "rank"]) + negs = {d: int(ranked.loc[d, "rank"]) for d in NEG5 if d in ranked.index} + n_bottom = sum(r > half for r in negs.values()) + + print("\n==== RECOVERY TEST (reference-calibrated tau) ====") + print(f" hydroxyurea rank {hu}/{N} (top {100*hu/N:.1f}%) top-10%? {hu <= top10} z={ranked.loc['hydroxyurea','spec_z']:.2f}") + print(f" L-glutamine rank {glut}/{N} (top {100*glut/N:.1f}%) top-25%? {glut <= top25}") + print(f" neg controls bottom-half: {n_bottom}/5 {negs}") + crit = (hu <= top10) and (glut <= top25) and (n_bottom >= 4) + print(f" OVERALL: {'PASS' if crit else 'FAIL'}") + print("\n top 12 by reference-calibrated z:") + for name, r in ranked.nsmallest(12, "spec_z").iterrows(): + print(f" {int(r['rank']):2d} {name:20s} z={r['spec_z']:6.2f} [{r['inclusion_reason']}]") + + +if __name__ == "__main__": + main()