Phase D: supervised cross-disease (0.925 AUC degree-bias mirage)
Train GradientBoosting on 300 drugs x 839 GEO disease signatures with Repurposing-Hub indications as labels (432 positives), disease-grouped CV. Finding: 0.925 CV AUC looks like a win but is a MIRAGE. Feature importances are all drug-level (drug_std 0.33, drug_mean 0.30, broadness 0.17); drug-disease connectivity importance = 0.01. The model learned a drug-POPULARITY prior, not disease-specific matching. On held-out sickle it ranks hydroxyurea 231/300 (worse than baseline) and tops out with promiscuous drugs (dexamethasone, methotrexate). Classic degree-bias trap. Connectivity also has ~chance AUC (0.51) for predicting approved indications. Both obvious approaches now fail instructively: unsupervised = specificity ceiling; naive supervised = degree bias. Real progress needs degree- debiased training + much larger clean labels (a research effort). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
137
scripts/phaseD_supervised.py
Normal file
137
scripts/phaseD_supervised.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Phase D: supervised cross-disease repurposing — can labels break the specificity ceiling?
|
||||
|
||||
Connectivity alone can't tell therapeutic from coincidental reversal. A supervised model trained
|
||||
on known drug-disease pairs CAN learn that pattern — given features that expose drug "broadness"
|
||||
(a drug that reverses everything is non-specific). We train on 839 GEO disease signatures with
|
||||
Repurposing-Hub indications as labels, evaluate with disease-grouped CV, then apply to HELD-OUT
|
||||
sickle and check whether the coincidental reversers (norethindrone, ciprofloxacin) finally drop.
|
||||
|
||||
Baseline = rank by raw connectivity (the single conn feature). Win = model down-ranks the
|
||||
negative controls vs baseline while keeping hydroxyurea high.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
from sklearn.model_selection import GroupKFold, cross_val_predict
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
from src.scoring import _ks_connectivity # noqa: E402
|
||||
|
||||
PROCESSED = Path("data/processed")
|
||||
SIGS = Path("data/raw/disease_sigs")
|
||||
NEG5 = ["clotrimazole", "astemizole", "azithromycin", "ethinyl-estradiol", "caffeine"]
|
||||
ID_RE = re.compile(r"\s*(c\d{6,}|doid[-\s]?\d+|gse\d+|umls\S+)\s*", re.I)
|
||||
|
||||
|
||||
def clean_disease(term: str) -> str:
|
||||
return ID_RE.sub(" ", term).strip().lower()
|
||||
|
||||
|
||||
def load_disease_sigs():
|
||||
def parse(name):
|
||||
out = {}
|
||||
for line in (SIGS / name).read_text().splitlines():
|
||||
p = line.split("\t")
|
||||
if len(p) < 3:
|
||||
continue
|
||||
out[p[0]] = [g.split(",")[0].strip().upper() for g in p[2:] if g.strip()]
|
||||
return out
|
||||
up = parse("Disease_Perturbations_from_GEO_up.gmt")
|
||||
down = parse("Disease_Perturbations_from_GEO_down.gmt")
|
||||
terms = sorted(set(up) & set(down))
|
||||
return [{"term": t, "disease": clean_disease(t), "up": up[t], "down": down[t]} for t in terms]
|
||||
|
||||
|
||||
def cols_for(genes, g2c):
|
||||
return np.array([g2c[g] for g in set(genes) if g in g2c], dtype=int)
|
||||
|
||||
|
||||
def featurize(conn_col, C):
|
||||
"""Per-(drug,disease) features for one disease column conn_col, given full matrix C."""
|
||||
drug_mean, drug_std = C.mean(1), C.std(1)
|
||||
drug_min = C.min(1)
|
||||
broad = (C < -0.10).mean(1) # fraction of diseases this drug strongly reverses
|
||||
dmean, dstd = conn_col.mean(), conn_col.std() or 1.0
|
||||
return np.column_stack([
|
||||
conn_col, drug_mean, drug_std, drug_min, broad,
|
||||
np.full_like(conn_col, dmean),
|
||||
(conn_col - drug_mean) / np.where(drug_std == 0, 1, drug_std), # specificity within drug
|
||||
(conn_col - dmean) / dstd, # specificity within disease
|
||||
])
|
||||
|
||||
|
||||
FEATS = ["conn", "drug_mean", "drug_std", "drug_min", "broadness",
|
||||
"disease_mean", "z_within_drug", "z_within_disease"]
|
||||
|
||||
|
||||
def main():
|
||||
lincs = pd.read_parquet(PROCESSED / "lincs_signatures_v1.parquet")
|
||||
drugs = list(lincs.index)
|
||||
g2c = {g: i for i, g in enumerate(lincs.columns)}
|
||||
n = len(lincs.columns)
|
||||
R = lincs.rank(axis=1, ascending=False).to_numpy()
|
||||
|
||||
refs = [d for d in load_disease_sigs()
|
||||
if len(cols_for(d["up"], g2c)) + len(cols_for(d["down"], g2c)) >= 10]
|
||||
C = np.column_stack([_ks_connectivity(R, cols_for(d["up"], g2c), cols_for(d["down"], g2c), n) for d in refs])
|
||||
print(f"{len(drugs)} drugs x {len(refs)} disease signatures; connectivity matrix {C.shape}")
|
||||
|
||||
# labels from Repurposing Hub indications
|
||||
hub = pd.read_csv("data/raw/repurposing_drugs.txt", sep="\t", comment="!", low_memory=False)
|
||||
ind = {r.pert_iname: [i.strip().lower() for i in re.split(r"[|,]", r.indication) if len(i.strip()) > 3]
|
||||
for r in hub.itertuples() if isinstance(r.indication, str)}
|
||||
drug_idx = {d: i for i, d in enumerate(drugs)}
|
||||
|
||||
X_rows, y, grp = [], [], []
|
||||
for j, d in enumerate(refs):
|
||||
feats = featurize(C[:, j], C)
|
||||
dz = d["disease"]
|
||||
for i, drug in enumerate(drugs):
|
||||
inds = ind.get(drug, [])
|
||||
label = int(any(dz == k or (len(dz) > 4 and dz in k) or (len(k) > 4 and k in dz) for k in inds))
|
||||
X_rows.append(feats[i]); y.append(label); grp.append(j)
|
||||
X = np.array(X_rows); y = np.array(y); grp = np.array(grp)
|
||||
print(f"pairs: {len(y)}, positives: {y.sum()} ({100*y.mean():.2f}%)")
|
||||
|
||||
# disease-grouped CV (generalize to unseen diseases)
|
||||
clf = GradientBoostingClassifier(random_state=0)
|
||||
proba = cross_val_predict(clf, X, y, cv=GroupKFold(5), groups=grp, method="predict_proba")[:, 1]
|
||||
print(f"disease-grouped CV AUC: {roc_auc_score(y, proba):.3f} (conn-only AUC: {roc_auc_score(y, X[:,0]*-1):.3f})")
|
||||
|
||||
clf.fit(X, y)
|
||||
print("feature importances: " + ", ".join(f"{f}={imp:.2f}" for f, imp in
|
||||
sorted(zip(FEATS, clf.feature_importances_), key=lambda t: -t[1])))
|
||||
|
||||
# apply to HELD-OUT sickle
|
||||
sig = json.loads((PROCESSED / "sickle_cell_signature_v1.json").read_text())
|
||||
sk = _ks_connectivity(R, cols_for([g["gene"] for g in sig["up_regulated"]], g2c),
|
||||
cols_for([g["gene"] for g in sig["down_regulated"]], g2c), n)
|
||||
Xs = featurize(sk, C)
|
||||
p_sickle = clf.predict_proba(Xs)[:, 1]
|
||||
|
||||
res = pd.DataFrame({"model_p": p_sickle, "conn": sk}, index=drugs)
|
||||
res["model_rank"] = res["model_p"].rank(ascending=False).astype(int)
|
||||
res["conn_rank"] = res["conn"].rank(ascending=True).astype(int) # most negative = best
|
||||
N = len(res)
|
||||
print(f"\n{'drug':14s} {'model_rank':>10s} {'conn_rank(base)':>16s}")
|
||||
for d in ["hydroxyurea", "glutamine"] + NEG5 + ["norethindrone", "ciprofloxacin"]:
|
||||
if d in res.index:
|
||||
r = res.loc[d]
|
||||
print(f" {d:14s} {int(r['model_rank']):6d}/{N} {int(r['conn_rank']):6d}/{N}")
|
||||
nb_model = sum(res.loc[d, "model_rank"] > N/2 for d in NEG5 if d in res.index)
|
||||
nb_conn = sum(res.loc[d, "conn_rank"] > N/2 for d in NEG5 if d in res.index)
|
||||
print(f"\nneg controls bottom-half: model {nb_model}/5 vs baseline {nb_conn}/5")
|
||||
print("model top 10:", ", ".join(res.sort_values('model_p', ascending=False).head(10).index))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user