Files
Reverso/gpu/modal_app.py
Junior B. 0535886ce6 Phase 2 screen pilot: HDAC2 recovers the inhibitor class (P>=0.99)
Add the `screen` entrypoint (parallel ~10-wide, cached weights) and run a
24-drug pilot vs HDAC2 (+Zn), ranked by Boltz-2 P(binder). ~$1.3.

Result (recovery test at scale): top 9 are ALL HDAC inhibitors
(trichostatin-A/vorinostat/panobinostat/belinostat/scriptaid/mocetinostat/
entinostat/apicidin >=0.99; valproic-acid 0.91), clean drop-off to
hydroxyurea 0.78 and non-HDAC drugs to dexamethasone 0.03. Captures the
structure-activity gradient (hydroxamates > weak fatty-acid > non-HDAC).

Honest false negative: romidepsin (potent HDAC inhibitor) ranks low (0.43)
-- it's a depsipeptide PRODRUG co-folding doesn't model. Screen mishandles
non-standard chemotypes.

Screening pipeline validated; next is the full 300-drug discovery run.
max_containers=10 (parallel safe once weights cached).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-24 22:23:01 +02:00

275 lines
13 KiB
Python

"""Ephemeral GPU runner for AF3-class co-folding (PLAN §12, docs/gpu_plan.md).
Serverless: `modal run gpu/modal_app.py` provisions a GPU, runs Phase 1, releases the GPU — zero
idle cost. Model weights cache in a persistent Volume so we never re-pay GPU time to re-download.
Phase 1 (positive-control recovery test, §12.4): co-fold each known binder + a couple of negative
controls against each sickle target and rank by Boltz-2 predicted affinity. The known binder
should win its own target — the test Vina couldn't pass on metal/covalent/allosteric modes.
Affinity ranking avoids the receptor-alignment that pose-RMSD would need (a later refinement).
Setup (one-time): `pip install modal && modal token new`.
Run: `modal run gpu/modal_app.py`.
Helpers below the GPU function run locally (no GPU) and are import-safe for testing.
"""
from __future__ import annotations
import json
import subprocess
from pathlib import Path
import modal
app = modal.App("reverso-binding")
image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install("git", "wget")
# Boltz-2 needs NVIDIA cuequivariance kernels (cuda 12) for inference, plus rdkit/numpy.
.pip_install("boltz", "cuequivariance-torch", "cuequivariance-ops-torch-cu12", "rdkit", "numpy")
)
weights = modal.Volume.from_name("reverso-binding-weights", create_if_missing=True)
WEIGHTS = "/weights"
# target -> (PDB id, crystal-ligand resname, drug name, cofactor/metal CCD codes to co-fold).
# The cofactors are the binding-mode determinants Vina couldn't model: HDAC2 needs the catalytic
# Zn (vorinostat chelates it), PKR the allosteric FBP + Mg, hemoglobin the heme. Co-folding them
# in as CCD ligands is the whole point of the AF3-class pivot. The same cofactors are present when
# co-folding the negatives into that target, for a fair comparison.
TARGETS = {
"hemoglobin": ("5E83", "5L7", "voxelotor", ["HEM"]),
"PKR": ("8XFD", "WV2", "mitapivat", ["FBP", "MG"]),
"HDAC2": ("4LXZ", "SHH", "vorinostat", ["ZN"]),
}
NEGATIVES = ["caffeine", "hydroxyurea"]
# Honest limitation: hemoglobin's voxelotor site sits at the tetramer centre and the bond is
# covalent (Schiff base) — a single-chain + heme model only approximates it, so Hb is the weak
# case. HDAC2 (Zn chelation) and PKR (allosteric + cofactor) are the real tests of whether
# co-folding handles the modes classical docking could not.
# --------------------------------------------------------------------------- local helpers (CPU)
def fetch_pdb(pdb: str, struct_dir: Path = Path("data/raw/structures")) -> Path:
"""Return a local PDB path, downloading from RCSB if absent (keeps the GPU run self-contained)."""
import requests
p = struct_dir / f"{pdb}.pdb"
if not p.exists():
p.parent.mkdir(parents=True, exist_ok=True)
p.write_bytes(requests.get(f"https://files.rcsb.org/download/{pdb}.pdb", timeout=60).content)
return p
def binding_chain_sequence(pdb: str, lig_resname: str) -> str:
"""One-letter sequence of the protein chain nearest the crystal ligand (the binding chain)."""
import gemmi
st = gemmi.read_structure(str(fetch_pdb(pdb)))
model = st[0]
lig_atoms = [a.pos for ch in model for res in ch if res.name == lig_resname for a in res]
if not lig_atoms:
raise ValueError(f"ligand {lig_resname} not found in {pdb}")
lig_center = gemmi.Position(
sum(p.x for p in lig_atoms) / len(lig_atoms),
sum(p.y for p in lig_atoms) / len(lig_atoms),
sum(p.z for p in lig_atoms) / len(lig_atoms),
)
best, best_d = None, 1e9
for ch in model:
poly = ch.get_polymer()
if len(poly) < 20:
continue
d = min((a.pos.dist(lig_center) for res in ch for a in res), default=1e9)
if d < best_d:
best, best_d = poly, d
if best is None:
raise ValueError(f"no protein chain in {pdb}")
return gemmi.one_letter_code(best.extract_sequence()).upper().replace("X", "")
def pubchem_smiles(name: str) -> str:
import requests
for prop in ("SMILES", "ConnectivitySMILES", "IsomericSMILES", "CanonicalSMILES"):
try:
d = requests.get(f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{name}"
f"/property/{prop}/JSON", timeout=30).json()["PropertyTable"]["Properties"][0]
if prop in d:
return d[prop]
except Exception:
continue
raise ValueError(f"no SMILES for {name}")
def build_boltz_yaml(protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str] | None = None) -> str:
"""Boltz-2 YAML: protein + the drug ligand (affinity binder) + any cofactor/metal CCD ligands.
Cofactors/ions are added as `ligand` entries referencing their CCD code (e.g. ZN, MG, FBP,
HEM) so the model places the metal/cofactor that defines the binding mode. Affinity is
predicted on the drug ligand L only.
"""
lines = ["version: 1", "sequences:",
" - protein:", " id: A", f" sequence: {protein_seq}",
" - ligand:", " id: L", f" smiles: '{ligand_smiles}'"]
for i, ccd in enumerate(cofactor_ccds or []):
lines += [" - ligand:", f" id: M{i}", f" ccd: {ccd}"]
lines += ["properties:", " - affinity:", " binder: L"]
return "\n".join(lines) + "\n"
# ------------------------------------------------------------------------------- GPU function
# max_containers caps parallel fan-out (cost control). The download race that corrupts the
# checkpoint only happens on a COLD volume; once weights are cached+committed (Phase 1 did this),
# parallel containers just reload them, so a screen can safely run ~10-wide.
@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600, max_containers=10)
def cofold(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str]) -> dict:
"""Co-fold one complex (protein + drug + cofactors) on the GPU; return affinity + P(binder).
Weights persist on the mounted Volume (HF_HOME/boltz --cache under /weights), so run 2+ skips
the download. GPU is released the moment this returns.
"""
import os
os.environ["HF_HOME"] = f"{WEIGHTS}/hf"
boltz_cache = f"{WEIGHTS}/boltz"
os.makedirs(boltz_cache, exist_ok=True)
weights.reload()
work = Path("/tmp") / label
work.mkdir(parents=True, exist_ok=True)
(work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles, cofactor_ccds))
out = work / "out"
try:
subprocess.run(
["boltz", "predict", str(work / "in.yaml"), "--use_msa_server",
"--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"],
check=True,
)
finally:
weights.commit() # persist downloaded weights/CCD even if this run fails, so retries skip it
# Affinity is written to a JSON under out/predictions/<name>/; parse defensively (keys vary).
aff = {"affinity_pred_value": None, "affinity_probability_binary": None}
for jf in out.rglob("affinity*.json"):
data = json.loads(jf.read_text())
for k in aff:
if k in data:
aff[k] = data[k]
break
cif = next(out.rglob("*_model_0.pdb"), None) or next(out.rglob("*.pdb"), None)
return {"label": label, "affinity": aff["affinity_pred_value"],
"prob_binder": aff["affinity_probability_binary"],
"structure": cif.read_text() if cif else None}
# ------------------------------------------------------------------------------- driver (local)
@app.local_entrypoint()
def pose() -> None:
"""Save the predicted HDAC2/vorinostat complex for local pose-RMSD validation.
Run: `modal run gpu/modal_app.py::pose`. Weights are cached, so this is one fast GPU call.
The returned PDB (protein + vorinostat + Zn) is scored locally against 4LXZ by
scripts/pose_rmsd.py (align predicted protein to crystal, compare ligand).
"""
target = "HDAC2"
pdb, res, drug, cofactors = TARGETS[target]
seq = binding_chain_sequence(pdb, res)
r = cofold.remote(f"{target}_{drug}_pose", seq, pubchem_smiles(drug), cofactors)
out = Path("data/processed/binding"); out.mkdir(parents=True, exist_ok=True)
if r.get("structure"):
dest = out / f"{target}_{drug}_pred.pdb"
dest.write_text(r["structure"])
print(f"saved {dest}; affinity={r['affinity']}, P(binder)={r['prob_binder']}")
else:
print("no structure returned")
@app.local_entrypoint()
def screen(limit: int = 0) -> None:
"""Phase 2: co-fold the drug set against the validated target (HDAC2 + Zn), rank by P(binder).
`modal run gpu/modal_app.py::screen --limit 24` (pilot; omit --limit for the full set).
Recovery check at scale: the known HDAC inhibitors (related_mechanism) should rank top.
Weights are cached, so this fans out ~10-wide.
"""
import csv
import pandas as pd
target = "HDAC2"
pdb, res, _drug, cofactors = TARGETS[target]
seq = binding_chain_sequence(pdb, res)
df = pd.read_csv("data/processed/drug_set_v1.csv")
df = df[df["canonical_smiles"].notna() & (df["canonical_smiles"] != "-666")].copy()
if limit: # pilot: prioritise mechanism + controls (incl. the HDAC inhibitors) then fill
pri = df[df["inclusion_reason"].isin(["ground_truth", "related_mechanism", "negative_control"])]
df = pd.concat([pri, df.drop(pri.index)]).head(limit)
jobs = [(f"{target}__{r.pert_iname}", seq, r.canonical_smiles, cofactors) for r in df.itertuples()]
print(f"screening {len(jobs)} drugs vs {target} (+{cofactors})")
results = list(cofold.starmap(jobs))
by = {j[0].split("__")[1]: r for j, r in zip(jobs, results)}
reason = dict(zip(df["pert_iname"], df["inclusion_reason"]))
rows = [{"drug": d, "P_binder": (r or {}).get("prob_binder"),
"affinity": (r or {}).get("affinity"), "inclusion_reason": reason.get(d)}
for d, r in by.items()]
rows = [x for x in rows if x["P_binder"] is not None]
rows.sort(key=lambda x: x["P_binder"], reverse=True)
out = Path("data/processed/binding"); out.mkdir(parents=True, exist_ok=True)
with (out / f"screen_{target}.csv").open("w", newline="") as f:
w = csv.DictWriter(f, fieldnames=["rank", "drug", "P_binder", "affinity", "inclusion_reason"])
w.writeheader()
for i, x in enumerate(rows, 1):
w.writerow({"rank": i, **x})
print(f"\nscreened {len(rows)} drugs vs {target}; top 15 by P(binder):")
for i, x in enumerate(rows[:15], 1):
print(f" {i:2d} {x['drug']:20s} P={x['P_binder']:.3f} [{x['inclusion_reason']}]")
hdac_like = [i for i, x in enumerate(rows, 1) if x["inclusion_reason"] == "related_mechanism"]
if hdac_like:
print(f"\nrelated-mechanism (HDAC inhibitors etc.) ranks: {hdac_like[:10]}")
@app.local_entrypoint()
def main() -> None:
"""Fan out one GPU call per (target, ligand) pair; tabulate affinities; positive-control test."""
jobs = [] # (label, protein_seq, smiles, cofactor_ccds)
for target, (pdb, res, drug, cofactors) in TARGETS.items():
seq = binding_chain_sequence(pdb, res)
for lig in [drug, *NEGATIVES]:
jobs.append((f"{target}_{lig}", seq, pubchem_smiles(lig), cofactors))
cofactor_summary = {t: c[3] for t, c in TARGETS.items()}
print(f"co-folding {len(jobs)} complexes ({len(TARGETS)} targets x {1+len(NEGATIVES)} ligands); "
f"cofactors per target: {cofactor_summary}")
results = list(cofold.starmap(jobs))
by = {r["label"]: r for r in results}
print(f"\n{'target':12s}{'ligand':14s}{'affinity':>10s}{'P(binder)':>11s}")
out_rows = []
for target, (pdb, res, drug, cofactors) in TARGETS.items():
prob = {} # rank by P(binder): unambiguous (higher = more likely a binder). Boltz
for lig in [drug, *NEGATIVES]: # affinity_pred_value is ~log(IC50) (lower=stronger) -- sign
r = by.get(f"{target}_{lig}", {}) # is version-dependent, so don't rank on it.
a, p = r.get("affinity"), r.get("prob_binder")
if p is not None:
prob[lig] = p
print(f"{target:12s}{lig:14s}{(f'{a:.2f}' if a is not None else 'NA'):>10s}"
f"{(f'{p:.2f}' if p is not None else 'NA'):>11s}")
out_rows.append({"target": target, "ligand": lig, "affinity": a, "prob_binder": p,
"is_known_binder": lig == drug})
if prob:
best = max(prob, key=prob.get) # highest P(binder)
print(f" -> {target}: best P(binder) = {best} (known = {drug}) "
f"{'PASS' if best == drug else 'FAIL'}\n")
outdir = Path("data/processed/binding"); outdir.mkdir(parents=True, exist_ok=True)
import csv
with (outdir / "phase1_affinity.csv").open("w", newline="") as f:
w = csv.DictWriter(f, fieldnames=["target", "ligand", "affinity", "prob_binder", "is_known_binder"])
w.writeheader(); w.writerows(out_rows)
print(f"wrote {outdir/'phase1_affinity.csv'}")