|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +from typing import ( |
| 3 | + Optional, |
| 4 | +) |
| 5 | + |
| 6 | +import array_api_compat |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from deepmd.dpmodel.loss.loss import ( |
| 10 | + Loss, |
| 11 | +) |
| 12 | +from deepmd.utils.data import ( |
| 13 | + DataRequirementItem, |
| 14 | +) |
| 15 | +from deepmd.utils.version import ( |
| 16 | + check_version_compatibility, |
| 17 | +) |
| 18 | + |
| 19 | + |
| 20 | +class EnergyLoss(Loss): |
| 21 | + def __init__( |
| 22 | + self, |
| 23 | + starter_learning_rate: float, |
| 24 | + start_pref_e: float = 0.02, |
| 25 | + limit_pref_e: float = 1.00, |
| 26 | + start_pref_f: float = 1000, |
| 27 | + limit_pref_f: float = 1.00, |
| 28 | + start_pref_v: float = 0.0, |
| 29 | + limit_pref_v: float = 0.0, |
| 30 | + start_pref_ae: float = 0.0, |
| 31 | + limit_pref_ae: float = 0.0, |
| 32 | + start_pref_pf: float = 0.0, |
| 33 | + limit_pref_pf: float = 0.0, |
| 34 | + relative_f: Optional[float] = None, |
| 35 | + enable_atom_ener_coeff: bool = False, |
| 36 | + start_pref_gf: float = 0.0, |
| 37 | + limit_pref_gf: float = 0.0, |
| 38 | + numb_generalized_coord: int = 0, |
| 39 | + **kwargs, |
| 40 | + ) -> None: |
| 41 | + self.starter_learning_rate = starter_learning_rate |
| 42 | + self.start_pref_e = start_pref_e |
| 43 | + self.limit_pref_e = limit_pref_e |
| 44 | + self.start_pref_f = start_pref_f |
| 45 | + self.limit_pref_f = limit_pref_f |
| 46 | + self.start_pref_v = start_pref_v |
| 47 | + self.limit_pref_v = limit_pref_v |
| 48 | + self.start_pref_ae = start_pref_ae |
| 49 | + self.limit_pref_ae = limit_pref_ae |
| 50 | + self.start_pref_pf = start_pref_pf |
| 51 | + self.limit_pref_pf = limit_pref_pf |
| 52 | + self.relative_f = relative_f |
| 53 | + self.enable_atom_ener_coeff = enable_atom_ener_coeff |
| 54 | + self.start_pref_gf = start_pref_gf |
| 55 | + self.limit_pref_gf = limit_pref_gf |
| 56 | + self.numb_generalized_coord = numb_generalized_coord |
| 57 | + self.has_e = self.start_pref_e != 0.0 or self.limit_pref_e != 0.0 |
| 58 | + self.has_f = self.start_pref_f != 0.0 or self.limit_pref_f != 0.0 |
| 59 | + self.has_v = self.start_pref_v != 0.0 or self.limit_pref_v != 0.0 |
| 60 | + self.has_ae = self.start_pref_ae != 0.0 or self.limit_pref_ae != 0.0 |
| 61 | + self.has_pf = self.start_pref_pf != 0.0 or self.limit_pref_pf != 0.0 |
| 62 | + self.has_gf = self.start_pref_gf != 0.0 or self.limit_pref_gf != 0.0 |
| 63 | + if self.has_gf and self.numb_generalized_coord < 1: |
| 64 | + raise RuntimeError( |
| 65 | + "When generalized force loss is used, the dimension of generalized coordinates should be larger than 0" |
| 66 | + ) |
| 67 | + |
| 68 | + def call( |
| 69 | + self, |
| 70 | + learning_rate: float, |
| 71 | + natoms: int, |
| 72 | + model_dict: dict[str, np.ndarray], |
| 73 | + label_dict: dict[str, np.ndarray], |
| 74 | + ) -> dict[str, np.ndarray]: |
| 75 | + """Calculate loss from model results and labeled results.""" |
| 76 | + energy = model_dict["energy"] |
| 77 | + force = model_dict["force"] |
| 78 | + virial = model_dict["virial"] |
| 79 | + atom_ener = model_dict["atom_ener"] |
| 80 | + energy_hat = label_dict["energy"] |
| 81 | + force_hat = label_dict["force"] |
| 82 | + virial_hat = label_dict["virial"] |
| 83 | + atom_ener_hat = label_dict["atom_ener"] |
| 84 | + atom_pref = label_dict["atom_pref"] |
| 85 | + find_energy = label_dict["find_energy"] |
| 86 | + find_force = label_dict["find_force"] |
| 87 | + find_virial = label_dict["find_virial"] |
| 88 | + find_atom_ener = label_dict["find_atom_ener"] |
| 89 | + find_atom_pref = label_dict["find_atom_pref"] |
| 90 | + xp = array_api_compat.array_namespace( |
| 91 | + energy, |
| 92 | + force, |
| 93 | + virial, |
| 94 | + atom_ener, |
| 95 | + energy_hat, |
| 96 | + force_hat, |
| 97 | + virial_hat, |
| 98 | + atom_ener_hat, |
| 99 | + atom_pref, |
| 100 | + ) |
| 101 | + |
| 102 | + if self.enable_atom_ener_coeff: |
| 103 | + # when ener_coeff (\nu) is defined, the energy is defined as |
| 104 | + # E = \sum_i \nu_i E_i |
| 105 | + # instead of the sum of atomic energies. |
| 106 | + # |
| 107 | + # A case is that we want to train reaction energy |
| 108 | + # A + B -> C + D |
| 109 | + # E = - E(A) - E(B) + E(C) + E(D) |
| 110 | + # A, B, C, D could be put far away from each other |
| 111 | + atom_ener_coeff = label_dict["atom_ener_coeff"] |
| 112 | + atom_ener_coeff = xp.reshape(atom_ener_coeff, xp.shape(atom_ener)) |
| 113 | + energy = xp.sum(atom_ener_coeff * atom_ener, 1) |
| 114 | + if self.has_f or self.has_pf or self.relative_f or self.has_gf: |
| 115 | + force_reshape = xp.reshape(force, [-1]) |
| 116 | + force_hat_reshape = xp.reshape(force_hat, [-1]) |
| 117 | + diff_f = force_hat_reshape - force_reshape |
| 118 | + else: |
| 119 | + diff_f = None |
| 120 | + |
| 121 | + if self.relative_f is not None: |
| 122 | + force_hat_3 = xp.reshape(force_hat, [-1, 3]) |
| 123 | + norm_f = xp.reshape(xp.norm(force_hat_3, axis=1), [-1, 1]) + self.relative_f |
| 124 | + diff_f_3 = xp.reshape(diff_f, [-1, 3]) |
| 125 | + diff_f_3 = diff_f_3 / norm_f |
| 126 | + diff_f = xp.reshape(diff_f_3, [-1]) |
| 127 | + |
| 128 | + atom_norm = 1.0 / natoms |
| 129 | + atom_norm_ener = 1.0 / natoms |
| 130 | + lr_ratio = learning_rate / self.starter_learning_rate |
| 131 | + pref_e = find_energy * ( |
| 132 | + self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * lr_ratio |
| 133 | + ) |
| 134 | + pref_f = find_force * ( |
| 135 | + self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * lr_ratio |
| 136 | + ) |
| 137 | + pref_v = find_virial * ( |
| 138 | + self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * lr_ratio |
| 139 | + ) |
| 140 | + pref_ae = find_atom_ener * ( |
| 141 | + self.limit_pref_ae + (self.start_pref_ae - self.limit_pref_ae) * lr_ratio |
| 142 | + ) |
| 143 | + pref_pf = find_atom_pref * ( |
| 144 | + self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * lr_ratio |
| 145 | + ) |
| 146 | + |
| 147 | + l2_loss = 0 |
| 148 | + more_loss = {} |
| 149 | + if self.has_e: |
| 150 | + l2_ener_loss = xp.mean(xp.square(energy - energy_hat)) |
| 151 | + l2_loss += atom_norm_ener * (pref_e * l2_ener_loss) |
| 152 | + more_loss["l2_ener_loss"] = self.display_if_exist(l2_ener_loss, find_energy) |
| 153 | + if self.has_f: |
| 154 | + l2_force_loss = xp.mean(xp.square(diff_f)) |
| 155 | + l2_loss += pref_f * l2_force_loss |
| 156 | + more_loss["l2_force_loss"] = self.display_if_exist( |
| 157 | + l2_force_loss, find_force |
| 158 | + ) |
| 159 | + if self.has_v: |
| 160 | + virial_reshape = xp.reshape(virial, [-1]) |
| 161 | + virial_hat_reshape = xp.reshape(virial_hat, [-1]) |
| 162 | + l2_virial_loss = xp.mean( |
| 163 | + xp.square(virial_hat_reshape - virial_reshape), |
| 164 | + ) |
| 165 | + l2_loss += atom_norm * (pref_v * l2_virial_loss) |
| 166 | + more_loss["l2_virial_loss"] = self.display_if_exist( |
| 167 | + l2_virial_loss, find_virial |
| 168 | + ) |
| 169 | + if self.has_ae: |
| 170 | + atom_ener_reshape = xp.reshape(atom_ener, [-1]) |
| 171 | + atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1]) |
| 172 | + l2_atom_ener_loss = xp.mean( |
| 173 | + xp.square(atom_ener_hat_reshape - atom_ener_reshape), |
| 174 | + ) |
| 175 | + l2_loss += pref_ae * l2_atom_ener_loss |
| 176 | + more_loss["l2_atom_ener_loss"] = self.display_if_exist( |
| 177 | + l2_atom_ener_loss, find_atom_ener |
| 178 | + ) |
| 179 | + if self.has_pf: |
| 180 | + atom_pref_reshape = xp.reshape(atom_pref, [-1]) |
| 181 | + l2_pref_force_loss = xp.mean( |
| 182 | + xp.multiply(xp.square(diff_f), atom_pref_reshape), |
| 183 | + ) |
| 184 | + l2_loss += pref_pf * l2_pref_force_loss |
| 185 | + more_loss["l2_pref_force_loss"] = self.display_if_exist( |
| 186 | + l2_pref_force_loss, find_atom_pref |
| 187 | + ) |
| 188 | + if self.has_gf: |
| 189 | + find_drdq = label_dict["find_drdq"] |
| 190 | + drdq = label_dict["drdq"] |
| 191 | + force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3]) |
| 192 | + force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3]) |
| 193 | + drdq_reshape = xp.reshape( |
| 194 | + drdq, [-1, natoms[0] * 3, self.numb_generalized_coord] |
| 195 | + ) |
| 196 | + gen_force_hat = xp.einsum( |
| 197 | + "bij,bi->bj", drdq_reshape, force_hat_reshape_nframes |
| 198 | + ) |
| 199 | + gen_force = xp.einsum("bij,bi->bj", drdq_reshape, force_reshape_nframes) |
| 200 | + diff_gen_force = gen_force_hat - gen_force |
| 201 | + l2_gen_force_loss = xp.mean(xp.square(diff_gen_force)) |
| 202 | + pref_gf = find_drdq * ( |
| 203 | + self.limit_pref_gf |
| 204 | + + (self.start_pref_gf - self.limit_pref_gf) * lr_ratio |
| 205 | + ) |
| 206 | + l2_loss += pref_gf * l2_gen_force_loss |
| 207 | + more_loss["l2_gen_force_loss"] = self.display_if_exist( |
| 208 | + l2_gen_force_loss, find_drdq |
| 209 | + ) |
| 210 | + |
| 211 | + self.l2_l = l2_loss |
| 212 | + self.l2_more = more_loss |
| 213 | + return l2_loss, more_loss |
| 214 | + |
| 215 | + @property |
| 216 | + def label_requirement(self) -> list[DataRequirementItem]: |
| 217 | + """Return data label requirements needed for this loss calculation.""" |
| 218 | + label_requirement = [] |
| 219 | + if self.has_e: |
| 220 | + label_requirement.append( |
| 221 | + DataRequirementItem( |
| 222 | + "energy", |
| 223 | + ndof=1, |
| 224 | + atomic=False, |
| 225 | + must=False, |
| 226 | + high_prec=True, |
| 227 | + ) |
| 228 | + ) |
| 229 | + if self.has_f: |
| 230 | + label_requirement.append( |
| 231 | + DataRequirementItem( |
| 232 | + "force", |
| 233 | + ndof=3, |
| 234 | + atomic=True, |
| 235 | + must=False, |
| 236 | + high_prec=False, |
| 237 | + ) |
| 238 | + ) |
| 239 | + if self.has_v: |
| 240 | + label_requirement.append( |
| 241 | + DataRequirementItem( |
| 242 | + "virial", |
| 243 | + ndof=9, |
| 244 | + atomic=False, |
| 245 | + must=False, |
| 246 | + high_prec=False, |
| 247 | + ) |
| 248 | + ) |
| 249 | + if self.has_ae: |
| 250 | + label_requirement.append( |
| 251 | + DataRequirementItem( |
| 252 | + "atom_ener", |
| 253 | + ndof=1, |
| 254 | + atomic=True, |
| 255 | + must=False, |
| 256 | + high_prec=False, |
| 257 | + ) |
| 258 | + ) |
| 259 | + if self.has_pf: |
| 260 | + label_requirement.append( |
| 261 | + DataRequirementItem( |
| 262 | + "atom_pref", |
| 263 | + ndof=1, |
| 264 | + atomic=True, |
| 265 | + must=False, |
| 266 | + high_prec=False, |
| 267 | + repeat=3, |
| 268 | + ) |
| 269 | + ) |
| 270 | + if self.has_gf > 0: |
| 271 | + label_requirement.append( |
| 272 | + DataRequirementItem( |
| 273 | + "drdq", |
| 274 | + ndof=self.numb_generalized_coord * 3, |
| 275 | + atomic=True, |
| 276 | + must=False, |
| 277 | + high_prec=False, |
| 278 | + ) |
| 279 | + ) |
| 280 | + if self.enable_atom_ener_coeff: |
| 281 | + label_requirement.append( |
| 282 | + DataRequirementItem( |
| 283 | + "atom_ener_coeff", |
| 284 | + ndof=1, |
| 285 | + atomic=True, |
| 286 | + must=False, |
| 287 | + high_prec=False, |
| 288 | + default=1.0, |
| 289 | + ) |
| 290 | + ) |
| 291 | + return label_requirement |
| 292 | + |
| 293 | + def serialize(self) -> dict: |
| 294 | + """Serialize the loss module. |
| 295 | +
|
| 296 | + Returns |
| 297 | + ------- |
| 298 | + dict |
| 299 | + The serialized loss module |
| 300 | + """ |
| 301 | + return { |
| 302 | + "@class": "EnergyLoss", |
| 303 | + "@version": 1, |
| 304 | + "starter_learning_rate": self.starter_learning_rate, |
| 305 | + "start_pref_e": self.start_pref_e, |
| 306 | + "limit_pref_e": self.limit_pref_e, |
| 307 | + "start_pref_f": self.start_pref_f, |
| 308 | + "limit_pref_f": self.limit_pref_f, |
| 309 | + "start_pref_v": self.start_pref_v, |
| 310 | + "limit_pref_v": self.limit_pref_v, |
| 311 | + "start_pref_ae": self.start_pref_ae, |
| 312 | + "limit_pref_ae": self.limit_pref_ae, |
| 313 | + "start_pref_pf": self.start_pref_pf, |
| 314 | + "limit_pref_pf": self.limit_pref_pf, |
| 315 | + "relative_f": self.relative_f, |
| 316 | + "enable_atom_ener_coeff": self.enable_atom_ener_coeff, |
| 317 | + "start_pref_gf": self.start_pref_gf, |
| 318 | + "limit_pref_gf": self.limit_pref_gf, |
| 319 | + "numb_generalized_coord": self.numb_generalized_coord, |
| 320 | + } |
| 321 | + |
| 322 | + @classmethod |
| 323 | + def deserialize(cls, data: dict) -> "Loss": |
| 324 | + """Deserialize the loss module. |
| 325 | +
|
| 326 | + Parameters |
| 327 | + ---------- |
| 328 | + data : dict |
| 329 | + The serialized loss module |
| 330 | +
|
| 331 | + Returns |
| 332 | + ------- |
| 333 | + Loss |
| 334 | + The deserialized loss module |
| 335 | + """ |
| 336 | + data = data.copy() |
| 337 | + check_version_compatibility(data.pop("@version"), 1, 1) |
| 338 | + data.pop("@class") |
| 339 | + return cls(**data) |
0 commit comments