Fix screen: compute target MSA once, reuse it, tolerate failures

The full 300-drug screen hammered the public ColabFold MSA server (one
redundant query per drug for the same HDAC2 sequence) -> timeouts, and
the default return_exceptions=False risked aborting the whole run on a
single failure.

Corrected:
- cache_msa(): compute the target MSA ONCE via the server, cache the a3m
  on the Volume; doubles as the weight/CCD warmup.
- build_boltz_yaml(msa_path): protein reuses the cached a3m.
- cofold(msa_path): when given, skip --use_msa_server (no server query).
- screen(): cache MSA once, then cofold.starmap(..., return_exceptions=True)
  so a bad drug is skipped, not fatal.

Turns a fragile ~2.5 hr run into a fast, robust one. Validation pilot
(6 drugs) running; full 300 to run tomorrow.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-25 00:41:51 +02:00
parent 0535886ce6
commit 066e0096d7

View File

@@ -102,16 +102,20 @@ def pubchem_smiles(name: str) -> str:
raise ValueError(f"no SMILES for {name}")
def build_boltz_yaml(protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str] | None = None) -> str:
def build_boltz_yaml(protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str] | None = None,
msa_path: str | None = None) -> str:
"""Boltz-2 YAML: protein + the drug ligand (affinity binder) + any cofactor/metal CCD ligands.
Cofactors/ions are added as `ligand` entries referencing their CCD code (e.g. ZN, MG, FBP,
HEM) so the model places the metal/cofactor that defines the binding mode. Affinity is
predicted on the drug ligand L only.
predicted on the drug ligand L only. If ``msa_path`` is given, the protein reuses a precomputed
MSA (so a screen against one target queries the MSA server ONCE, not per drug).
"""
lines = ["version: 1", "sequences:",
" - protein:", " id: A", f" sequence: {protein_seq}",
" - ligand:", " id: L", f" smiles: '{ligand_smiles}'"]
prot = [" - protein:", " id: A", f" sequence: {protein_seq}"]
if msa_path:
prot.append(f" msa: {msa_path}")
lines = ["version: 1", "sequences:"] + prot + \
[" - ligand:", " id: L", f" smiles: '{ligand_smiles}'"]
for i, ccd in enumerate(cofactor_ccds or []):
lines += [" - ligand:", f" id: M{i}", f" ccd: {ccd}"]
lines += ["properties:", " - affinity:", " binder: L"]
@@ -123,12 +127,46 @@ def build_boltz_yaml(protein_seq: str, ligand_smiles: str, cofactor_ccds: list[s
# max_containers caps parallel fan-out (cost control). The download race that corrupts the
# checkpoint only happens on a COLD volume; once weights are cached+committed (Phase 1 did this),
# parallel containers just reload them, so a screen can safely run ~10-wide.
@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600, max_containers=10)
def cofold(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str]) -> dict:
@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=1800)
def cache_msa(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str]) -> str:
"""Compute the target's MSA ONCE (via the server) and cache the a3m on the Volume.
A screen reuses one target protein for all drugs, so we query the MSA server a single time
here, then every cofold() reuses the cached a3m (no server hammering). Returns the Volume path
of the cached a3m. Doubles as the weight/CCD warmup.
"""
import os
import shutil
os.environ["HF_HOME"] = f"{WEIGHTS}/hf"
boltz_cache = f"{WEIGHTS}/boltz"
os.makedirs(boltz_cache, exist_ok=True)
weights.reload()
work = Path("/tmp") / f"{label}_msa"
work.mkdir(parents=True, exist_ok=True)
(work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles, cofactor_ccds))
out = work / "out"
subprocess.run(["boltz", "predict", str(work / "in.yaml"), "--use_msa_server",
"--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"], check=True)
a3m = next(out.rglob("*.a3m"), None)
if a3m is None:
weights.commit()
raise RuntimeError("MSA generation produced no .a3m")
msa_dir = Path(WEIGHTS) / "msa"
msa_dir.mkdir(parents=True, exist_ok=True)
dest = msa_dir / f"{label}.a3m"
shutil.copy(a3m, dest)
weights.commit()
return str(dest)
@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600, max_containers=20)
def cofold(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str],
msa_path: str | None = None) -> dict:
"""Co-fold one complex (protein + drug + cofactors) on the GPU; return affinity + P(binder).
Weights persist on the mounted Volume (HF_HOME/boltz --cache under /weights), so run 2+ skips
the download. GPU is released the moment this returns.
Weights persist on the mounted Volume. If ``msa_path`` is given, reuse the cached target MSA
(no server query); else query the MSA server. GPU is released the moment this returns.
"""
import os
os.environ["HF_HOME"] = f"{WEIGHTS}/hf"
@@ -138,14 +176,14 @@ def cofold(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list
work = Path("/tmp") / label
work.mkdir(parents=True, exist_ok=True)
(work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles, cofactor_ccds))
(work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles, cofactor_ccds, msa_path))
out = work / "out"
cmd = ["boltz", "predict", str(work / "in.yaml"),
"--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"]
if not msa_path:
cmd.insert(3, "--use_msa_server")
try:
subprocess.run(
["boltz", "predict", str(work / "in.yaml"), "--use_msa_server",
"--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"],
check=True,
)
subprocess.run(cmd, check=True)
finally:
weights.commit() # persist downloaded weights/CCD even if this run fails, so retries skip it
@@ -206,11 +244,20 @@ def screen(limit: int = 0) -> None:
if limit: # pilot: prioritise mechanism + controls (incl. the HDAC inhibitors) then fill
pri = df[df["inclusion_reason"].isin(["ground_truth", "related_mechanism", "negative_control"])]
df = pd.concat([pri, df.drop(pri.index)]).head(limit)
jobs = [(f"{target}__{r.pert_iname}", seq, r.canonical_smiles, cofactors) for r in df.itertuples()]
print(f"screening {len(jobs)} drugs vs {target} (+{cofactors})")
# 1) compute the target MSA ONCE (also warms weights/CCD), then reuse it for every drug
print(f"computing {target} MSA once (server) ...")
msa_path = cache_msa.remote(target, seq, pubchem_smiles("vorinostat"), cofactors)
print(f"cached MSA: {msa_path}")
results = list(cofold.starmap(jobs))
by = {j[0].split("__")[1]: r for j, r in zip(jobs, results)}
# 2) screen all drugs reusing the cached MSA; tolerate per-drug failures (no abort)
jobs = [(f"{target}__{r.pert_iname}", seq, r.canonical_smiles, cofactors, msa_path)
for r in df.itertuples()]
print(f"screening {len(jobs)} drugs vs {target} (+{cofactors}), reusing cached MSA")
results = list(cofold.starmap(jobs, return_exceptions=True))
by = {j[0].split("__")[1]: (r if isinstance(r, dict) else None) for j, r in zip(jobs, results)}
n_fail = sum(1 for v in by.values() if v is None)
if n_fail:
print(f" ({n_fail}/{len(jobs)} drugs failed and were skipped)")
reason = dict(zip(df["pert_iname"], df["inclusion_reason"]))
rows = [{"drug": d, "P_binder": (r or {}).get("prob_binder"),