|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | + |
| 3 | + |
| 4 | +class RepFlowArgs: |
| 5 | + def __init__( |
| 6 | + self, |
| 7 | + n_dim: int = 128, |
| 8 | + e_dim: int = 64, |
| 9 | + a_dim: int = 64, |
| 10 | + nlayers: int = 6, |
| 11 | + e_rcut: float = 6.0, |
| 12 | + e_rcut_smth: float = 5.0, |
| 13 | + e_sel: int = 120, |
| 14 | + a_rcut: float = 4.0, |
| 15 | + a_rcut_smth: float = 3.5, |
| 16 | + a_sel: int = 20, |
| 17 | + axis_neuron: int = 4, |
| 18 | + node_has_conv: bool = False, |
| 19 | + update_angle: bool = True, |
| 20 | + update_style: str = "res_residual", |
| 21 | + update_residual: float = 0.1, |
| 22 | + update_residual_init: str = "const", |
| 23 | + ) -> None: |
| 24 | + r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor. |
| 25 | +
|
| 26 | + Parameters |
| 27 | + ---------- |
| 28 | + n_dim : int, optional |
| 29 | + The dimension of node representation. |
| 30 | + e_dim : int, optional |
| 31 | + The dimension of edge representation. |
| 32 | + a_dim : int, optional |
| 33 | + The dimension of angle representation. |
| 34 | + nlayers : int, optional |
| 35 | + Number of repflow layers. |
| 36 | + e_rcut : float, optional |
| 37 | + The edge cut-off radius. |
| 38 | + e_rcut_smth : float, optional |
| 39 | + Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth. |
| 40 | + e_sel : int, optional |
| 41 | + Maximally possible number of selected edge neighbors. |
| 42 | + a_rcut : float, optional |
| 43 | + The angle cut-off radius. |
| 44 | + a_rcut_smth : float, optional |
| 45 | + Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth. |
| 46 | + a_sel : int, optional |
| 47 | + Maximally possible number of selected angle neighbors. |
| 48 | + axis_neuron : int, optional |
| 49 | + The number of dimension of submatrix in the symmetrization ops. |
| 50 | + update_angle : bool, optional |
| 51 | + Where to update the angle rep. If not, only node and edge rep will be used. |
| 52 | + update_style : str, optional |
| 53 | + Style to update a representation. |
| 54 | + Supported options are: |
| 55 | + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) |
| 56 | + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) |
| 57 | + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) |
| 58 | + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` |
| 59 | + and `update_residual_init`. |
| 60 | + update_residual : float, optional |
| 61 | + When update using residual mode, the initial std of residual vector weights. |
| 62 | + update_residual_init : str, optional |
| 63 | + When update using residual mode, the initialization mode of residual vector weights. |
| 64 | + """ |
| 65 | + self.n_dim = n_dim |
| 66 | + self.e_dim = e_dim |
| 67 | + self.a_dim = a_dim |
| 68 | + self.nlayers = nlayers |
| 69 | + self.e_rcut = e_rcut |
| 70 | + self.e_rcut_smth = e_rcut_smth |
| 71 | + self.e_sel = e_sel |
| 72 | + self.a_rcut = a_rcut |
| 73 | + self.a_rcut_smth = a_rcut_smth |
| 74 | + self.a_sel = a_sel |
| 75 | + self.axis_neuron = axis_neuron |
| 76 | + self.node_has_conv = node_has_conv # tmp |
| 77 | + self.update_angle = update_angle |
| 78 | + self.update_style = update_style |
| 79 | + self.update_residual = update_residual |
| 80 | + self.update_residual_init = update_residual_init |
| 81 | + |
| 82 | + def __getitem__(self, key): |
| 83 | + if hasattr(self, key): |
| 84 | + return getattr(self, key) |
| 85 | + else: |
| 86 | + raise KeyError(key) |
| 87 | + |
| 88 | + def serialize(self) -> dict: |
| 89 | + return { |
| 90 | + "n_dim": self.n_dim, |
| 91 | + "e_dim": self.e_dim, |
| 92 | + "a_dim": self.a_dim, |
| 93 | + "nlayers": self.nlayers, |
| 94 | + "e_rcut": self.e_rcut, |
| 95 | + "e_rcut_smth": self.e_rcut_smth, |
| 96 | + "e_sel": self.e_sel, |
| 97 | + "a_rcut": self.a_rcut, |
| 98 | + "a_rcut_smth": self.a_rcut_smth, |
| 99 | + "a_sel": self.a_sel, |
| 100 | + "axis_neuron": self.axis_neuron, |
| 101 | + "node_has_conv": self.node_has_conv, # tmp |
| 102 | + "update_angle": self.update_angle, |
| 103 | + "update_style": self.update_style, |
| 104 | + "update_residual": self.update_residual, |
| 105 | + "update_residual_init": self.update_residual_init, |
| 106 | + } |
| 107 | + |
| 108 | + @classmethod |
| 109 | + def deserialize(cls, data: dict) -> "RepFlowArgs": |
| 110 | + return cls(**data) |
0 commit comments