"""Ephemeral GPU runner for AF3-class co-folding (PLAN §12, docs/gpu_plan.md). Serverless: `modal run gpu/modal_app.py` provisions a GPU, runs Phase 1, releases the GPU — zero idle cost. Model weights cache in a persistent Volume so we never re-pay GPU time to re-download. Phase 1 (positive-control recovery test, §12.4): co-fold each known binder + a couple of negative controls against each sickle target and rank by Boltz-2 predicted affinity. The known binder should win its own target — the test Vina couldn't pass on metal/covalent/allosteric modes. Affinity ranking avoids the receptor-alignment that pose-RMSD would need (a later refinement). Setup (one-time): `pip install modal && modal token new`. Run: `modal run gpu/modal_app.py`. Helpers below the GPU function run locally (no GPU) and are import-safe for testing. """ from __future__ import annotations import json import subprocess from pathlib import Path import modal app = modal.App("reverso-binding") image = ( modal.Image.debian_slim(python_version="3.12") .apt_install("git", "wget") # Boltz-2 needs NVIDIA cuequivariance kernels (cuda 12) for inference, plus rdkit/numpy. .pip_install("boltz", "cuequivariance-torch", "cuequivariance-ops-torch-cu12", "rdkit", "numpy") ) weights = modal.Volume.from_name("reverso-binding-weights", create_if_missing=True) WEIGHTS = "/weights" # target -> (PDB id, crystal-ligand resname, drug name, cofactor/metal CCD codes to co-fold). # The cofactors are the binding-mode determinants Vina couldn't model: HDAC2 needs the catalytic # Zn (vorinostat chelates it), PKR the allosteric FBP + Mg, hemoglobin the heme. Co-folding them # in as CCD ligands is the whole point of the AF3-class pivot. The same cofactors are present when # co-folding the negatives into that target, for a fair comparison. TARGETS = { "hemoglobin": ("5E83", "5L7", "voxelotor", ["HEM"]), "PKR": ("8XFD", "WV2", "mitapivat", ["FBP", "MG"]), "HDAC2": ("4LXZ", "SHH", "vorinostat", ["ZN"]), } NEGATIVES = ["caffeine", "hydroxyurea"] # Honest limitation: hemoglobin's voxelotor site sits at the tetramer centre and the bond is # covalent (Schiff base) — a single-chain + heme model only approximates it, so Hb is the weak # case. HDAC2 (Zn chelation) and PKR (allosteric + cofactor) are the real tests of whether # co-folding handles the modes classical docking could not. # --------------------------------------------------------------------------- local helpers (CPU) def fetch_pdb(pdb: str, struct_dir: Path = Path("data/raw/structures")) -> Path: """Return a local PDB path, downloading from RCSB if absent (keeps the GPU run self-contained).""" import requests p = struct_dir / f"{pdb}.pdb" if not p.exists(): p.parent.mkdir(parents=True, exist_ok=True) p.write_bytes(requests.get(f"https://files.rcsb.org/download/{pdb}.pdb", timeout=60).content) return p def binding_chain_sequence(pdb: str, lig_resname: str) -> str: """One-letter sequence of the protein chain nearest the crystal ligand (the binding chain).""" import gemmi st = gemmi.read_structure(str(fetch_pdb(pdb))) model = st[0] lig_atoms = [a.pos for ch in model for res in ch if res.name == lig_resname for a in res] if not lig_atoms: raise ValueError(f"ligand {lig_resname} not found in {pdb}") lig_center = gemmi.Position( sum(p.x for p in lig_atoms) / len(lig_atoms), sum(p.y for p in lig_atoms) / len(lig_atoms), sum(p.z for p in lig_atoms) / len(lig_atoms), ) best, best_d = None, 1e9 for ch in model: poly = ch.get_polymer() if len(poly) < 20: continue d = min((a.pos.dist(lig_center) for res in ch for a in res), default=1e9) if d < best_d: best, best_d = poly, d if best is None: raise ValueError(f"no protein chain in {pdb}") return gemmi.one_letter_code(best.extract_sequence()).upper().replace("X", "") def pubchem_smiles(name: str) -> str: import requests for prop in ("SMILES", "ConnectivitySMILES", "IsomericSMILES", "CanonicalSMILES"): try: d = requests.get(f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{name}" f"/property/{prop}/JSON", timeout=30).json()["PropertyTable"]["Properties"][0] if prop in d: return d[prop] except Exception: continue raise ValueError(f"no SMILES for {name}") def build_boltz_yaml(protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str] | None = None) -> str: """Boltz-2 YAML: protein + the drug ligand (affinity binder) + any cofactor/metal CCD ligands. Cofactors/ions are added as `ligand` entries referencing their CCD code (e.g. ZN, MG, FBP, HEM) so the model places the metal/cofactor that defines the binding mode. Affinity is predicted on the drug ligand L only. """ lines = ["version: 1", "sequences:", " - protein:", " id: A", f" sequence: {protein_seq}", " - ligand:", " id: L", f" smiles: '{ligand_smiles}'"] for i, ccd in enumerate(cofactor_ccds or []): lines += [" - ligand:", f" id: M{i}", f" ccd: {ccd}"] lines += ["properties:", " - affinity:", " binder: L"] return "\n".join(lines) + "\n" # ------------------------------------------------------------------------------- GPU function # max_containers=1: run the inputs serially on one warm container so the weights download ONCE # (no concurrent-download race that corrupts the checkpoint) and are reused for the rest. @app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600, max_containers=1) def cofold(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str]) -> dict: """Co-fold one complex (protein + drug + cofactors) on the GPU; return affinity + P(binder). Weights persist on the mounted Volume (HF_HOME/boltz --cache under /weights), so run 2+ skips the download. GPU is released the moment this returns. """ import os os.environ["HF_HOME"] = f"{WEIGHTS}/hf" boltz_cache = f"{WEIGHTS}/boltz" os.makedirs(boltz_cache, exist_ok=True) weights.reload() work = Path("/tmp") / label work.mkdir(parents=True, exist_ok=True) (work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles, cofactor_ccds)) out = work / "out" try: subprocess.run( ["boltz", "predict", str(work / "in.yaml"), "--use_msa_server", "--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"], check=True, ) finally: weights.commit() # persist downloaded weights/CCD even if this run fails, so retries skip it # Affinity is written to a JSON under out/predictions//; parse defensively (keys vary). aff = {"affinity_pred_value": None, "affinity_probability_binary": None} for jf in out.rglob("affinity*.json"): data = json.loads(jf.read_text()) for k in aff: if k in data: aff[k] = data[k] break cif = next(out.rglob("*_model_0.pdb"), None) or next(out.rglob("*.pdb"), None) return {"label": label, "affinity": aff["affinity_pred_value"], "prob_binder": aff["affinity_probability_binary"], "structure": cif.read_text() if cif else None} # ------------------------------------------------------------------------------- driver (local) @app.local_entrypoint() def pose() -> None: """Save the predicted HDAC2/vorinostat complex for local pose-RMSD validation. Run: `modal run gpu/modal_app.py::pose`. Weights are cached, so this is one fast GPU call. The returned PDB (protein + vorinostat + Zn) is scored locally against 4LXZ by scripts/pose_rmsd.py (align predicted protein to crystal, compare ligand). """ target = "HDAC2" pdb, res, drug, cofactors = TARGETS[target] seq = binding_chain_sequence(pdb, res) r = cofold.remote(f"{target}_{drug}_pose", seq, pubchem_smiles(drug), cofactors) out = Path("data/processed/binding"); out.mkdir(parents=True, exist_ok=True) if r.get("structure"): dest = out / f"{target}_{drug}_pred.pdb" dest.write_text(r["structure"]) print(f"saved {dest}; affinity={r['affinity']}, P(binder)={r['prob_binder']}") else: print("no structure returned") @app.local_entrypoint() def main() -> None: """Fan out one GPU call per (target, ligand) pair; tabulate affinities; positive-control test.""" jobs = [] # (label, protein_seq, smiles, cofactor_ccds) for target, (pdb, res, drug, cofactors) in TARGETS.items(): seq = binding_chain_sequence(pdb, res) for lig in [drug, *NEGATIVES]: jobs.append((f"{target}_{lig}", seq, pubchem_smiles(lig), cofactors)) cofactor_summary = {t: c[3] for t, c in TARGETS.items()} print(f"co-folding {len(jobs)} complexes ({len(TARGETS)} targets x {1+len(NEGATIVES)} ligands); " f"cofactors per target: {cofactor_summary}") results = list(cofold.starmap(jobs)) by = {r["label"]: r for r in results} print(f"\n{'target':12s}{'ligand':14s}{'affinity':>10s}{'P(binder)':>11s}") out_rows = [] for target, (pdb, res, drug, cofactors) in TARGETS.items(): prob = {} # rank by P(binder): unambiguous (higher = more likely a binder). Boltz for lig in [drug, *NEGATIVES]: # affinity_pred_value is ~log(IC50) (lower=stronger) -- sign r = by.get(f"{target}_{lig}", {}) # is version-dependent, so don't rank on it. a, p = r.get("affinity"), r.get("prob_binder") if p is not None: prob[lig] = p print(f"{target:12s}{lig:14s}{(f'{a:.2f}' if a is not None else 'NA'):>10s}" f"{(f'{p:.2f}' if p is not None else 'NA'):>11s}") out_rows.append({"target": target, "ligand": lig, "affinity": a, "prob_binder": p, "is_known_binder": lig == drug}) if prob: best = max(prob, key=prob.get) # highest P(binder) print(f" -> {target}: best P(binder) = {best} (known = {drug}) " f"{'PASS' if best == drug else 'FAIL'}\n") outdir = Path("data/processed/binding"); outdir.mkdir(parents=True, exist_ok=True) import csv with (outdir / "phase1_affinity.csv").open("w", newline="") as f: w = csv.DictWriter(f, fieldnames=["target", "ligand", "affinity", "prob_binder", "is_known_binder"]) w.writeheader(); w.writerows(out_rows) print(f"wrote {outdir/'phase1_affinity.csv'}")