Source code for mythos.simulators.lammps.lammps_oxdna

"""LAMMPS-based OxDNA simulator for mythos."""

import re
from dataclasses import field
from pathlib import Path
from typing import Any

import chex
import jax.numpy as jnp
import numpy as np
from typing_extensions import override

from mythos.energy.base import EnergyFunction
from mythos.input.trajectory import NucleotideState, Trajectory, validate_box_size
from mythos.simulators.base import InputDirSimulator, SimulatorOutput
from mythos.simulators.io import SimulatorTrajectory
from mythos.utils.helpers import run_command
from mythos.utils.types import Params


[docs] @chex.dataclass(frozen=True, kw_only=True) class LAMMPSoxDNASimulator(InputDirSimulator): """LAMMPS-based OxDNA simulator. Please note that for LAMMPS simulations of oxDNA, BondedExcludedVolume should be left out of the energy function, as LAMMPS does not implement it (or does not in a compatible way). Args: input_dir: Path to the directory containing the LAMMPS input files. overwrite: Whether to overwrite the input directory or copy to a temporary directory. input_file_name: Name of the LAMMPS input file (default "input"). energy_fn: Energy function used in the simulation, for updating parameters. variables: Additional variables to set in the LAMMPS input file before run. These variables must already be defined in the input file using a command of the form "variable name equal value". temperature_variable: Name of the LAMMPS variable that holds the simulation temperature in reduced units (kT). When the corresponding variable is set, it is used to populate the temperature field of the output trajectory. """ energy_fn: EnergyFunction input_file_name: str = "input" variables: dict[str, Any] = field(default_factory=dict) temperature_variable: str = "kt"
[docs] @override def __post_init__(self) -> None: if not (Path(self.input_dir) / self.input_file_name).is_file(): raise FileNotFoundError(f"LAMMPS input file not found: {self.input_file_name}")
[docs] @override def run_simulation(self, input_dir: Path, params: Params, seed: int | None = None) -> SimulatorOutput: self._replace_parameters(input_dir, params, seed) run_command(["lmp", "-in", self.input_file_name], cwd=input_dir, log_prefix="lammps") traj = _read_lammps_output(input_dir.joinpath("trajectory.dat")) temperature = None if (kt := self.variables.get(self.temperature_variable)) is not None: n_states = traj.state_rigid_body.center.shape[0] temperature = jnp.full(n_states, float(kt)) return SimulatorOutput( observables=[ SimulatorTrajectory.from_rigid_body( traj.state_rigid_body, temperature=temperature, ) ] )
[docs] def _replace_parameters(self, input_dir: Path, params: Params, seed: int | None) -> None: updated_params = self.energy_fn.with_params(params).params_dict(exclude_non_optimizable=True) input_lines = input_dir.joinpath(self.input_file_name).read_text().splitlines() new_lines = _lammps_oxdna_replace_inputs(input_lines, updated_params, seed, variables=self.variables) input_dir.joinpath(self.input_file_name).write_text("\n".join(new_lines))
[docs] def _lammps_oxdna_replace_inputs( # noqa: C901 TODO: refactor perhaps to class input_lines: list[str], params: list[dict[str, float]], seed: int | None, variables: dict[str, Any] | None = None, ) -> list[str]: variable_replacements = {"seed": seed or np.random.default_rng().integers(0, 2**24), **(variables or {})} new_lines = [] seen = set() multiline_buffer = "" for input_l in input_lines: line = re.sub(r"\s+", " ", input_l.strip()) if line.endswith("&"): multiline_buffer += line.removesuffix("&") + " " continue if multiline_buffer: line = multiline_buffer + line multiline_buffer = "" if line.startswith("variable "): var = line.split()[1] if var in variable_replacements: line = f"variable {var} equal {variable_replacements.pop(var)}" elif line.startswith("dump "): line_parts = line.split() if len(line_parts) > 6: # noqa: PLR2004 fname = line_parts[5] fields = set(line_parts[6:]) if LAMMPS_REQUIRED_FIELDS.issubset(fields) and fname == "trajectory.dat": seen.add("dump_line") for key, replacements in REPLACEMENT_MAP.items(): if line.startswith(key): new_parts = _replace_parts_in_line(line.removeprefix(key), replacements, params) line = f"{key} {new_parts}" new_lines.append(line) if "dump_line" not in seen: raise ValueError(f"Required dump not found. Must dump to trajectory.dat fields {LAMMPS_REQUIRED_FIELDS}.") if variable_replacements: raise ValueError("Missing variable for replacements: " + ", ".join(variable_replacements.keys())) return new_lines
[docs] def _replace_parts_in_line(inputs: str, replacements: tuple[str], params: dict[str, float]) -> str: parts = inputs.split() def repl(part: str, replacement: str | None) -> str: if replacement is None or replacement not in params: return part return f"{_transform_param(replacement, params[replacement]):f}" return " ".join([repl(part, r_param) for part, r_param in zip(parts, replacements, strict=True)])
REPLACEMENT_MAP = { "bond_coeff *": ("eps_backbone", "delta_backbone", "r0_backbone"), "pair_coeff * * oxdna/excv": ( "eps_exc", "sigma_backbone", "dr_star_backbone", "eps_exc", "sigma_back_base", "dr_star_back_base", "eps_exc", "sigma_base", "dr_star_base", ), "pair_coeff * * oxdna/stk": ( None, None, "eps_stack_base", "eps_stack_kt_coeff", "a_stack", "dr0_stack", "dr_c_stack", "dr_low_stack", "dr_high_stack", "a_stack_4", "theta0_stack_4", "delta_theta_star_stack_4", "a_stack_5", "theta0_stack_5", "delta_theta_star_stack_5", "a_stack_6", "theta0_stack_6", "delta_theta_star_stack_6", "a_stack_1", "neg_cos_phi1_star_stack", "a_stack_2", "neg_cos_phi2_star_stack", ), "pair_coeff * * oxdna/hbond": ( None, "HYDR_F1", # this we don't have replacement for "a_hb", "dr0_hb", "dr_c_hb", "dr_low_hb", "dr_high_hb", "a_hb_1", "theta0_hb_1", "delta_theta_star_hb_1", "a_hb_2", "theta0_hb_2", "delta_theta_star_hb_2", "a_hb_3", "theta0_hb_3", "delta_theta_star_hb_3", "a_hb_4", "theta0_hb_4", "delta_theta_star_hb_4", "a_hb_8", # 8 and 7 swapped in lammps input "theta0_hb_8", "delta_theta_star_hb_8", "a_hb_7", "theta0_hb_7", "delta_theta_star_hb_7", ), "pair_coeff 1 4 oxdna/hbond": ( None, "eps_hb", "a_hb", "dr0_hb", "dr_c_hb", "dr_low_hb", "dr_high_hb", "a_hb_1", "theta0_hb_1", "delta_theta_star_hb_1", "a_hb_2", "theta0_hb_2", "delta_theta_star_hb_2", "a_hb_3", "theta0_hb_3", "delta_theta_star_hb_3", "a_hb_4", "theta0_hb_4", "delta_theta_star_hb_4", "a_hb_8", # 8 and 7 swapped in lammps input "theta0_hb_8", "delta_theta_star_hb_8", "a_hb_7", "theta0_hb_7", "delta_theta_star_hb_7", ), "pair_coeff 2 3 oxdna/hbond": ( None, "eps_hb", "a_hb", "dr0_hb", "dr_c_hb", "dr_low_hb", "dr_high_hb", "a_hb_1", "theta0_hb_1", "delta_theta_star_hb_1", "a_hb_2", "theta0_hb_2", "delta_theta_star_hb_2", "a_hb_3", "theta0_hb_3", "delta_theta_star_hb_3", "a_hb_4", "theta0_hb_4", "delta_theta_star_hb_4", "a_hb_7", "theta0_hb_7", "delta_theta_star_hb_7", "a_hb_8", "theta0_hb_8", "delta_theta_star_hb_8", ), "pair_coeff * * oxdna/xstk": ( "k_cross", "r0_cross", "dr_c_cross", "dr_low_cross", "dr_high_cross", "a_cross_1", "theta0_cross_1", "delta_theta_star_cross_1", "a_cross_3", # 3 and 2 swapped in lammps input "theta0_cross_3", "delta_theta_star_cross_3", "a_cross_2", "theta0_cross_2", "delta_theta_star_cross_2", "a_cross_4", "theta0_cross_4", "delta_theta_star_cross_4", "a_cross_8", # 8 and 7 swapped in lammps input "theta0_cross_8", "delta_theta_star_cross_8", "a_cross_7", "theta0_cross_7", "delta_theta_star_cross_7", ), "pair_coeff * * oxdna/coaxstk": ( "k_coax", "dr0_coax", "dr_c_coax", "dr_low_coax", "dr_high_coax", "a_coax_1", "theta0_coax_1", "delta_theta_star_coax_1", "a_coax_4", "theta0_coax_4", "delta_theta_star_coax_4", "a_coax_5", "theta0_coax_5", "delta_theta_star_coax_5", "a_coax_6", "theta0_coax_6", "delta_theta_star_coax_6", "a_coax_3p", "cos_phi3_star_coax", "a_coax_4p", "cos_phi4_star_coax", ), } # Copy common oxdna2 parameters providing overrides where needed REPLACEMENT_MAP = { **REPLACEMENT_MAP, **{k.replace("oxdna/", "oxdna2/"): v for k, v in REPLACEMENT_MAP.items() if "oxdna/" in k}, "pair_coeff * * oxdna2/coaxstk": ( "k_coax", "dr0_coax", "dr_c_coax", "dr_low_coax", "dr_high_coax", "a_coax_1", "theta0_coax_1", "delta_theta_star_coax_1", "a_coax_4", "theta0_coax_4", "delta_theta_star_coax_4", "a_coax_5", "theta0_coax_5", "delta_theta_star_coax_5", "a_coax_6", "theta0_coax_6", "delta_theta_star_coax_6", "a_coax_1_f6", "b_coax_1_f6", ), "pair_coeff * * oxdna2/dh": (None, "salt_conc", "q_eff"), }
[docs] def _transform_param(param: str, value: float) -> float: if param in ["neg_cos_phi1_star_stack", "neg_cos_phi2_star_stack"]: return -value return value
LAMMPS_REQUIRED_FIELDS = { "x", "y", "z", "vx", "vy", "vz", "c_quat[1]", "c_quat[2]", "c_quat[3]", "c_quat[4]", "angmomx", "angmomy", "angmomz", }
[docs] def _transform_lammps_state(state: np.ndarray, fields: str) -> np.ndarray: def get_idx(*field_names: str) -> list[int]: return [fields.index(name) for name in field_names] pos = state[get_idx("x", "y", "z")] vel = state[get_idx("vx", "vy", "vz")] quat = state[get_idx("c_quat[1]", "c_quat[2]", "c_quat[3]", "c_quat[4]")] angmom = state[get_idx("angmomx", "angmomy", "angmomz")] vel *= np.sqrt(3.1575) angmom /= np.sqrt(0.435179) return np.concatenate([pos, _transform_lammps_quat(quat), vel, angmom])
[docs] def _transform_lammps_quat(quat: np.ndarray) -> np.ndarray: q_2 = quat**2 i = 1 / q_2.sum() a0 = (q_2[0] + q_2[1] - q_2[2] - q_2[3]) * i a1 = 2 * (quat[1] * quat[2] + quat[0] * quat[3]) * i a2 = 2 * (quat[1] * quat[3] - quat[0] * quat[2]) * i b0 = 2 * (quat[1] * quat[3] + quat[0] * quat[2]) * i b1 = 2 * (quat[2] * quat[3] - quat[0] * quat[1]) * i b2 = (q_2[0] + q_2[3] - q_2[1] - q_2[2]) * i return np.array([a0, a1, a2, b0, b1, b2])
[docs] def _read_lammps_output(output_file: Path) -> Trajectory: """Reads LAMMPS trajectory dump file and extracts the final energy values. The file must have been created by a dump LAMMPS dump command similar to: compute quat all property/atom quatw quati quatj quatk dump {name} all custom {freq} trajectory.dat x y z vx vy vz & c_quat[1] c_quat[2] c_quat[3] c_quat[4] angmomx angmomy angmomz noting that the above fields are required, but other fields may also be present in the dump. Args: output_file: Path to the LAMMPS trajectory dump file. Returns: A Trajectory object in mythos format. """ ts = [] bs = [] states = [] num_atoms = None with output_file.open() as f: for line in f: if line.startswith("ITEM: TIMESTEP"): t = float(next(f)) if t == 0: # skip initial frame continue ts.append(t) if not ts: continue if line.startswith("ITEM: NUMBER OF ATOMS") and num_atoms is None: num_atoms = int(next(f)) elif line.startswith("ITEM: BOX BOUNDS"): bounds = " ".join([next(f).replace("\n", " ") for _ in range(3)]) bx1, bx2, by1, by2, bz1, bz2 = np.fromstring(bounds, dtype=np.float64, sep=" ") bs.append(np.array([bx2 - bx1, by2 - by1, bz2 - bz1])) elif line.startswith("ITEM: ATOMS"): state_fields = line[12:].strip().split() if LAMMPS_REQUIRED_FIELDS - set(state_fields): raise ValueError("LAMMPS output file missing required fields.") states.append( np.array( [ _transform_lammps_state(np.fromstring(next(f), dtype=np.float64, sep=" "), state_fields) for _ in range(num_atoms) ] ) ) validate_box_size(bs) return Trajectory( n_nucleotides=num_atoms, strand_lengths=[num_atoms], # this is not actually correct times=np.array(ts, dtype=np.float64), energies=np.zeros((len(ts), 3), dtype=np.float64), # energies are not parsed here states=[NucleotideState(array=s) for s in states], )