diff --git a/corelib/src/libs/SireIO/amberprm.cpp b/corelib/src/libs/SireIO/amberprm.cpp index 031f85e00..67a4ff563 100644 --- a/corelib/src/libs/SireIO/amberprm.cpp +++ b/corelib/src/libs/SireIO/amberprm.cpp @@ -5396,6 +5396,63 @@ System AmberPrm::startSystem(const PropertyMap &map) const bar.success(); } + if (not in_expected_order) + { + // Collect all residues in loaded (mol-index) order and check for + // residue-number inversions. These arise when covalent bonds cross + // AMBER's molecule boundaries (e.g. a metal ion bonded to a small + // molecule), forcing Sire to group the bonded atoms into the same + // molecule and thereby placing some residues out of their original + // residue-number sequence. + struct ResEntry + { + int resnum; + QString resname; + }; + + QVector all_res; + + for (const auto &mol : mols) + { + const auto &molinfo = mol.info(); + const int nres = molinfo.nResidues(); + + for (int j = 0; j < nres; ++j) + { + all_res.append({molinfo.number(ResIdx(j)).value(), + molinfo.name(ResIdx(j)).value()}); + } + } + + QStringList out_of_order; + int prev_resnum = 0; + + for (const auto &res : all_res) + { + if (res.resnum < prev_resnum) + { + out_of_order.append( + QString("%1(%2)").arg(res.resname).arg(res.resnum)); + } + + prev_resnum = res.resnum; + } + + if (not out_of_order.isEmpty()) + { + qWarning().noquote() << QObject::tr( + "WARNING: One or more residues have been reordered relative to the " + "original topology when loading this file. This happens when covalent " + "bonds cross AMBER's molecule boundaries (e.g. a metal ion bonded to a " + "small molecule ligand), which forces Sire to group the bonded atoms " + "into the same molecule. The following residues appear out of " + "residue-number order in the loaded system: %1. " + "To avoid ordering issues, access residues by name rather than by " + "position (e.g. mols.residues()[\"resname ML1\"] instead of mols.residues()[1]).") + .arg(out_of_order.join(", ")); + } + } + MoleculeGroup molgroup("all"); for (auto mol : mols) diff --git a/corelib/src/libs/SireMM/bondrestraints.cpp b/corelib/src/libs/SireMM/bondrestraints.cpp index ff63cd678..b87aba1f1 100644 --- a/corelib/src/libs/SireMM/bondrestraints.cpp +++ b/corelib/src/libs/SireMM/bondrestraints.cpp @@ -188,6 +188,7 @@ BondRestraint &BondRestraint::operator=(const BondRestraint &other) { if (this != &other) { + Property::operator=(other); atms0 = other.atms0; atms1 = other.atms1; _k = other._k; diff --git a/corelib/src/libs/SireMM/inversebondrestraints.cpp b/corelib/src/libs/SireMM/inversebondrestraints.cpp index f88acc3fe..c110c0a70 100644 --- a/corelib/src/libs/SireMM/inversebondrestraints.cpp +++ b/corelib/src/libs/SireMM/inversebondrestraints.cpp @@ -188,6 +188,7 @@ InverseBondRestraint &InverseBondRestraint::operator=(const InverseBondRestraint { if (this != &other) { + Property::operator=(other); atms0 = other.atms0; atms1 = other.atms1; _k = other._k; diff --git a/corelib/src/libs/SireMM/morsepotentialrestraints.cpp b/corelib/src/libs/SireMM/morsepotentialrestraints.cpp index 1322e08b1..666f7426c 100644 --- a/corelib/src/libs/SireMM/morsepotentialrestraints.cpp +++ b/corelib/src/libs/SireMM/morsepotentialrestraints.cpp @@ -190,6 +190,7 @@ MorsePotentialRestraint &MorsePotentialRestraint::operator=(const MorsePotential { if (this != &other) { + Property::operator=(other); atms0 = other.atms0; atms1 = other.atms1; _k = other._k; diff --git a/corelib/src/libs/SireMM/positionalrestraints.cpp b/corelib/src/libs/SireMM/positionalrestraints.cpp index 101ce0f57..cd3ac6902 100644 --- a/corelib/src/libs/SireMM/positionalrestraints.cpp +++ b/corelib/src/libs/SireMM/positionalrestraints.cpp @@ -146,6 +146,7 @@ PositionalRestraint &PositionalRestraint::operator=(const PositionalRestraint &o { if (this != &other) { + Property::operator=(other); atms = other.atms; pos = other.pos; _k = other._k; diff --git a/doc/source/changelog.rst b/doc/source/changelog.rst index acf30dd44..da07068a3 100644 --- a/doc/source/changelog.rst +++ b/doc/source/changelog.rst @@ -36,6 +36,15 @@ organisation on `GitHub `__. * Add support for 4- and 5-point water models in the OpenMM conversion layer. +* Add per-force-group energy caching to the OpenMM integration layer. Named + forces (``clj``, ``bond``, ``angle``, ``torsion``, ``cmap``, ghost forces, + restraints, etc.) are each assigned a unique OpenMM force group when the + system is built. :meth:`~sire.mol.Dynamics.get_potential_energy` now + re-evaluates only the groups whose parameters changed since the last lambda + update, using cached values for all others. Call + :meth:`~sire.mol.Dynamics.clear_energy_cache` to force a full re-evaluation + (e.g. after a replica-exchange position swap). + * Add functionality for coupling one lambda lever to another. * Added support for Direct Morse Replacement (DMR) feature in ``sire.restraints.morse_potential`` @@ -45,6 +54,8 @@ organisation on `GitHub `__. * Store OpenMM state at start of a dynamics run to use for crash recovery. +* Print warning when ``sire.legacy.IO.AmberPrm`` parser re-orders residues due to covalent bonds between molecules. + `2025.4.0 `__ - February 2026 --------------------------------------------------------------------------------------------- diff --git a/src/sire/mol/_dynamics.py b/src/sire/mol/_dynamics.py index 7b3ac4dbf..6878ce91f 100644 --- a/src/sire/mol/_dynamics.py +++ b/src/sire/mol/_dynamics.py @@ -483,6 +483,10 @@ def _exit_dynamics_block( nrg_sim_lambda_value = nrg if lambda_windows is not None: + # Positions have just changed (dynamics completed), so + # invalidate all cached per-group energies before the scan. + self._omm_mols.clear_energy_cache() + # get the index of the simulation lambda value in the # lambda windows list try: @@ -542,6 +546,10 @@ def _exit_dynamics_block( self._nrgs = nrgs self._nrgs_array = nrgs_array + # Repex synchronisation point: a peer replica may push new + # positions into this context, so the cache must be invalidated. + self._omm_mols.clear_energy_cache() + # update the interpolation lambda value if self._is_interpolate: if delta_lambda: @@ -853,14 +861,11 @@ def current_potential_energy(self): if self.is_null(): return 0 else: - from openmm.unit import kilocalorie_per_mole as _omm_kcal_mol - from ..units import kcal_per_mol as _sire_kcal_mol + return self._omm_mols.get_potential_energy(to_sire_units=True) - state = self._get_current_state() - - nrg = state.getPotentialEnergy() - - return nrg.value_in_unit(_omm_kcal_mol) * _sire_kcal_mol + def clear_energy_cache(self): + if not self.is_null(): + self._omm_mols.clear_energy_cache() def current_kinetic_energy(self): if self.is_null(): @@ -989,6 +994,9 @@ def run_minimisation( timeout=timeout.to(second), ) + # Positions changed during minimisation; invalidate the energy cache. + self._omm_mols.clear_energy_cache() + def _rebuild_and_minimise(self): if self.is_null(): return @@ -1030,38 +1038,28 @@ def _rebuild_and_minimise(self): if self._save_crash_report: import openmm import numpy as np - from copy import deepcopy from uuid import uuid4 # Create a unique identifier for this crash report. crash_id = str(uuid4())[:8] - # Get the current context and system. context = self._omm_mols - system = deepcopy(context.getSystem()) - - # Add each force to a unique group. - for i, f in enumerate(system.getForces()): - f.setForceGroup(i) - - # Create a new context. - new_context = openmm.Context(system, deepcopy(context.getIntegrator())) - new_context.setPositions(context.getState(getPositions=True).getPositions()) - # Write the energies for each force group. + # Write per-force-group energies using the groups already assigned + # by sire_to_openmm_system. with open(f"crash_{crash_id}.log", "w") as f: f.write(f"Current lambda: {str(self.get_lambda())}\n") - for i, force in enumerate(system.getForces()): - state = new_context.getState(getEnergy=True, groups={i}) - f.write(f"{force.getName()}, {state.getPotentialEnergy()}\n") + for name, grp in context._force_group_map.items(): + state = context.getState(getEnergy=True, groups=(1 << grp)) + f.write(f"{name}, {state.getPotentialEnergy()}\n") # Save the serialised system. with open(f"system_{crash_id}.xml", "w") as f: - f.write(openmm.XmlSerializer.serialize(system)) + f.write(openmm.XmlSerializer.serialize(context.getSystem())) # Save the positions. positions = ( - new_context.getState(getPositions=True).getPositions(asNumpy=True) + context.getState(getPositions=True).getPositions(asNumpy=True) / openmm.unit.nanometer ) np.savetxt(f"positions_{crash_id}.txt", positions) @@ -2252,6 +2250,14 @@ def _current_energy_array(self): """ return self._d._current_energy_array() + def clear_energy_cache(self): + """ + Invalidate the per-force-group energy cache. Call this whenever + positions have been changed externally (e.g. after a replica-exchange + swap) so that the next energy evaluation fully re-computes all groups. + """ + self._d.clear_energy_cache() + def to_xml(self, f=None): """ Save the current state of the dynamics to XML. diff --git a/tests/convert/test_openmm_force_groups.py b/tests/convert/test_openmm_force_groups.py new file mode 100644 index 000000000..c67abe80e --- /dev/null +++ b/tests/convert/test_openmm_force_groups.py @@ -0,0 +1,572 @@ +""" +Tests for per-force-group energy caching in SOMMContext / LambdaLever. + +Checks: + - Named force groups are assigned and retrievable via get_force_group/get_force_names. + - The cached per-group sum equals the full potential energy at several lambda values. + - clear_energy_cache() forces a full re-evaluation on the next get_potential_energy() call. + - Repeated get_potential_energy() calls without a lambda change return the same value + (i.e. unchanged groups are served from cache, not re-computed). + - A REST2 scale change marks the appropriate groups as dirty. +""" + +import pytest +import sire as sr + + +@pytest.fixture(scope="module") +def perturbable_omm(merged_ethane_methanol, openmm_platform): + """Return a SOMMContext built from the merged ethane/methanol molecule.""" + mols = merged_ethane_methanol.clone() + + # Pin coordinates to lambda-0 end state so energies are well-behaved. + c = mols.cursor() + c["molidx 0"]["coordinates"] = c["molidx 0"]["coordinates0"] + c["molidx 0"]["coordinates1"] = c["molidx 0"]["coordinates0"] + mols = c.commit() + + l = sr.cas.LambdaSchedule() + l.add_stage("morph", (1 - l.lam()) * l.initial() + l.lam() * l.final()) + + map = { + "platform": openmm_platform, + "schedule": l, + "constraint": "h-bonds-not-perturbed", + "include_constrained_energies": True, + "dynamic_constraints": False, + } + + return sr.convert.to(mols[0], "openmm", map=map) + + +@pytest.mark.skipif( + "openmm" not in sr.convert.supported_formats(), + reason="openmm support is not available", +) +def test_force_groups_assigned(perturbable_omm): + """Every expected named force has a non-negative force group index.""" + omm = perturbable_omm + lever = omm.get_lambda_lever() + + force_names = lever.get_force_names() + assert len(force_names) > 0, "No force names registered on LambdaLever" + + for name in force_names: + grp = lever.get_force_group(name) + assert grp >= 0, f"Force '{name}' has invalid group index {grp}" + + # The force_group_map built in Python must be consistent. + assert len(omm._force_group_map) > 0 + for name, grp in omm._force_group_map.items(): + assert grp >= 0 + + +@pytest.mark.skipif( + "openmm" not in sr.convert.supported_formats(), + reason="openmm support is not available", +) +def test_per_group_sum_equals_full_energy(perturbable_omm): + """ + Sum of per-group energies must equal the full potential energy + (within numerical noise) at lambda=0, 0.5, and 1. + """ + import openmm + + omm = perturbable_omm + + for lam in (0.0, 0.5, 1.0): + omm.set_lambda(lam) + omm.clear_energy_cache() + + # Full evaluation (groups bitmask = all 32 groups). + full_state = omm.getState(getEnergy=True) + full_kj = full_state.getPotentialEnergy().value_in_unit( + openmm.unit.kilojoule_per_mole + ) + + # Per-group sum. + group_sum_kj = 0.0 + for grp in omm._force_group_map.values(): + s = omm.getState(getEnergy=True, groups=(1 << grp)) + group_sum_kj += s.getPotentialEnergy().value_in_unit( + openmm.unit.kilojoule_per_mole + ) + + assert group_sum_kj == pytest.approx(full_kj, abs=1e-3), ( + f"Group sum {group_sum_kj:.6f} kJ/mol != full energy {full_kj:.6f} kJ/mol " + f"at lambda={lam}" + ) + + +@pytest.mark.skipif( + "openmm" not in sr.convert.supported_formats(), + reason="openmm support is not available", +) +def test_cached_energy_matches_full(perturbable_omm): + """ + get_potential_energy() via the cache must match a direct full getState() + at lambda=0, 0.5, and 1. + """ + import openmm + + omm = perturbable_omm + + for lam in (0.0, 0.5, 1.0): + omm.set_lambda(lam) + omm.clear_energy_cache() + + cached_kj = omm.get_potential_energy(to_sire_units=False).value_in_unit( + openmm.unit.kilojoule_per_mole + ) + + full_state = omm.getState(getEnergy=True) + full_kj = full_state.getPotentialEnergy().value_in_unit( + openmm.unit.kilojoule_per_mole + ) + + assert cached_kj == pytest.approx(full_kj, abs=1e-3), ( + f"Cached energy {cached_kj:.6f} kJ/mol != full energy {full_kj:.6f} kJ/mol " + f"at lambda={lam}" + ) + + +@pytest.mark.skipif( + "openmm" not in sr.convert.supported_formats(), + reason="openmm support is not available", +) +def test_cache_stable_without_lambda_change(perturbable_omm): + """ + Calling get_potential_energy() twice without a lambda change returns + the same value (second call served entirely from cache). + """ + omm = perturbable_omm + omm.set_lambda(0.5) + omm.clear_energy_cache() + + nrg1 = omm.get_potential_energy(to_sire_units=True).value() + nrg2 = omm.get_potential_energy(to_sire_units=True).value() + + assert nrg1 == pytest.approx( + nrg2, rel=1e-10 + ), "Energy changed between two consecutive calls with no lambda change" + + +@pytest.mark.skipif( + "openmm" not in sr.convert.supported_formats(), + reason="openmm support is not available", +) +def test_clear_cache_marks_all_dirty(perturbable_omm): + """ + After clear_energy_cache(), _dirty_groups contains every group in + _force_group_map, and get_potential_energy() returns the correct value. + """ + import openmm + + omm = perturbable_omm + omm.set_lambda(0.0) + + # Populate cache. + _ = omm.get_potential_energy(to_sire_units=False) + + # Clear. + omm.clear_energy_cache() + + assert omm._dirty_groups == set( + omm._force_group_map.values() + ), "After clear_energy_cache(), not all groups are marked dirty" + assert ( + len(omm._energy_cache) == 0 + ), "After clear_energy_cache(), energy_cache should be empty" + + # Energy should still be correct after re-evaluation. + full_state = omm.getState(getEnergy=True) + full_kj = full_state.getPotentialEnergy().value_in_unit( + openmm.unit.kilojoule_per_mole + ) + cached_kj = omm.get_potential_energy(to_sire_units=False).value_in_unit( + openmm.unit.kilojoule_per_mole + ) + assert cached_kj == pytest.approx(full_kj, abs=1e-3) + + +@pytest.mark.skipif( + "openmm" not in sr.convert.supported_formats(), + reason="openmm support is not available", +) +def test_set_positions_invalidates_cache(perturbable_omm): + """ + Calling setPositions() must invalidate the energy cache so that the next + get_potential_energy() call re-evaluates all force groups. + """ + omm = perturbable_omm + omm.set_lambda(0.0) + + # Populate the cache. + _ = omm.get_potential_energy(to_sire_units=False) + assert len(omm._dirty_groups) == 0, "Cache should be clean after evaluation" + + # Retrieve current positions and set them back — content unchanged but + # the override must still invalidate the cache. + positions = omm.getState(getPositions=True).getPositions() + omm.setPositions(positions) + + assert omm._dirty_groups == set( + omm._force_group_map.values() + ), "All groups should be dirty after setPositions()" + + +@pytest.mark.skipif( + "openmm" not in sr.convert.supported_formats(), + reason="openmm support is not available", +) +def test_set_state_invalidates_cache(perturbable_omm): + """ + Calling setState() must invalidate the energy cache so that the next + get_potential_energy() call re-evaluates all force groups. + """ + omm = perturbable_omm + omm.set_lambda(0.0) + + # Populate the cache. + _ = omm.get_potential_energy(to_sire_units=False) + assert len(omm._dirty_groups) == 0, "Cache should be clean after evaluation" + + # Round-trip through setState using the current state. + state = omm.getState(getPositions=True) + omm.setState(state) + + assert omm._dirty_groups == set( + omm._force_group_map.values() + ), "All groups should be dirty after setState()" + + +@pytest.mark.skipif( + "openmm" not in sr.convert.supported_formats(), + reason="openmm support is not available", +) +def test_set_periodic_box_vectors_invalidates_cache(perturbable_omm): + """ + Calling setPeriodicBoxVectors() must invalidate the energy cache since + a box change affects the PME energy. + """ + omm = perturbable_omm + omm.set_lambda(0.0) + + # Populate the cache. + _ = omm.get_potential_energy(to_sire_units=False) + assert len(omm._dirty_groups) == 0, "Cache should be clean after evaluation" + + # Set the same box vectors back — content unchanged but the override must + # still invalidate the cache. + box = omm.getState(getPositions=True).getPeriodicBoxVectors() + omm.setPeriodicBoxVectors(*box) + + assert omm._dirty_groups == set( + omm._force_group_map.values() + ), "All groups should be dirty after setPeriodicBoxVectors()" + + +# --------------------------------------------------------------------------- +# Levers that belong to each named OpenMM force. When ALL levers for a force +# are pinned to l.initial(), that force's parameters cannot change between +# lambda steps, so it must NOT be marked dirty. +# --------------------------------------------------------------------------- +_FORCE_LEVERS = { + "bond": ["bond_k", "bond_length"], + "angle": ["angle_k", "angle_size"], + # Note: fixing torsion_k also implicitly fixes cmap_grid via the default + # coupling, but merged_ethane_methanol has no CMAP so this has no effect. + "torsion": ["torsion_k", "torsion_phase"], + # Fixing these without a force argument pins them for clj, ghost/ghost, + # ghost/non-ghost and ghost-14 simultaneously. + "clj": ["charge", "sigma", "epsilon", "alpha", "kappa", "charge_scale", "lj_scale"], + # cmap_grid is the only lever for the CMAPTorsionForce. By default it is + # coupled to torsion_k, so without an explicit equation it would morph + # whenever torsion_k does. _make_fixed_schedule sets an explicit equation + # (l.initial()) which breaks that coupling and pins CMAP independently. + # Tested with a molecule that actually has perturbable CMAP terms + # (merged_molecule_cmap.s3). + "cmap": ["cmap_grid"], +} + +# Forces whose dirty-state is tied together with "clj" (they share levers). +_CLJ_RELATED = {"clj", "ghost/ghost", "ghost/non-ghost", "ghost-14"} + + +def _make_fixed_schedule(fixed_levers): + """ + Return a single-stage LambdaSchedule that morphs all parameters + linearly, except for the levers listed in *fixed_levers* which are + pinned to their lambda=0 (initial) values. + """ + l = sr.cas.LambdaSchedule() + l.add_stage("morph", (1 - l.lam()) * l.initial() + l.lam() * l.final()) + for lever in fixed_levers: + l.set_equation(stage="morph", lever=lever, equation=l.initial()) + return l + + +@pytest.mark.skipif( + "openmm" not in sr.convert.supported_formats(), + reason="openmm support is not available", +) +@pytest.mark.parametrize("fixed_force", list(_FORCE_LEVERS.keys())) +def test_fixed_lever_not_dirty(merged_ethane_methanol, openmm_platform, fixed_force): + """ + When all levers controlling *fixed_force* are pinned to their initial + values, that force must not be marked dirty after a lambda step. + All other forces (whose levers still morph) must be dirty. + The cached energy must still match the full OpenMM energy. + """ + import openmm as mm + + fixed_levers = _FORCE_LEVERS[fixed_force] + schedule = _make_fixed_schedule(fixed_levers) + + if fixed_force == "cmap": + # merged_molecule_cmap.s3 is a perturbable ubiquitin chain (T9A mutation) + # that carries genuine CMAP backbone correction terms at both end states. + mols = sr.load_test_files("merged_molecule_cmap.s3") + mols = sr.morph.link_to_reference(mols) + omm = sr.convert.to( + mols, + "openmm", + map={ + "platform": openmm_platform or "CPU", + "schedule": schedule, + "constraint": "none", + "cutoff": "none", + "cutoff_type": "none", + }, + ) + else: + mols = merged_ethane_methanol.clone() + + # Pin coordinates to lambda-0 so energies are well-behaved. + c = mols.cursor() + c["molidx 0"]["coordinates"] = c["molidx 0"]["coordinates0"] + c["molidx 0"]["coordinates1"] = c["molidx 0"]["coordinates0"] + mols = c.commit() + + omm = sr.convert.to( + mols[0], + "openmm", + map={ + "platform": openmm_platform, + "schedule": schedule, + "constraint": "h-bonds-not-perturbed", + "include_constrained_energies": True, + "dynamic_constraints": False, + }, + ) + + lever = omm.get_lambda_lever() + + # Step 1: prime the cache at lambda=0 (first call — all forces dirty + # because there is no previous cached state to compare against). + omm.set_lambda(0.0) + _ = omm.get_potential_energy(to_sire_units=False) # clears dirty_groups + + # Step 2: advance lambda — now hasChanged() compares against the + # lambda=0 values stored in prev_cache. + omm.set_lambda(0.5) + + # The pinned force must NOT be dirty. + if fixed_force == "clj": + # All CLJ-related forces share the same levers. + for name in _CLJ_RELATED: + if name in omm._force_group_map: + assert not lever.was_force_changed(name), ( + f"'{name}' should not be changed when all its levers are " + f"pinned to initial (fixed_force='{fixed_force}')" + ) + assert ( + omm._force_group_map[name] not in omm._dirty_groups + ), f"Force group for '{name}' should not be dirty" + else: + assert not lever.was_force_changed(fixed_force), ( + f"'{fixed_force}' should not be changed when all its levers are " + f"pinned to initial" + ) + if fixed_force in omm._force_group_map: + assert ( + omm._force_group_map[fixed_force] not in omm._dirty_groups + ), f"Force group for '{fixed_force}' should not be dirty" + + # All OTHER morphing forces must be dirty. + # Exclude "cmap": molecules without CMAP terms have no CMAP parameters to + # change, so was_force_changed("cmap") is correctly False regardless of + # pinning. The reverse direction (CMAP not dirty when pinned) is covered + # by the fixed_force="cmap" parametrize case which uses a CMAP molecule. + other_forces = set(_FORCE_LEVERS.keys()) - {fixed_force, "cmap"} + for other in other_forces: + if other == "clj": + # Check at least the primary clj force. + if "clj" in omm._force_group_map: + assert lever.was_force_changed( + "clj" + ), f"'clj' should be changed (fixed_force='{fixed_force}')" + else: + if other in omm._force_group_map: + assert lever.was_force_changed(other), ( + f"'{other}' should be changed when it is not pinned " + f"(fixed_force='{fixed_force}')" + ) + + # Energy correctness: cached sum must match full OpenMM evaluation. + full_kj = ( + omm.getState(getEnergy=True) + .getPotentialEnergy() + .value_in_unit(mm.unit.kilojoule_per_mole) + ) + cached_kj = omm.get_potential_energy(to_sire_units=False).value_in_unit( + mm.unit.kilojoule_per_mole + ) + assert cached_kj == pytest.approx(full_kj, abs=1e-3), ( + f"Cached energy {cached_kj:.6f} kJ/mol != full energy {full_kj:.6f} kJ/mol " + f"(fixed_force='{fixed_force}')" + ) + + +@pytest.mark.skipif( + "openmm" not in sr.convert.supported_formats(), + reason="openmm support is not available", +) +def test_rest2_scale_change_dirties_correct_groups( + merged_ethane_methanol, openmm_platform +): + """ + REST2 scale changes must dirty CLJ and torsion groups (and their related + ghost forces) but must NOT dirty bond or angle groups. + + Uses a schedule where every morphed lever is pinned to its initial value, + so morphed parameter vectors never change between lambda steps. This + isolates the REST2 scale as the sole source of cache invalidation. + + Three scenarios are tested: + 1. Lambda changes with REST2 scale held at 1.0 → no groups dirtied + (cache entirely reused because neither morphed values nor scale changed). + 2. Lambda held constant, REST2 scale changes from 1.0 → 2.0 → CLJ and + torsion groups (including ghost variants) are dirtied; bond and angle + groups are NOT. + 3. The cached energy after the REST2 scale change still matches the full + OpenMM potential energy. + """ + import openmm + + # Pin every lever to its initial value so morphed vectors never change. + all_levers = ( + list(_FORCE_LEVERS["bond"]) + + list(_FORCE_LEVERS["angle"]) + + list(_FORCE_LEVERS["torsion"]) + + list(_FORCE_LEVERS["clj"]) + ) + schedule = _make_fixed_schedule(all_levers) + + mols = merged_ethane_methanol.clone() + c = mols.cursor() + c["molidx 0"]["coordinates"] = c["molidx 0"]["coordinates0"] + c["molidx 0"]["coordinates1"] = c["molidx 0"]["coordinates0"] + mols = c.commit() + + omm = sr.convert.to( + mols[0], + "openmm", + map={ + "platform": openmm_platform, + "schedule": schedule, + "constraint": "h-bonds-not-perturbed", + "include_constrained_energies": True, + "dynamic_constraints": False, + }, + ) + + # ----------------------------------------------------------------------- + # Scenario 1: prime the cache at lambda=0, REST2 scale=1.0. + # ----------------------------------------------------------------------- + omm.set_lambda(0.0, rest2_scale=1.0) + _ = omm.get_potential_energy(to_sire_units=False) # clears _dirty_groups + assert len(omm._dirty_groups) == 0, "Cache should be fully clean after priming" + + # Advance lambda — morphed values are all pinned so only REST2 scale + # changes could dirty anything; scale is still 1.0, so nothing is dirty. + omm.set_lambda(0.5, rest2_scale=1.0) + assert len(omm._dirty_groups) == 0, ( + "No groups should be dirty after a lambda change when all levers are " + "pinned and REST2 scale is unchanged" + ) + + # Consume the (still-clean) cache so subsequent checks start fresh. + _ = omm.get_potential_energy(to_sire_units=False) + + # ----------------------------------------------------------------------- + # Scenario 2: lambda stays at 0.5, REST2 scale changes 1.0 → 2.0. + # ----------------------------------------------------------------------- + omm.set_lambda(0.5, rest2_scale=2.0) + + rest2_affected = {"clj", "torsion", "ghost/ghost", "ghost/non-ghost"} + rest2_unaffected = {"bond", "angle"} + + for name in rest2_affected: + if name in omm._force_group_map: + assert ( + omm._force_group_map[name] in omm._dirty_groups + ), f"Force group '{name}' should be dirty after a REST2 scale change" + + for name in rest2_unaffected: + if name in omm._force_group_map: + assert ( + omm._force_group_map[name] not in omm._dirty_groups + ), f"Force group '{name}' should NOT be dirty after a REST2 scale change" + + # ----------------------------------------------------------------------- + # Scenario 3: cached energy still matches the full OpenMM evaluation. + # ----------------------------------------------------------------------- + full_kj = ( + omm.getState(getEnergy=True) + .getPotentialEnergy() + .value_in_unit(openmm.unit.kilojoule_per_mole) + ) + cached_kj = omm.get_potential_energy(to_sire_units=False).value_in_unit( + openmm.unit.kilojoule_per_mole + ) + assert cached_kj == pytest.approx(full_kj, abs=1e-3), ( + f"Cached energy {cached_kj:.6f} kJ/mol != full energy {full_kj:.6f} kJ/mol " + "after REST2 scale change" + ) + + +@pytest.mark.skipif( + "openmm" not in sr.convert.supported_formats(), + reason="openmm support is not available", +) +def test_lambda_change_dirties_correct_groups(perturbable_omm): + """ + After set_lambda(), only the groups whose parameters actually changed + are marked dirty. Groups that are unchanged are not in _dirty_groups. + """ + omm = perturbable_omm + omm.set_lambda(0.5) + omm.clear_energy_cache() + + # Populate cache at lambda=0.5. + _ = omm.get_potential_energy(to_sire_units=False) + assert len(omm._dirty_groups) == 0, "Cache should be fully clean after evaluation" + + # Move to a new lambda — some groups must become dirty. + omm.set_lambda(0.6) + assert ( + len(omm._dirty_groups) > 0 + ), "At least one group should be dirty after a lambda change" + + # The cached energy must still be correct. + import openmm + + full_state = omm.getState(getEnergy=True) + full_kj = full_state.getPotentialEnergy().value_in_unit( + openmm.unit.kilojoule_per_mole + ) + cached_kj = omm.get_potential_energy(to_sire_units=False).value_in_unit( + openmm.unit.kilojoule_per_mole + ) + assert cached_kj == pytest.approx(full_kj, abs=1e-3) diff --git a/wrapper/Convert/SireOpenMM/CMakeLists.txt b/wrapper/Convert/SireOpenMM/CMakeLists.txt index b0df4b54d..368cb0e7a 100644 --- a/wrapper/Convert/SireOpenMM/CMakeLists.txt +++ b/wrapper/Convert/SireOpenMM/CMakeLists.txt @@ -42,25 +42,6 @@ if (${SIRE_USE_OPENMM}) endif() endif() - # Check to see if we have support for updating some parameters in context - include(CheckCXXSourceCompiles) - check_cxx_source_compiles( "#include - int main() { - OpenMM::CustomNonbondedForce *force; - OpenMM::Context *context; - force->updateSomeParametersInContext(0, 0, *context); - return 0; - }" - SIREOPENMM_HAS_UPDATESOMEPARAMETERSINCONTEXT ) - - if ( ${SIREOPENMM_HAS_UPDATESOMEPARAMETERSINCONTEXT} ) - message( STATUS "OpenMM has support for updating some parameters in context") - add_definitions("-DSIRE_HAS_UPDATE_SOME_IN_CONTEXT") - else() - message( STATUS "OpenMM does not have support for updating some parameters in context") - message( STATUS "The free energy code will be a little slower.") - endif() - # Get the list of autogenerated files include(CMakeAutogenFile.txt) diff --git a/wrapper/Convert/SireOpenMM/LambdaLever.pypp.cpp b/wrapper/Convert/SireOpenMM/LambdaLever.pypp.cpp index 891398949..2e4d52185 100644 --- a/wrapper/Convert/SireOpenMM/LambdaLever.pypp.cpp +++ b/wrapper/Convert/SireOpenMM/LambdaLever.pypp.cpp @@ -217,17 +217,68 @@ void register_LambdaLever_class(){ } { //::SireOpenMM::LambdaLever::setForceIndex - + typedef void ( ::SireOpenMM::LambdaLever::*setForceIndex_function_type)( ::QString const &,int ) ; setForceIndex_function_type setForceIndex_function_value( &::SireOpenMM::LambdaLever::setForceIndex ); - - LambdaLever_exposer.def( + + LambdaLever_exposer.def( "setForceIndex" , setForceIndex_function_value , ( bp::arg("force"), bp::arg("index") ) , bp::release_gil_policy() , "Set the index of the force called force in the OpenMM System.\n There can only be one force with this name. Attempts to add\n a duplicate will cause an error to be raised.\n" ); - + + } + { //::SireOpenMM::LambdaLever::setForceGroup + + typedef void ( ::SireOpenMM::LambdaLever::*setForceGroup_function_type)( ::QString const &,int ) ; + setForceGroup_function_type setForceGroup_function_value( &::SireOpenMM::LambdaLever::setForceGroup ); + + LambdaLever_exposer.def( + "setForceGroup" + , setForceGroup_function_value + , ( bp::arg("name"), bp::arg("group_idx") ) + , bp::release_gil_policy() + , "Set the force group index for the named force." ); + + } + { //::SireOpenMM::LambdaLever::getForceGroup + + typedef int ( ::SireOpenMM::LambdaLever::*getForceGroup_function_type)( ::QString const & ) const; + getForceGroup_function_type getForceGroup_function_value( &::SireOpenMM::LambdaLever::getForceGroup ); + + LambdaLever_exposer.def( + "getForceGroup" + , getForceGroup_function_value + , ( bp::arg("name") ) + , bp::release_gil_policy() + , "Get the force group index for the named force. Returns -1 if not found." ); + + } + { //::SireOpenMM::LambdaLever::getForceNames + + typedef ::QStringList ( ::SireOpenMM::LambdaLever::*getForceNames_function_type)( ) const; + getForceNames_function_type getForceNames_function_value( &::SireOpenMM::LambdaLever::getForceNames ); + + LambdaLever_exposer.def( + "getForceNames" + , getForceNames_function_value + , bp::release_gil_policy() + , "Return the names of all forces and restraints that have been assigned a force group index." ); + + } + { //::SireOpenMM::LambdaLever::wasForceChanged + + typedef bool ( ::SireOpenMM::LambdaLever::*wasForceChanged_function_type)( ::QString const & ) const; + wasForceChanged_function_type wasForceChanged_function_value( &::SireOpenMM::LambdaLever::wasForceChanged ); + + LambdaLever_exposer.def( + "wasForceChanged" + , wasForceChanged_function_value + , ( bp::arg("name") ) + , bp::release_gil_policy() + , "Return whether the named force had parameters changed in the last setLambda call." ); + } { //::SireOpenMM::LambdaLever::setLambda diff --git a/wrapper/Convert/SireOpenMM/_sommcontext.py b/wrapper/Convert/SireOpenMM/_sommcontext.py index 8d990cc46..244159be3 100644 --- a/wrapper/Convert/SireOpenMM/_sommcontext.py +++ b/wrapper/Convert/SireOpenMM/_sommcontext.py @@ -96,11 +96,28 @@ def __init__( ) self._map = map + + # Build the force group map from the lambda lever and initialise + # the per-group energy cache. + self._force_group_map = {} # force name → group index + for name in self._lambda_lever.get_force_names(): + grp = self._lambda_lever.get_force_group(name) + if grp >= 0: + self._force_group_map[name] = grp + + self._energy_cache = {} # group index → energy in kJ/mol + # All groups are dirty on first call. + self._dirty_groups = set(self._force_group_map.values()) + self._prev_rest2_scale = self._rest2_scale else: self._atom_index = None self._lambda_lever = None self._lambda_value = 0.0 self._map = None + self._force_group_map = {} + self._energy_cache = {} + self._dirty_groups = set() + self._prev_rest2_scale = 1.0 self._is_non_pert_rest2 = False @@ -278,10 +295,28 @@ def set_lambda( update_constraints=update_constraints, ) + # Mark force groups whose parameters changed. + for name, grp in self._force_group_map.items(): + if self._lambda_lever.was_force_changed(name): + self._dirty_groups.add(grp) + + # A REST2 scale change also affects CLJ and torsion even if the + # perturbable lambda parameters didn't change. + if rest2_scale != self._prev_rest2_scale: + for name in ("clj", "torsion", "ghost/ghost", "ghost/non-ghost"): + if name in self._force_group_map: + self._dirty_groups.add(self._force_group_map[name]) + self._prev_rest2_scale = rest2_scale + # Update any additional parameters in the REST2 region. if self._is_non_pert_rest2 and rest2_scale != self._rest2_scale: self._update_rest2(lambda_value, rest2_scale) self._rest2_scale = rest2_scale + # _update_rest2 modifies nonbonded and torsion forces directly; + # mark those groups as dirty. + for name in ("clj", "torsion"): + if name in self._force_group_map: + self._dirty_groups.add(self._force_group_map[name]) def get_rest2_scale(self): """ @@ -317,18 +352,86 @@ def set_surface_tension(self, surface_tension): def get_potential_energy(self, to_sire_units: bool = True): """ - Calculate and return the potential energy of the system + Calculate and return the potential energy of the system. + + Uses energy caching: if no force groups have been marked dirty since + the last call (i.e. neither lambda nor positions changed), the cached + total is returned without any GPU call. Otherwise a single full + getState() evaluation is performed and the result cached. + + Falls back to a full getState() evaluation when no force group map is + available (null context or no perturbable forces). """ - s = self.getState(getEnergy=True) - nrg = s.getPotentialEnergy() + import openmm + + if not self._force_group_map: + # No force group information available; fall back to full evaluation. + s = self.getState(getEnergy=True) + nrg = s.getPotentialEnergy() + if to_sire_units: + from ...units import kcal_per_mol + + return ( + nrg.value_in_unit(openmm.unit.kilocalorie_per_mole) * kcal_per_mol + ) + else: + return nrg + + if self._dirty_groups: + # One or more groups have changed so re-evaluate with a single + # full getState call rather than N per-group calls. Multiple + # small masked calls carry per-call GPU synchronisation overhead + # that outweighs any saving from skipping clean groups. + total_kj = ( + self.getState(getEnergy=True) + .getPotentialEnergy() + .value_in_unit(openmm.unit.kilojoule_per_mole) + ) + self._energy_cache = {"_total": total_kj} + self._dirty_groups.clear() + else: + total_kj = self._energy_cache["_total"] if to_sire_units: - import openmm from ...units import kcal_per_mol - return nrg.value_in_unit(openmm.unit.kilocalorie_per_mole) * kcal_per_mol + return (total_kj / 4.184) * kcal_per_mol else: - return nrg + return total_kj * openmm.unit.kilojoule_per_mole + + def setPositions(self, positions, *args, **kwargs): + """ + Set the positions of all particles. Overridden to automatically + invalidate the per-force-group energy cache. + """ + super().setPositions(positions, *args, **kwargs) + self.clear_energy_cache() + + def setState(self, state, *args, **kwargs): + """ + Set the complete state of the context (positions, velocities, box + vectors). Overridden to automatically invalidate the per-force-group + energy cache. + """ + super().setState(state, *args, **kwargs) + self.clear_energy_cache() + + def setPeriodicBoxVectors(self, a, b, c, *args, **kwargs): + """ + Set the periodic box vectors. Overridden to automatically invalidate + the per-force-group energy cache, since a box change affects PME energy. + """ + super().setPeriodicBoxVectors(a, b, c, *args, **kwargs) + self.clear_energy_cache() + + def clear_energy_cache(self): + """ + Invalidate the energy cache. Call this whenever positions change + (e.g. after dynamics steps) so that the next get_potential_energy() + call fully re-evaluates the system. + """ + self._energy_cache.clear() + self._dirty_groups = set(self._force_group_map.values()) def get_energy(self, to_sire_units: bool = True): """ diff --git a/wrapper/Convert/SireOpenMM/lambdalever.cpp b/wrapper/Convert/SireOpenMM/lambdalever.cpp index 1c701c85a..04ce9bfc0 100644 --- a/wrapper/Convert/SireOpenMM/lambdalever.cpp +++ b/wrapper/Convert/SireOpenMM/lambdalever.cpp @@ -53,8 +53,15 @@ MolLambdaCache::MolLambdaCache(double lam) { } +MolLambdaCache::MolLambdaCache(double lam, const MolLambdaCache &prev) + : lam_val(lam) +{ + QReadLocker lkr(&(const_cast(&prev)->lock)); + prev_cache = prev.cache; +} + MolLambdaCache::MolLambdaCache(const MolLambdaCache &other) - : lam_val(other.lam_val), cache(other.cache) + : lam_val(other.lam_val), cache(other.cache), prev_cache(other.prev_cache) { } @@ -68,11 +75,43 @@ MolLambdaCache &MolLambdaCache::operator=(const MolLambdaCache &other) { lam_val = other.lam_val; cache = other.cache; + prev_cache = other.prev_cache; } return *this; } +bool MolLambdaCache::hasChanged(const QString &force, const QString &key) const +{ + return this->hasChanged(force, key, QString()); +} + +bool MolLambdaCache::hasChanged(const QString &force, const QString &key, + const QString &subkey) const +{ + if (prev_cache.isEmpty()) + return true; + + QString cache_key = key; + if (not subkey.isEmpty()) + cache_key += ("::" + subkey); + + const QString force_key = force + "::" + cache_key; + + const auto prev_it = prev_cache.constFind(force_key); + if (prev_it == prev_cache.constEnd()) + return true; + + auto nonconst_this = const_cast(this); + QReadLocker lkr(&(nonconst_this->lock)); + + const auto curr_it = cache.constFind(force_key); + if (curr_it == cache.constEnd()) + return true; + + return curr_it.value() != prev_it.value(); +} + const QVector &MolLambdaCache::morph(const LambdaSchedule &schedule, const QString &force, const QString &key, @@ -155,7 +194,8 @@ LeverCache::LeverCache() { } -LeverCache::LeverCache(const LeverCache &other) : cache(other.cache) +LeverCache::LeverCache(const LeverCache &other) + : cache(other.cache), prev_lam_vals(other.prev_lam_vals) { } @@ -168,6 +208,7 @@ LeverCache &LeverCache::operator=(const LeverCache &other) if (this != &other) { cache = other.cache; + prev_lam_vals = other.prev_lam_vals; } return *this; @@ -188,10 +229,25 @@ const MolLambdaCache &LeverCache::get(int molidx, double lam_val) const if (it == mol_cache.constEnd()) { - // need to create a new cache for this lambda value - it = mol_cache.insert(lam_val, MolLambdaCache(lam_val)); + // Create a new cache for this lambda value, initialising prev_cache + // from the previous lambda's computed values so hasChanged() works. + auto prev_it = nonconst_this->prev_lam_vals.constFind(molidx); + if (prev_it != nonconst_this->prev_lam_vals.constEnd()) + { + auto old_it = mol_cache.constFind(prev_it.value()); + if (old_it != mol_cache.constEnd()) + it = mol_cache.insert(lam_val, MolLambdaCache(lam_val, old_it.value())); + else + it = mol_cache.insert(lam_val, MolLambdaCache(lam_val)); + } + else + { + it = mol_cache.insert(lam_val, MolLambdaCache(lam_val)); + } } + nonconst_this->prev_lam_vals[molidx] = lam_val; + return it.value(); } @@ -204,7 +260,9 @@ void LeverCache::clear() ////// Implementation of LambdaLever ////// -LambdaLever::LambdaLever() : SireBase::ConcreteProperty() +LambdaLever::LambdaLever() + : SireBase::ConcreteProperty(), + last_rest2_scale(-1.0) { } @@ -212,11 +270,13 @@ LambdaLever::LambdaLever(const LambdaLever &other) : SireBase::ConcreteProperty(other), name_to_ffidx(other.name_to_ffidx), name_to_restraintidx(other.name_to_restraintidx), + name_to_groupidx(other.name_to_groupidx), lambda_schedule(other.lambda_schedule), perturbable_mols(other.perturbable_mols), start_indices(other.start_indices), perturbable_maps(other.perturbable_maps), - lambda_cache(other.lambda_cache) + lambda_cache(other.lambda_cache), + last_rest2_scale(-1.0) { } @@ -230,6 +290,7 @@ LambdaLever &LambdaLever::operator=(const LambdaLever &other) { name_to_ffidx = other.name_to_ffidx; name_to_restraintidx = other.name_to_restraintidx; + name_to_groupidx = other.name_to_groupidx; lambda_schedule = other.lambda_schedule; perturbable_mols = other.perturbable_mols; start_indices = other.start_indices; @@ -324,6 +385,56 @@ QString LambdaLever::getForceType(const QString &name, return QString::fromStdString(force.getName()); } +/** Set the force group index for the force called 'name'. */ +void LambdaLever::setForceGroup(const QString &name, int group_idx) +{ + name_to_groupidx.insert(name, group_idx); +} + +/** Set the force group index for the restraint called 'name'. + * Unlike setForceGroup, this is a no-op if the name is already registered, + * since multiple restraint forces can share the same name and group. + */ +void LambdaLever::setRestraintForceGroup(const QString &name, int group_idx) +{ + if (!name_to_groupidx.contains(name)) + name_to_groupidx.insert(name, group_idx); +} + +/** Get the force group index for the force called 'name'. + * Returns -1 if there is no force with this name. + */ +int LambdaLever::getForceGroup(const QString &name) const +{ + auto it = name_to_groupidx.constFind(name); + + if (it == name_to_groupidx.constEnd()) + return -1; + + return it.value(); +} + +/** Return the names of all forces and restraints that have been assigned + * a force group index. + */ +QStringList LambdaLever::getForceNames() const +{ + return name_to_groupidx.keys(); +} + +/** Return whether the named force had parameters changed in the last + * setLambda call. Returns false if the name is not recognised. + */ +bool LambdaLever::wasForceChanged(const QString &name) const +{ + auto it = last_changed_forces.constFind(name); + + if (it == last_changed_forces.constEnd()) + return false; + + return it.value(); +} + boost::tuple get_exception(int atom0, int atom1, int start_index, double coul_14_scl, double lj_14_scl, @@ -1161,6 +1272,12 @@ double LambdaLever::setLambda(OpenMM::Context &context, // scale factor. rest2_scale = 1.0 / rest2_scale; + // Detect whether REST2 scaling changed since the last setLambda call. + // REST2 is applied on top of morphed parameters, so a change in scale + // requires re-uploading parameters even if morphed values are unchanged. + const bool rest2_changed = (rest2_scale != last_rest2_scale); + last_rest2_scale = rest2_scale; + // Store the REST charge scaling factor for non-bonded interactions. const auto sqrt_rest2_scale = std::sqrt(rest2_scale); @@ -1185,6 +1302,8 @@ double LambdaLever::setLambda(OpenMM::Context &context, // whether the constraints have changed bool have_constraints_changed = false; + // whether any CMAP map parameters were set (tracked to defer updateParametersInContext) + std::vector custom_params = {0.0, 0.0, 0.0, 0.0, 0.0}; if (qmff != 0) @@ -1193,18 +1312,14 @@ double LambdaLever::setLambda(OpenMM::Context &context, qmff->setLambda(lam); } - // record the range of indices of the atoms, bonds, angles, - // torsions which change - int start_change_atom = -1; - int end_change_atom = -1; - int start_change_14 = -1; - int end_change_14 = -1; - int start_change_bond = -1; - int end_change_bond = -1; - int start_change_angle = -1; - int end_change_angle = -1; - int start_change_torsion = -1; - int end_change_torsion = -1; + // track whether parameters actually changed for each force, so we only + // call updateParametersInContext when necessary + bool has_changed_cljff = false; + bool has_changed_ghost14ff = false; + bool has_changed_bondff = false; + bool has_changed_angff = false; + bool has_changed_dihff = false; + bool has_changed_cmap = false; // change the parameters for all of the perturbable molecules for (int i = 0; i < this->perturbable_mols.count(); ++i) @@ -1370,15 +1485,10 @@ double LambdaLever::setLambda(OpenMM::Context &context, const int nparams = morphed_charges.count(); - if (start_change_atom == -1) - { - start_change_atom = start_index; - end_change_atom = start_index + nparams; - } - else if (start_index >= end_change_atom) - { - end_change_atom = start_index + nparams; - } + // Detect whether any CLJ or ghost-14 parameters changed + has_changed_cljff |= rest2_changed || cache.hasChanged("clj", "charge") || cache.hasChanged("clj", "sigma") || cache.hasChanged("clj", "epsilon") || cache.hasChanged("clj", "alpha") || cache.hasChanged("clj", "kappa") || cache.hasChanged("clj", "charge_scale") || cache.hasChanged("clj", "lj_scale") || cache.hasChanged("ghost/ghost", "charge") || cache.hasChanged("ghost/ghost", "sigma") || cache.hasChanged("ghost/ghost", "epsilon") || cache.hasChanged("ghost/ghost", "alpha") || cache.hasChanged("ghost/ghost", "kappa") || cache.hasChanged("ghost/non-ghost", "charge") || cache.hasChanged("ghost/non-ghost", "sigma") || cache.hasChanged("ghost/non-ghost", "epsilon") || cache.hasChanged("ghost/non-ghost", "alpha") || cache.hasChanged("ghost/non-ghost", "kappa"); + + has_changed_ghost14ff |= rest2_changed || cache.hasChanged("ghost-14", "charge") || cache.hasChanged("ghost-14", "sigma") || cache.hasChanged("ghost-14", "epsilon") || cache.hasChanged("ghost-14", "alpha") || cache.hasChanged("ghost-14", "kappa") || cache.hasChanged("ghost-14", "charge_scale") || cache.hasChanged("ghost-14", "lj_scale"); if (have_ghost_atoms) { @@ -1455,7 +1565,7 @@ double LambdaLever::setLambda(OpenMM::Context &context, else { cljff->setParticleParameters( - start_index + j, sqrt_scale* morphed_charges[j], + start_index + j, sqrt_scale * morphed_charges[j], morphed_sigmas[j], scale * morphed_epsilons[j]); } } @@ -1475,7 +1585,7 @@ double LambdaLever::setLambda(OpenMM::Context &context, } cljff->setParticleParameters(start_index + j, sqrt_scale * morphed_charges[j], - morphed_sigmas[j], scale * morphed_epsilons[j]); + morphed_sigmas[j], scale * morphed_epsilons[j]); } } @@ -1538,8 +1648,8 @@ double LambdaLever::setLambda(OpenMM::Context &context, if (nbidx < 0) throw SireError::program_bug(QObject::tr( - "Unset NB14 index for a ghost atom?"), - CODELOC); + "Unset NB14 index for a ghost atom?"), + CODELOC); coul_14_scale = morphed_ghost14_charge_scale[j]; lj_14_scale = morphed_ghost14_lj_scale[j]; @@ -1558,22 +1668,7 @@ double LambdaLever::setLambda(OpenMM::Context &context, boost::get<3>(p), 4.0 * boost::get<4>(p) * scale, boost::get<5>(p), - boost::get<6>(p) - }; - - if (start_change_14 == -1) - { - start_change_14 = nbidx; - end_change_14 = nbidx + 1; - } - else - { - if (nbidx < start_change_14) - start_change_14 = nbidx; - - if (nbidx + 1 > end_change_14) - end_change_14 = nbidx + 1; - } + boost::get<6>(p)}; ghost_14ff->setBondParameters(nbidx, boost::get<0>(p), @@ -1647,20 +1742,7 @@ double LambdaLever::setLambda(OpenMM::Context &context, const int nparams = morphed_bond_k.count(); - if (start_change_bond == -1) - { - start_change_bond = start_index; - end_change_bond = start_index + nparams; - } - else if (start_index < start_change_bond) - { - start_change_bond = start_index; - } - - if (start_index + nparams > end_change_bond) - { - end_change_bond = start_index + nparams; - } + has_changed_bondff |= cache.hasChanged("bond", "bond_k") || cache.hasChanged("bond", "bond_length"); for (int j = 0; j < nparams; ++j) { @@ -1670,11 +1752,11 @@ double LambdaLever::setLambda(OpenMM::Context &context, double length, k; bondff->getBondParameters(index, particle1, particle2, - length, k); + length, k); bondff->setBondParameters(index, particle1, particle2, - morphed_bond_length[j], - morphed_bond_k[j]); + morphed_bond_length[j], + morphed_bond_k[j]); } } @@ -1696,20 +1778,7 @@ double LambdaLever::setLambda(OpenMM::Context &context, const int nparams = morphed_angle_k.count(); - if (start_change_angle == -1) - { - start_change_angle = start_index; - end_change_angle = start_index + nparams; - } - else if (start_index < start_change_angle) - { - start_change_angle = start_index; - } - - if (start_index + nparams > end_change_angle) - { - end_change_angle = start_index + nparams; - } + has_changed_angff |= cache.hasChanged("angle", "angle_k") || cache.hasChanged("angle", "angle_size"); for (int j = 0; j < nparams; ++j) { @@ -1719,13 +1788,13 @@ double LambdaLever::setLambda(OpenMM::Context &context, double size, k; angff->getAngleParameters(index, - particle1, particle2, particle3, - size, k); + particle1, particle2, particle3, + size, k); angff->setAngleParameters(index, - particle1, particle2, particle3, - morphed_angle_size[j], - morphed_angle_k[j]); + particle1, particle2, particle3, + morphed_angle_size[j], + morphed_angle_k[j]); } } @@ -1749,20 +1818,7 @@ double LambdaLever::setLambda(OpenMM::Context &context, const auto is_improper = perturbable_mol.getIsImproper(); - if (start_change_torsion == -1) - { - start_change_torsion = start_index; - end_change_torsion = start_index + nparams; - } - else if (start_index < start_change_torsion) - { - start_change_torsion = start_index; - } - - if (start_index + nparams > end_change_torsion) - { - end_change_torsion = start_index + nparams; - } + has_changed_dihff |= rest2_changed || cache.hasChanged("torsion", "torsion_k") || cache.hasChanged("torsion", "torsion_phase"); for (int j = 0; j < nparams; ++j) { @@ -1821,101 +1877,73 @@ double LambdaLever::setLambda(OpenMM::Context &context, grid0, grid1); - int offset = 0; - - for (int j = 0; j < sizes.count(); ++j) + if (rest2_changed or cache.hasChanged("cmap", "cmap_grid")) { - const int N = sizes[j]; - const int map_size = N * N; + has_changed_cmap = true; - // CMAP is always a proper backbone dihedral pair, never improper. - // Apply REST2 scaling if all 5 atoms are within the REST2 region. - double scale = 1.0; + int offset = 0; - const auto &cmap_atms = atoms[j]; - - if (perturbable_mol.isRest2(boost::get<0>(cmap_atms)) and - perturbable_mol.isRest2(boost::get<1>(cmap_atms)) and - perturbable_mol.isRest2(boost::get<2>(cmap_atms)) and - perturbable_mol.isRest2(boost::get<3>(cmap_atms)) and - perturbable_mol.isRest2(boost::get<4>(cmap_atms))) + for (int j = 0; j < sizes.count(); ++j) { - scale = rest2_scale; - } + const int N = sizes[j]; + const int map_size = N * N; - std::vector energy(map_size); + // CMAP is always a proper backbone dihedral pair, never improper. + // Apply REST2 scaling if all 5 atoms are within the REST2 region. + double scale = 1.0; - for (int k = 0; k < map_size; ++k) - { - energy[k] = morphed_grids[offset + k] * scale; - } + const auto &cmap_atms = atoms[j]; - cmapff->setMapParameters(start_index + j, N, energy); - offset += map_size; - } + if (perturbable_mol.isRest2(boost::get<0>(cmap_atms)) and + perturbable_mol.isRest2(boost::get<1>(cmap_atms)) and + perturbable_mol.isRest2(boost::get<2>(cmap_atms)) and + perturbable_mol.isRest2(boost::get<3>(cmap_atms)) and + perturbable_mol.isRest2(boost::get<4>(cmap_atms))) + { + scale = rest2_scale; + } + + std::vector energy(map_size); - cmapff->updateParametersInContext(context); + for (int k = 0; k < map_size; ++k) + { + energy[k] = morphed_grids[offset + k] * scale; + } + + cmapff->setMapParameters(start_index + j, N, energy); + offset += map_size; + } + } } } - // update the parameters in the context - const auto num_changed_atoms = end_change_atom - start_change_atom; - const auto num_changed_bonds = end_change_bond - start_change_bond; - const auto num_changed_angles = end_change_angle - start_change_angle; - const auto num_changed_torsions = end_change_torsion - start_change_torsion; - const auto num_changed_14 = end_change_14 - start_change_14; - - if (num_changed_atoms > 0) + // update the parameters in the context for forces whose parameters changed + if (has_changed_cljff) { if (cljff) -#ifdef SIRE_HAS_UPDATE_SOME_IN_CONTEXT - cljff->updateSomeParametersInContext(start_change_atom, num_changed_atoms, context); -#else cljff->updateParametersInContext(context); -#endif if (ghost_ghostff) -#ifdef SIRE_HAS_UPDATE_SOME_IN_CONTEXT - ghost_ghostff->updateSomeParametersInContext(start_change_atom, num_changed_atoms, context); -#else ghost_ghostff->updateParametersInContext(context); -#endif if (ghost_nonghostff) -#ifdef SIRE_HAS_UPDATE_SOME_IN_CONTEXT - ghost_nonghostff->updateSomeParametersInContext(start_change_atom, num_changed_atoms, context); -#else ghost_nonghostff->updateParametersInContext(context); -#endif } - if (ghost_14ff and num_changed_14 > 0) -#ifdef SIRE_HAS_UPDATE_SOME_IN_CONTEXT - ghost_14ff->updateSomeParametersInContext(start_change_14, num_changed_14, context); -#else + if (ghost_14ff and has_changed_ghost14ff) ghost_14ff->updateParametersInContext(context); -#endif - if (bondff and num_changed_bonds > 0) -#ifdef SIRE_HAS_UPDATE_SOME_IN_CONTEXT - bondff->updateSomeParametersInContext(start_change_bond, num_changed_bonds, context); -#else + if (bondff and has_changed_bondff) bondff->updateParametersInContext(context); -#endif - if (angff and num_changed_angles > 0) -#ifdef SIRE_HAS_UPDATE_SOME_IN_CONTEXT - angff->updateSomeParametersInContext(start_change_angle, num_changed_angles, context); -#else + if (angff and has_changed_angff) angff->updateParametersInContext(context); -#endif - if (dihff and num_changed_torsions > 0) -#ifdef SIRE_HAS_UPDATE_SOME_IN_CONTEXT - dihff->updateSomeParametersInContext(start_change_torsion, num_changed_torsions, context); -#else + if (dihff and has_changed_dihff) dihff->updateParametersInContext(context); -#endif + + if (cmapff and has_changed_cmap) + cmapff->updateParametersInContext(context); // now update any restraints that are scaled for (const auto &restraint : this->name_to_restraintidx.keys()) @@ -1928,11 +1956,18 @@ double LambdaLever::setLambda(OpenMM::Context &context, 1.0, 1.0, lambda_value); - for (auto &ff : this->getRestraints(restraint, system)) + const double prev_rho = last_restraint_rho.value(restraint, -1.0); + last_restraint_rho[restraint] = rho; + last_changed_forces[restraint] = (rho != prev_rho); + + if (rho != prev_rho) { - if (ff != 0) + for (auto &ff : this->getRestraints(restraint, system)) { - this->updateRestraintInContext(*ff, rho, context); + if (ff != 0) + { + this->updateRestraintInContext(*ff, rho, context); + } } } } @@ -1946,6 +1981,17 @@ double LambdaLever::setLambda(OpenMM::Context &context, context.reinitialize(true); } + // record which named forces had parameters changed in this call + last_changed_forces["clj"] = has_changed_cljff; + last_changed_forces["ghost/ghost"] = has_changed_cljff; + last_changed_forces["ghost/non-ghost"] = has_changed_cljff; + last_changed_forces["ghost-14"] = has_changed_ghost14ff; + last_changed_forces["bond"] = has_changed_bondff; + last_changed_forces["angle"] = has_changed_angff; + last_changed_forces["torsion"] = has_changed_dihff; + last_changed_forces["cmap"] = has_changed_cmap; + last_changed_forces["qmff"] = false; + return lambda_value; } diff --git a/wrapper/Convert/SireOpenMM/lambdalever.h b/wrapper/Convert/SireOpenMM/lambdalever.h index b37a446e1..f80ed6077 100644 --- a/wrapper/Convert/SireOpenMM/lambdalever.h +++ b/wrapper/Convert/SireOpenMM/lambdalever.h @@ -49,6 +49,7 @@ namespace SireOpenMM public: MolLambdaCache(); MolLambdaCache(double lam_val); + MolLambdaCache(double lam_val, const MolLambdaCache &prev); MolLambdaCache(const MolLambdaCache &other); ~MolLambdaCache(); @@ -65,9 +66,14 @@ namespace SireOpenMM const QVector &initial, const QVector &final) const; + bool hasChanged(const QString &force, const QString &key) const; + bool hasChanged(const QString &force, const QString &key, + const QString &subkey) const; + private: QHash> cache; - QReadWriteLock lock; + QHash> prev_cache; + mutable QReadWriteLock lock; double lam_val; }; @@ -86,6 +92,7 @@ namespace SireOpenMM private: QHash> cache; + QHash prev_lam_vals; }; /** This is a lever that is used to change the parameters in an OpenMM @@ -152,6 +159,12 @@ namespace SireOpenMM QString getForceType(const QString &name, const OpenMM::System &system) const; + void setForceGroup(const QString &name, int group_idx); + void setRestraintForceGroup(const QString &name, int group_idx); + int getForceGroup(const QString &name) const; + QStringList getForceNames() const; + bool wasForceChanged(const QString &name) const; + protected: void updateRestraintInContext(OpenMM::Force &ff, double rho, OpenMM::Context &context) const; @@ -163,6 +176,10 @@ namespace SireOpenMM * Note that multiple restraints can have the same name */ QMultiHash name_to_restraintidx; + /** Map from a force or restraint name to its OpenMM force group index. + * Multiple restraint forces sharing the same name share one group. */ + QHash name_to_groupidx; + /** The schedule used to set lambda */ SireCAS::LambdaSchedule lambda_schedule; @@ -178,6 +195,19 @@ namespace SireOpenMM /** Cache of the parameters for different lambda values */ LeverCache lambda_cache; + + /** Records which forces had parameters changed in the last setLambda + * call. Mutable so it can be updated from the const setLambda method. */ + mutable QHash last_changed_forces; + + /** Records the rho value used for each restraint in the last setLambda + * call, so we can detect when restraint parameters actually change. */ + mutable QHash last_restraint_rho; + + /** Records the REST2 scale factor used in the last setLambda call, + * so we can detect when it changes (REST2 scaling is applied on top + * of the morphed parameters, so a change requires re-setting params). */ + mutable double last_rest2_scale; }; #ifndef SIRE_SKIP_INLINE_FUNCTION diff --git a/wrapper/Convert/SireOpenMM/sire_to_openmm_system.cpp b/wrapper/Convert/SireOpenMM/sire_to_openmm_system.cpp index 7a9978632..fdb09c3e7 100644 --- a/wrapper/Convert/SireOpenMM/sire_to_openmm_system.cpp +++ b/wrapper/Convert/SireOpenMM/sire_to_openmm_system.cpp @@ -72,7 +72,7 @@ using namespace SireOpenMM; */ void _add_boresch_restraints(const SireMM::BoreschRestraints &restraints, OpenMM::System &system, LambdaLever &lambda_lever, - int natoms) + int natoms, int &force_group_counter) { if (restraints.isEmpty()) return; @@ -137,8 +137,10 @@ void _add_boresch_restraints(const SireMM::BoreschRestraints &restraints, restraintff->setUsesPeriodicBoundaryConditions(restraints.usesPbc()); + restraintff->setForceGroup(force_group_counter); lambda_lever.addRestraintIndex(restraints.name(), system.addForce(restraintff)); + lambda_lever.setRestraintForceGroup(restraints.name(), force_group_counter++); const double internal_to_nm = (1 * SireUnits::angstrom).to(SireUnits::nanometer); const double internal_to_k = (1 * SireUnits::kcal_per_mol / (SireUnits::angstrom2)).to(SireUnits::kJ_per_mol / SireUnits::nanometer2); @@ -197,7 +199,7 @@ void _add_boresch_restraints(const SireMM::BoreschRestraints &restraints, */ void _add_bond_restraints(const SireMM::BondRestraints &restraints, OpenMM::System &system, LambdaLever &lambda_lever, - int natoms) + int natoms, int &force_group_counter) { if (restraints.isEmpty()) return; @@ -224,8 +226,10 @@ void _add_bond_restraints(const SireMM::BondRestraints &restraints, restraintff->setUsesPeriodicBoundaryConditions(restraints.usesPbc()); + restraintff->setForceGroup(force_group_counter); lambda_lever.addRestraintIndex(restraints.name(), system.addForce(restraintff)); + lambda_lever.setRestraintForceGroup(restraints.name(), force_group_counter++); const auto atom_restraints = restraints.atomRestraints(); @@ -263,7 +267,7 @@ void _add_bond_restraints(const SireMM::BondRestraints &restraints, void _add_inverse_bond_restraints(const SireMM::InverseBondRestraints &restraints, OpenMM::System &system, LambdaLever &lambda_lever, - int natoms) + int natoms, int &force_group_counter) { if (restraints.isEmpty()) return; @@ -290,8 +294,10 @@ void _add_inverse_bond_restraints(const SireMM::InverseBondRestraints &restraint restraintff->setUsesPeriodicBoundaryConditions(restraints.usesPbc()); + restraintff->setForceGroup(force_group_counter); lambda_lever.addRestraintIndex(restraints.name(), system.addForce(restraintff)); + lambda_lever.setRestraintForceGroup(restraints.name(), force_group_counter++); const auto atom_restraints = restraints.atomRestraints(); @@ -334,7 +340,7 @@ void _add_inverse_bond_restraints(const SireMM::InverseBondRestraints &restraint */ void _add_morse_potential_restraints(const SireMM::MorsePotentialRestraints &restraints, OpenMM::System &system, LambdaLever &lambda_lever, - int natoms) + int natoms, int &force_group_counter) { if (restraints.isEmpty()) return; @@ -368,8 +374,10 @@ void _add_morse_potential_restraints(const SireMM::MorsePotentialRestraints &res restraintff->setUsesPeriodicBoundaryConditions(restraints.usesPbc()); + restraintff->setForceGroup(force_group_counter); lambda_lever.addRestraintIndex(restraints.name(), system.addForce(restraintff)); + lambda_lever.setRestraintForceGroup(restraints.name(), force_group_counter++); const auto atom_restraints = restraints.atomRestraints(); const double internal_to_nm = (1 * SireUnits::angstrom).to(SireUnits::nanometer); @@ -415,7 +423,7 @@ void _add_morse_potential_restraints(const SireMM::MorsePotentialRestraints &res void _add_positional_restraints(const SireMM::PositionalRestraints &restraints, OpenMM::System &system, LambdaLever &lambda_lever, std::vector &anchor_coords, - int natoms) + int natoms, int &force_group_counter) { if (restraints.isEmpty()) return; @@ -442,8 +450,10 @@ void _add_positional_restraints(const SireMM::PositionalRestraints &restraints, restraintff->setUsesPeriodicBoundaryConditions(restraints.usesPbc()); + restraintff->setForceGroup(force_group_counter); lambda_lever.addRestraintIndex(restraints.name(), system.addForce(restraintff)); + lambda_lever.setRestraintForceGroup(restraints.name(), force_group_counter++); const auto atom_restraints = restraints.atomRestraints(); @@ -549,7 +559,7 @@ void _add_positional_restraints(const SireMM::PositionalRestraints &restraints, void _add_rmsd_restraints(const SireMM::RMSDRestraints &restraints, OpenMM::System &system, LambdaLever &lambda_lever, - int natoms) + int natoms, int &force_group_counter) { if (restraints.isEmpty()) return; @@ -631,8 +641,17 @@ void _add_rmsd_restraints(const SireMM::RMSDRestraints &restraints, auto *rmsdCV = new OpenMM::RMSDForce(referencePositions, particles); restraintff->addCollectiveVariable(rmsd_unique, rmsdCV); + // All sub-restraints with the same name share a single force group so + // that one getState(groups=...) call sums their energies correctly. + int grp = lambda_lever.getForceGroup(restraints.name()); + if (grp < 0) + { + grp = force_group_counter++; + } + restraintff->setForceGroup(grp); lambda_lever.addRestraintIndex(restraints.name(), system.addForce(restraintff)); + lambda_lever.setRestraintForceGroup(restraints.name(), grp); // Update the counter for number of CustomCVForces n_CVForces++; @@ -645,7 +664,7 @@ void _add_rmsd_restraints(const SireMM::RMSDRestraints &restraints, */ void _add_angle_restraints(const SireMM::AngleRestraints &restraints, OpenMM::System &system, LambdaLever &lambda_lever, - int natoms) + int natoms, int &force_group_counter) { if (restraints.isEmpty()) return; @@ -672,8 +691,10 @@ void _add_angle_restraints(const SireMM::AngleRestraints &restraints, restraintff->setUsesPeriodicBoundaryConditions(restraints.usesPbc()); + restraintff->setForceGroup(force_group_counter); lambda_lever.addRestraintIndex(restraints.name(), system.addForce(restraintff)); + lambda_lever.setRestraintForceGroup(restraints.name(), force_group_counter++); const double internal_to_ktheta = (1 * SireUnits::kcal_per_mol / (SireUnits::radian2)).to(SireUnits::kJ_per_mol / SireUnits::radian2); @@ -703,7 +724,7 @@ void _add_angle_restraints(const SireMM::AngleRestraints &restraints, void _add_dihedral_restraints(const SireMM::DihedralRestraints &restraints, OpenMM::System &system, LambdaLever &lambda_lever, - int natoms) + int natoms, int &force_group_counter) { if (restraints.isEmpty()) return; @@ -735,8 +756,10 @@ void _add_dihedral_restraints(const SireMM::DihedralRestraints &restraints, restraintff->setUsesPeriodicBoundaryConditions(restraints.usesPbc()); + restraintff->setForceGroup(force_group_counter); lambda_lever.addRestraintIndex(restraints.name(), system.addForce(restraintff)); + lambda_lever.setRestraintForceGroup(restraints.name(), force_group_counter++); const double internal_to_ktheta = (1 * SireUnits::kcal_per_mol / (SireUnits::radian2)).to(SireUnits::kJ_per_mol / SireUnits::radian2); @@ -1209,17 +1232,26 @@ OpenMMMetaData SireOpenMM::sire_to_openmm_system(OpenMM::System &system, LambdaSchedule::standard_morph()); } + // Each named force is placed into its own force group so that energies + // can be queried and cached per-group. The counter starts at 0 and + // increments for each named force added to the system. + int force_group_counter = 0; + // Add any QM force first so that we can guarantee that it is index zero. if (qmff != 0) { + qmff->setForceGroup(force_group_counter); lambda_lever.setForceIndex("qmff", system.addForce(qmff)); + lambda_lever.setForceGroup("qmff", force_group_counter++); lambda_lever.addLever("qm_scale"); } // We can now add the standard forces to the OpenMM::System. // We do this here, so that we can capture the index of the // force and associate it with a name in the lever. + cljff->setForceGroup(force_group_counter); lambda_lever.setForceIndex("clj", system.addForce(cljff)); + lambda_lever.setForceGroup("clj", force_group_counter++); // We also want to name the levers available for this force, // e.g. we can change the charge, sigma and epsilon parameters @@ -1232,21 +1264,24 @@ OpenMMMetaData SireOpenMM::sire_to_openmm_system(OpenMM::System &system, lambda_lever.addLever("lj_scale"); // Do the same for the bond, angle and torsion forces + bondff->setForceGroup(force_group_counter); lambda_lever.setForceIndex("bond", system.addForce(bondff)); + lambda_lever.setForceGroup("bond", force_group_counter++); lambda_lever.addLever("bond_length"); lambda_lever.addLever("bond_k"); + angff->setForceGroup(force_group_counter); lambda_lever.setForceIndex("angle", system.addForce(angff)); + lambda_lever.setForceGroup("angle", force_group_counter++); lambda_lever.addLever("angle_size"); lambda_lever.addLever("angle_k"); + dihff->setForceGroup(force_group_counter); lambda_lever.setForceIndex("torsion", system.addForce(dihff)); + lambda_lever.setForceGroup("torsion", force_group_counter++); lambda_lever.addLever("torsion_phase"); lambda_lever.addLever("torsion_k"); - lambda_lever.setForceIndex("cmap", system.addForce(cmapff)); - lambda_lever.addLever("cmap_grid"); - /// /// Stage 4 - define the forces for ghost atoms /// @@ -1521,9 +1556,17 @@ OpenMMMetaData SireOpenMM::sire_to_openmm_system(OpenMM::System &system, ghost_nonghostff->setNonbondedMethod(OpenMM::CustomNonbondedForce::NoCutoff); } + ghost_ghostff->setForceGroup(force_group_counter); lambda_lever.setForceIndex("ghost/ghost", system.addForce(ghost_ghostff)); + lambda_lever.setForceGroup("ghost/ghost", force_group_counter++); + + ghost_nonghostff->setForceGroup(force_group_counter); lambda_lever.setForceIndex("ghost/non-ghost", system.addForce(ghost_nonghostff)); + lambda_lever.setForceGroup("ghost/non-ghost", force_group_counter++); + + ghost_14ff->setForceGroup(force_group_counter); lambda_lever.setForceIndex("ghost-14", system.addForce(ghost_14ff)); + lambda_lever.setForceGroup("ghost-14", force_group_counter++); } // Stage 4 is complete. We now have all(*) of the forces we need to run @@ -2000,6 +2043,22 @@ OpenMMMetaData SireOpenMM::sire_to_openmm_system(OpenMM::System &system, /// Stage 5 is complete. We have added all of the parameter data /// for the molecules to the OpenMM forces + // Only register the CMAP force if terms were actually added during the + // molecule loop. An empty CMAPTorsionForce wastes a force-group slot and + // launches a zero-work kernel on every step. + if (cmapff->getNumMaps() > 0) + { + cmapff->setForceGroup(force_group_counter); + lambda_lever.setForceIndex("cmap", system.addForce(cmapff)); + lambda_lever.setForceGroup("cmap", force_group_counter++); + lambda_lever.addLever("cmap_grid"); + } + else + { + delete cmapff; + cmapff = nullptr; + } + /// /// Stage 6 - Set up the exceptions and perturbable constraints /// @@ -2231,42 +2290,43 @@ OpenMMMetaData SireOpenMM::sire_to_openmm_system(OpenMM::System &system, if (prop.read().isA()) { _add_dihedral_restraints(prop.read().asA(), - system, lambda_lever, start_index); + system, lambda_lever, start_index, force_group_counter); } else if (prop.read().isA()) { _add_angle_restraints(prop.read().asA(), - system, lambda_lever, start_index); + system, lambda_lever, start_index, force_group_counter); } else if (prop.read().isA()) { _add_positional_restraints(prop.read().asA(), - system, lambda_lever, anchor_coords, start_index); + system, lambda_lever, anchor_coords, start_index, + force_group_counter); } else if (prop.read().isA()) { _add_morse_potential_restraints(prop.read().asA(), - system, lambda_lever, start_index); + system, lambda_lever, start_index, force_group_counter); } else if (prop.read().isA()) { _add_rmsd_restraints(prop.read().asA(), - system, lambda_lever, start_index); + system, lambda_lever, start_index, force_group_counter); } else if (prop.read().isA()) { _add_bond_restraints(prop.read().asA(), - system, lambda_lever, start_index); + system, lambda_lever, start_index, force_group_counter); } else if (prop.read().isA()) { _add_inverse_bond_restraints(prop.read().asA(), - system, lambda_lever, start_index); + system, lambda_lever, start_index, force_group_counter); } else if (prop.read().isA()) { _add_boresch_restraints(prop.read().asA(), - system, lambda_lever, start_index); + system, lambda_lever, start_index, force_group_counter); } } } @@ -2291,7 +2351,7 @@ OpenMMMetaData SireOpenMM::sire_to_openmm_system(OpenMM::System &system, if (prop.read().isA()) { _add_inverse_bond_restraints(prop.read().asA(), - system, lambda_lever, start_index); + system, lambda_lever, start_index, force_group_counter); } } }