Fix screen: compute target MSA once, reuse it, tolerate failures
The full 300-drug screen hammered the public ColabFold MSA server (one redundant query per drug for the same HDAC2 sequence) -> timeouts, and the default return_exceptions=False risked aborting the whole run on a single failure. Corrected: - cache_msa(): compute the target MSA ONCE via the server, cache the a3m on the Volume; doubles as the weight/CCD warmup. - build_boltz_yaml(msa_path): protein reuses the cached a3m. - cofold(msa_path): when given, skip --use_msa_server (no server query). - screen(): cache MSA once, then cofold.starmap(..., return_exceptions=True) so a bad drug is skipped, not fatal. Turns a fragile ~2.5 hr run into a fast, robust one. Validation pilot (6 drugs) running; full 300 to run tomorrow. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -102,16 +102,20 @@ def pubchem_smiles(name: str) -> str:
|
|||||||
raise ValueError(f"no SMILES for {name}")
|
raise ValueError(f"no SMILES for {name}")
|
||||||
|
|
||||||
|
|
||||||
def build_boltz_yaml(protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str] | None = None) -> str:
|
def build_boltz_yaml(protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str] | None = None,
|
||||||
|
msa_path: str | None = None) -> str:
|
||||||
"""Boltz-2 YAML: protein + the drug ligand (affinity binder) + any cofactor/metal CCD ligands.
|
"""Boltz-2 YAML: protein + the drug ligand (affinity binder) + any cofactor/metal CCD ligands.
|
||||||
|
|
||||||
Cofactors/ions are added as `ligand` entries referencing their CCD code (e.g. ZN, MG, FBP,
|
Cofactors/ions are added as `ligand` entries referencing their CCD code (e.g. ZN, MG, FBP,
|
||||||
HEM) so the model places the metal/cofactor that defines the binding mode. Affinity is
|
HEM) so the model places the metal/cofactor that defines the binding mode. Affinity is
|
||||||
predicted on the drug ligand L only.
|
predicted on the drug ligand L only. If ``msa_path`` is given, the protein reuses a precomputed
|
||||||
|
MSA (so a screen against one target queries the MSA server ONCE, not per drug).
|
||||||
"""
|
"""
|
||||||
lines = ["version: 1", "sequences:",
|
prot = [" - protein:", " id: A", f" sequence: {protein_seq}"]
|
||||||
" - protein:", " id: A", f" sequence: {protein_seq}",
|
if msa_path:
|
||||||
" - ligand:", " id: L", f" smiles: '{ligand_smiles}'"]
|
prot.append(f" msa: {msa_path}")
|
||||||
|
lines = ["version: 1", "sequences:"] + prot + \
|
||||||
|
[" - ligand:", " id: L", f" smiles: '{ligand_smiles}'"]
|
||||||
for i, ccd in enumerate(cofactor_ccds or []):
|
for i, ccd in enumerate(cofactor_ccds or []):
|
||||||
lines += [" - ligand:", f" id: M{i}", f" ccd: {ccd}"]
|
lines += [" - ligand:", f" id: M{i}", f" ccd: {ccd}"]
|
||||||
lines += ["properties:", " - affinity:", " binder: L"]
|
lines += ["properties:", " - affinity:", " binder: L"]
|
||||||
@@ -123,12 +127,46 @@ def build_boltz_yaml(protein_seq: str, ligand_smiles: str, cofactor_ccds: list[s
|
|||||||
# max_containers caps parallel fan-out (cost control). The download race that corrupts the
|
# max_containers caps parallel fan-out (cost control). The download race that corrupts the
|
||||||
# checkpoint only happens on a COLD volume; once weights are cached+committed (Phase 1 did this),
|
# checkpoint only happens on a COLD volume; once weights are cached+committed (Phase 1 did this),
|
||||||
# parallel containers just reload them, so a screen can safely run ~10-wide.
|
# parallel containers just reload them, so a screen can safely run ~10-wide.
|
||||||
@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600, max_containers=10)
|
@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=1800)
|
||||||
def cofold(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str]) -> dict:
|
def cache_msa(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str]) -> str:
|
||||||
|
"""Compute the target's MSA ONCE (via the server) and cache the a3m on the Volume.
|
||||||
|
|
||||||
|
A screen reuses one target protein for all drugs, so we query the MSA server a single time
|
||||||
|
here, then every cofold() reuses the cached a3m (no server hammering). Returns the Volume path
|
||||||
|
of the cached a3m. Doubles as the weight/CCD warmup.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
os.environ["HF_HOME"] = f"{WEIGHTS}/hf"
|
||||||
|
boltz_cache = f"{WEIGHTS}/boltz"
|
||||||
|
os.makedirs(boltz_cache, exist_ok=True)
|
||||||
|
weights.reload()
|
||||||
|
|
||||||
|
work = Path("/tmp") / f"{label}_msa"
|
||||||
|
work.mkdir(parents=True, exist_ok=True)
|
||||||
|
(work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles, cofactor_ccds))
|
||||||
|
out = work / "out"
|
||||||
|
subprocess.run(["boltz", "predict", str(work / "in.yaml"), "--use_msa_server",
|
||||||
|
"--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"], check=True)
|
||||||
|
a3m = next(out.rglob("*.a3m"), None)
|
||||||
|
if a3m is None:
|
||||||
|
weights.commit()
|
||||||
|
raise RuntimeError("MSA generation produced no .a3m")
|
||||||
|
msa_dir = Path(WEIGHTS) / "msa"
|
||||||
|
msa_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
dest = msa_dir / f"{label}.a3m"
|
||||||
|
shutil.copy(a3m, dest)
|
||||||
|
weights.commit()
|
||||||
|
return str(dest)
|
||||||
|
|
||||||
|
|
||||||
|
@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600, max_containers=20)
|
||||||
|
def cofold(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str],
|
||||||
|
msa_path: str | None = None) -> dict:
|
||||||
"""Co-fold one complex (protein + drug + cofactors) on the GPU; return affinity + P(binder).
|
"""Co-fold one complex (protein + drug + cofactors) on the GPU; return affinity + P(binder).
|
||||||
|
|
||||||
Weights persist on the mounted Volume (HF_HOME/boltz --cache under /weights), so run 2+ skips
|
Weights persist on the mounted Volume. If ``msa_path`` is given, reuse the cached target MSA
|
||||||
the download. GPU is released the moment this returns.
|
(no server query); else query the MSA server. GPU is released the moment this returns.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
os.environ["HF_HOME"] = f"{WEIGHTS}/hf"
|
os.environ["HF_HOME"] = f"{WEIGHTS}/hf"
|
||||||
@@ -138,14 +176,14 @@ def cofold(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list
|
|||||||
|
|
||||||
work = Path("/tmp") / label
|
work = Path("/tmp") / label
|
||||||
work.mkdir(parents=True, exist_ok=True)
|
work.mkdir(parents=True, exist_ok=True)
|
||||||
(work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles, cofactor_ccds))
|
(work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles, cofactor_ccds, msa_path))
|
||||||
out = work / "out"
|
out = work / "out"
|
||||||
|
cmd = ["boltz", "predict", str(work / "in.yaml"),
|
||||||
|
"--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"]
|
||||||
|
if not msa_path:
|
||||||
|
cmd.insert(3, "--use_msa_server")
|
||||||
try:
|
try:
|
||||||
subprocess.run(
|
subprocess.run(cmd, check=True)
|
||||||
["boltz", "predict", str(work / "in.yaml"), "--use_msa_server",
|
|
||||||
"--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"],
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
weights.commit() # persist downloaded weights/CCD even if this run fails, so retries skip it
|
weights.commit() # persist downloaded weights/CCD even if this run fails, so retries skip it
|
||||||
|
|
||||||
@@ -206,11 +244,20 @@ def screen(limit: int = 0) -> None:
|
|||||||
if limit: # pilot: prioritise mechanism + controls (incl. the HDAC inhibitors) then fill
|
if limit: # pilot: prioritise mechanism + controls (incl. the HDAC inhibitors) then fill
|
||||||
pri = df[df["inclusion_reason"].isin(["ground_truth", "related_mechanism", "negative_control"])]
|
pri = df[df["inclusion_reason"].isin(["ground_truth", "related_mechanism", "negative_control"])]
|
||||||
df = pd.concat([pri, df.drop(pri.index)]).head(limit)
|
df = pd.concat([pri, df.drop(pri.index)]).head(limit)
|
||||||
jobs = [(f"{target}__{r.pert_iname}", seq, r.canonical_smiles, cofactors) for r in df.itertuples()]
|
# 1) compute the target MSA ONCE (also warms weights/CCD), then reuse it for every drug
|
||||||
print(f"screening {len(jobs)} drugs vs {target} (+{cofactors})")
|
print(f"computing {target} MSA once (server) ...")
|
||||||
|
msa_path = cache_msa.remote(target, seq, pubchem_smiles("vorinostat"), cofactors)
|
||||||
|
print(f"cached MSA: {msa_path}")
|
||||||
|
|
||||||
results = list(cofold.starmap(jobs))
|
# 2) screen all drugs reusing the cached MSA; tolerate per-drug failures (no abort)
|
||||||
by = {j[0].split("__")[1]: r for j, r in zip(jobs, results)}
|
jobs = [(f"{target}__{r.pert_iname}", seq, r.canonical_smiles, cofactors, msa_path)
|
||||||
|
for r in df.itertuples()]
|
||||||
|
print(f"screening {len(jobs)} drugs vs {target} (+{cofactors}), reusing cached MSA")
|
||||||
|
results = list(cofold.starmap(jobs, return_exceptions=True))
|
||||||
|
by = {j[0].split("__")[1]: (r if isinstance(r, dict) else None) for j, r in zip(jobs, results)}
|
||||||
|
n_fail = sum(1 for v in by.values() if v is None)
|
||||||
|
if n_fail:
|
||||||
|
print(f" ({n_fail}/{len(jobs)} drugs failed and were skipped)")
|
||||||
reason = dict(zip(df["pert_iname"], df["inclusion_reason"]))
|
reason = dict(zip(df["pert_iname"], df["inclusion_reason"]))
|
||||||
|
|
||||||
rows = [{"drug": d, "P_binder": (r or {}).get("prob_binder"),
|
rows = [{"drug": d, "P_binder": (r or {}).get("prob_binder"),
|
||||||
|
|||||||
Reference in New Issue
Block a user