From 066e0096d7c8ba77251211165540c544b0762a83 Mon Sep 17 00:00:00 2001 From: "Junior B." Date: Thu, 25 Jun 2026 00:41:51 +0200 Subject: [PATCH] Fix screen: compute target MSA once, reuse it, tolerate failures The full 300-drug screen hammered the public ColabFold MSA server (one redundant query per drug for the same HDAC2 sequence) -> timeouts, and the default return_exceptions=False risked aborting the whole run on a single failure. Corrected: - cache_msa(): compute the target MSA ONCE via the server, cache the a3m on the Volume; doubles as the weight/CCD warmup. - build_boltz_yaml(msa_path): protein reuses the cached a3m. - cofold(msa_path): when given, skip --use_msa_server (no server query). - screen(): cache MSA once, then cofold.starmap(..., return_exceptions=True) so a bad drug is skipped, not fatal. Turns a fragile ~2.5 hr run into a fast, robust one. Validation pilot (6 drugs) running; full 300 to run tomorrow. Co-Authored-By: Claude Opus 4.8 (1M context) --- gpu/modal_app.py | 85 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 66 insertions(+), 19 deletions(-) diff --git a/gpu/modal_app.py b/gpu/modal_app.py index 035278f..231e50a 100644 --- a/gpu/modal_app.py +++ b/gpu/modal_app.py @@ -102,16 +102,20 @@ def pubchem_smiles(name: str) -> str: raise ValueError(f"no SMILES for {name}") -def build_boltz_yaml(protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str] | None = None) -> str: +def build_boltz_yaml(protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str] | None = None, + msa_path: str | None = None) -> str: """Boltz-2 YAML: protein + the drug ligand (affinity binder) + any cofactor/metal CCD ligands. Cofactors/ions are added as `ligand` entries referencing their CCD code (e.g. ZN, MG, FBP, HEM) so the model places the metal/cofactor that defines the binding mode. Affinity is - predicted on the drug ligand L only. + predicted on the drug ligand L only. If ``msa_path`` is given, the protein reuses a precomputed + MSA (so a screen against one target queries the MSA server ONCE, not per drug). """ - lines = ["version: 1", "sequences:", - " - protein:", " id: A", f" sequence: {protein_seq}", - " - ligand:", " id: L", f" smiles: '{ligand_smiles}'"] + prot = [" - protein:", " id: A", f" sequence: {protein_seq}"] + if msa_path: + prot.append(f" msa: {msa_path}") + lines = ["version: 1", "sequences:"] + prot + \ + [" - ligand:", " id: L", f" smiles: '{ligand_smiles}'"] for i, ccd in enumerate(cofactor_ccds or []): lines += [" - ligand:", f" id: M{i}", f" ccd: {ccd}"] lines += ["properties:", " - affinity:", " binder: L"] @@ -123,12 +127,46 @@ def build_boltz_yaml(protein_seq: str, ligand_smiles: str, cofactor_ccds: list[s # max_containers caps parallel fan-out (cost control). The download race that corrupts the # checkpoint only happens on a COLD volume; once weights are cached+committed (Phase 1 did this), # parallel containers just reload them, so a screen can safely run ~10-wide. -@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600, max_containers=10) -def cofold(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str]) -> dict: +@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=1800) +def cache_msa(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str]) -> str: + """Compute the target's MSA ONCE (via the server) and cache the a3m on the Volume. + + A screen reuses one target protein for all drugs, so we query the MSA server a single time + here, then every cofold() reuses the cached a3m (no server hammering). Returns the Volume path + of the cached a3m. Doubles as the weight/CCD warmup. + """ + import os + import shutil + os.environ["HF_HOME"] = f"{WEIGHTS}/hf" + boltz_cache = f"{WEIGHTS}/boltz" + os.makedirs(boltz_cache, exist_ok=True) + weights.reload() + + work = Path("/tmp") / f"{label}_msa" + work.mkdir(parents=True, exist_ok=True) + (work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles, cofactor_ccds)) + out = work / "out" + subprocess.run(["boltz", "predict", str(work / "in.yaml"), "--use_msa_server", + "--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"], check=True) + a3m = next(out.rglob("*.a3m"), None) + if a3m is None: + weights.commit() + raise RuntimeError("MSA generation produced no .a3m") + msa_dir = Path(WEIGHTS) / "msa" + msa_dir.mkdir(parents=True, exist_ok=True) + dest = msa_dir / f"{label}.a3m" + shutil.copy(a3m, dest) + weights.commit() + return str(dest) + + +@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600, max_containers=20) +def cofold(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list[str], + msa_path: str | None = None) -> dict: """Co-fold one complex (protein + drug + cofactors) on the GPU; return affinity + P(binder). - Weights persist on the mounted Volume (HF_HOME/boltz --cache under /weights), so run 2+ skips - the download. GPU is released the moment this returns. + Weights persist on the mounted Volume. If ``msa_path`` is given, reuse the cached target MSA + (no server query); else query the MSA server. GPU is released the moment this returns. """ import os os.environ["HF_HOME"] = f"{WEIGHTS}/hf" @@ -138,14 +176,14 @@ def cofold(label: str, protein_seq: str, ligand_smiles: str, cofactor_ccds: list work = Path("/tmp") / label work.mkdir(parents=True, exist_ok=True) - (work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles, cofactor_ccds)) + (work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles, cofactor_ccds, msa_path)) out = work / "out" + cmd = ["boltz", "predict", str(work / "in.yaml"), + "--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"] + if not msa_path: + cmd.insert(3, "--use_msa_server") try: - subprocess.run( - ["boltz", "predict", str(work / "in.yaml"), "--use_msa_server", - "--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"], - check=True, - ) + subprocess.run(cmd, check=True) finally: weights.commit() # persist downloaded weights/CCD even if this run fails, so retries skip it @@ -206,11 +244,20 @@ def screen(limit: int = 0) -> None: if limit: # pilot: prioritise mechanism + controls (incl. the HDAC inhibitors) then fill pri = df[df["inclusion_reason"].isin(["ground_truth", "related_mechanism", "negative_control"])] df = pd.concat([pri, df.drop(pri.index)]).head(limit) - jobs = [(f"{target}__{r.pert_iname}", seq, r.canonical_smiles, cofactors) for r in df.itertuples()] - print(f"screening {len(jobs)} drugs vs {target} (+{cofactors})") + # 1) compute the target MSA ONCE (also warms weights/CCD), then reuse it for every drug + print(f"computing {target} MSA once (server) ...") + msa_path = cache_msa.remote(target, seq, pubchem_smiles("vorinostat"), cofactors) + print(f"cached MSA: {msa_path}") - results = list(cofold.starmap(jobs)) - by = {j[0].split("__")[1]: r for j, r in zip(jobs, results)} + # 2) screen all drugs reusing the cached MSA; tolerate per-drug failures (no abort) + jobs = [(f"{target}__{r.pert_iname}", seq, r.canonical_smiles, cofactors, msa_path) + for r in df.itertuples()] + print(f"screening {len(jobs)} drugs vs {target} (+{cofactors}), reusing cached MSA") + results = list(cofold.starmap(jobs, return_exceptions=True)) + by = {j[0].split("__")[1]: (r if isinstance(r, dict) else None) for j, r in zip(jobs, results)} + n_fail = sum(1 for v in by.values() if v is None) + if n_fail: + print(f" ({n_fail}/{len(jobs)} drugs failed and were skipped)") reason = dict(zip(df["pert_iname"], df["inclusion_reason"])) rows = [{"drug": d, "P_binder": (r or {}).get("prob_binder"),