"""Phase D: supervised cross-disease repurposing — can labels break the specificity ceiling? Connectivity alone can't tell therapeutic from coincidental reversal. A supervised model trained on known drug-disease pairs CAN learn that pattern — given features that expose drug "broadness" (a drug that reverses everything is non-specific). We train on 839 GEO disease signatures with Repurposing-Hub indications as labels, evaluate with disease-grouped CV, then apply to HELD-OUT sickle and check whether the coincidental reversers (norethindrone, ciprofloxacin) finally drop. Baseline = rank by raw connectivity (the single conn feature). Win = model down-ranks the negative controls vs baseline while keeping hydroxyurea high. """ from __future__ import annotations import json import re from pathlib import Path import numpy as np import pandas as pd from sklearn.ensemble import GradientBoostingClassifier from sklearn.model_selection import GroupKFold, cross_val_predict from sklearn.metrics import roc_auc_score import sys sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from src.scoring import _ks_connectivity # noqa: E402 PROCESSED = Path("data/processed") SIGS = Path("data/raw/disease_sigs") NEG5 = ["clotrimazole", "astemizole", "azithromycin", "ethinyl-estradiol", "caffeine"] ID_RE = re.compile(r"\s*(c\d{6,}|doid[-\s]?\d+|gse\d+|umls\S+)\s*", re.I) def clean_disease(term: str) -> str: return ID_RE.sub(" ", term).strip().lower() def load_disease_sigs(): def parse(name): out = {} for line in (SIGS / name).read_text().splitlines(): p = line.split("\t") if len(p) < 3: continue out[p[0]] = [g.split(",")[0].strip().upper() for g in p[2:] if g.strip()] return out up = parse("Disease_Perturbations_from_GEO_up.gmt") down = parse("Disease_Perturbations_from_GEO_down.gmt") terms = sorted(set(up) & set(down)) return [{"term": t, "disease": clean_disease(t), "up": up[t], "down": down[t]} for t in terms] def cols_for(genes, g2c): return np.array([g2c[g] for g in set(genes) if g in g2c], dtype=int) def featurize(conn_col, C): """Per-(drug,disease) features for one disease column conn_col, given full matrix C.""" drug_mean, drug_std = C.mean(1), C.std(1) drug_min = C.min(1) broad = (C < -0.10).mean(1) # fraction of diseases this drug strongly reverses dmean, dstd = conn_col.mean(), conn_col.std() or 1.0 return np.column_stack([ conn_col, drug_mean, drug_std, drug_min, broad, np.full_like(conn_col, dmean), (conn_col - drug_mean) / np.where(drug_std == 0, 1, drug_std), # specificity within drug (conn_col - dmean) / dstd, # specificity within disease ]) FEATS = ["conn", "drug_mean", "drug_std", "drug_min", "broadness", "disease_mean", "z_within_drug", "z_within_disease"] def main(): lincs = pd.read_parquet(PROCESSED / "lincs_signatures_v1.parquet") drugs = list(lincs.index) g2c = {g: i for i, g in enumerate(lincs.columns)} n = len(lincs.columns) R = lincs.rank(axis=1, ascending=False).to_numpy() refs = [d for d in load_disease_sigs() if len(cols_for(d["up"], g2c)) + len(cols_for(d["down"], g2c)) >= 10] C = np.column_stack([_ks_connectivity(R, cols_for(d["up"], g2c), cols_for(d["down"], g2c), n) for d in refs]) print(f"{len(drugs)} drugs x {len(refs)} disease signatures; connectivity matrix {C.shape}") # labels from Repurposing Hub indications hub = pd.read_csv("data/raw/repurposing_drugs.txt", sep="\t", comment="!", low_memory=False) ind = {r.pert_iname: [i.strip().lower() for i in re.split(r"[|,]", r.indication) if len(i.strip()) > 3] for r in hub.itertuples() if isinstance(r.indication, str)} drug_idx = {d: i for i, d in enumerate(drugs)} X_rows, y, grp = [], [], [] for j, d in enumerate(refs): feats = featurize(C[:, j], C) dz = d["disease"] for i, drug in enumerate(drugs): inds = ind.get(drug, []) label = int(any(dz == k or (len(dz) > 4 and dz in k) or (len(k) > 4 and k in dz) for k in inds)) X_rows.append(feats[i]); y.append(label); grp.append(j) X = np.array(X_rows); y = np.array(y); grp = np.array(grp) print(f"pairs: {len(y)}, positives: {y.sum()} ({100*y.mean():.2f}%)") # disease-grouped CV (generalize to unseen diseases) clf = GradientBoostingClassifier(random_state=0) proba = cross_val_predict(clf, X, y, cv=GroupKFold(5), groups=grp, method="predict_proba")[:, 1] print(f"disease-grouped CV AUC: {roc_auc_score(y, proba):.3f} (conn-only AUC: {roc_auc_score(y, X[:,0]*-1):.3f})") clf.fit(X, y) print("feature importances: " + ", ".join(f"{f}={imp:.2f}" for f, imp in sorted(zip(FEATS, clf.feature_importances_), key=lambda t: -t[1]))) # apply to HELD-OUT sickle sig = json.loads((PROCESSED / "sickle_cell_signature_v1.json").read_text()) sk = _ks_connectivity(R, cols_for([g["gene"] for g in sig["up_regulated"]], g2c), cols_for([g["gene"] for g in sig["down_regulated"]], g2c), n) Xs = featurize(sk, C) p_sickle = clf.predict_proba(Xs)[:, 1] res = pd.DataFrame({"model_p": p_sickle, "conn": sk}, index=drugs) res["model_rank"] = res["model_p"].rank(ascending=False).astype(int) res["conn_rank"] = res["conn"].rank(ascending=True).astype(int) # most negative = best N = len(res) print(f"\n{'drug':14s} {'model_rank':>10s} {'conn_rank(base)':>16s}") for d in ["hydroxyurea", "glutamine"] + NEG5 + ["norethindrone", "ciprofloxacin"]: if d in res.index: r = res.loc[d] print(f" {d:14s} {int(r['model_rank']):6d}/{N} {int(r['conn_rank']):6d}/{N}") nb_model = sum(res.loc[d, "model_rank"] > N/2 for d in NEG5 if d in res.index) nb_conn = sum(res.loc[d, "conn_rank"] > N/2 for d in NEG5 if d in res.index) print(f"\nneg controls bottom-half: model {nb_model}/5 vs baseline {nb_conn}/5") print("model top 10:", ", ".join(res.sort_values('model_p', ascending=False).head(10).index)) if __name__ == "__main__": main()