Skip to content

Commit 665f836

Browse files
committed
Convert from/to
1 parent 6a90766 commit 665f836

File tree

5 files changed

+120
-47
lines changed

5 files changed

+120
-47
lines changed

docs/api/toc.json

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@
4747
"api": {
4848
"SpinParam": "",
4949
"ValGrad": "",
50-
"length_to": "",
51-
"time_to": "",
52-
"freq_to": "",
53-
"ir_ints_to": "",
54-
"raman_ints_to": "",
55-
"edipole_to": "",
56-
"equadrupole_to": ""
50+
"convert_length": "",
51+
"convert_time": "",
52+
"convert_freq": "",
53+
"convert_ir_ints": "",
54+
"convert_raman_ints": "",
55+
"convert_edipole": "",
56+
"convert_equadrupole": ""
5757
}
5858
}
5959
}

dqc/api/properties.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from dqc.qccalc.base_qccalc import BaseQCCalc
99
from dqc.utils.misc import memoize_method
1010
from dqc.utils.datastruct import SpinParam
11-
from dqc.utils.units import length_to, freq_to, edipole_to, equadrupole_to, ir_ints_to, \
12-
raman_ints_to
11+
from dqc.utils.units import convert_length, convert_freq, convert_edipole, \
12+
convert_equadrupole, convert_ir_ints, \
13+
convert_raman_ints
1314

1415
__all__ = ["hessian_pos", "vibration", "edipole", "equadrupole", "is_orb_min",
1516
"lowest_eival_orb_hessian", "ir_spectrum", "raman_spectrum"]
@@ -35,7 +36,7 @@ def hessian_pos(qc: BaseQCCalc, unit: Optional[str] = None) -> torch.Tensor:
3536
of the energy with respect to the atomic position
3637
"""
3738
hess = _hessian_pos(qc)
38-
hess = length_to(hess, unit)
39+
hess = convert_freq(hess, to_unit=unit)
3940
return hess
4041

4142
def vibration(qc: BaseQCCalc, freq_unit: Optional[str] = "cm^-1",
@@ -64,8 +65,8 @@ def vibration(qc: BaseQCCalc, freq_unit: Optional[str] = "cm^-1",
6465
to each axis sorted from the largest frequency to smallest frequency.
6566
"""
6667
freq, mode = _vibration(qc)
67-
freq = freq_to(freq, freq_unit)
68-
mode = length_to(mode, length_unit)
68+
freq = convert_freq(freq, to_unit=freq_unit)
69+
mode = convert_length(mode, to_unit=length_unit)
6970
return freq, mode
7071

7172
def ir_spectrum(qc: BaseQCCalc, freq_unit: Optional[str] = "cm^-1",
@@ -93,8 +94,8 @@ def ir_spectrum(qc: BaseQCCalc, freq_unit: Optional[str] = "cm^-1",
9394
tensor is the IR intensity with the same order as the frequency.
9495
"""
9596
freq, ir_ints = _ir_spectrum(qc)
96-
freq = freq_to(freq, freq_unit)
97-
ir_ints = ir_ints_to(ir_ints, ints_unit)
97+
freq = convert_freq(freq, to_unit=freq_unit)
98+
ir_ints = convert_ir_ints(ir_ints, to_unit=ints_unit)
9899
return freq, ir_ints
99100

100101
def raman_spectrum(qc: BaseQCCalc, freq_unit: Optional[str] = "cm^-1",
@@ -121,8 +122,8 @@ def raman_spectrum(qc: BaseQCCalc, freq_unit: Optional[str] = "cm^-1",
121122
tensor is the IR intensity with the same order as the frequency.
122123
"""
123124
freq, raman_ints = _raman_spectrum(qc)
124-
freq = freq_to(freq, freq_unit)
125-
raman_ints = raman_ints_to(raman_ints, ints_unit)
125+
freq = convert_freq(freq, to_unit=freq_unit)
126+
raman_ints = convert_raman_ints(raman_ints, to_unit=ints_unit)
126127
return freq, raman_ints
127128

128129
def edipole(qc: BaseQCCalc, unit: Optional[str] = "Debye") -> torch.Tensor:
@@ -144,7 +145,7 @@ def edipole(qc: BaseQCCalc, unit: Optional[str] = "Debye") -> torch.Tensor:
144145
Tensor representing the dipole moment in atomic unit with shape ``(ndim,)``
145146
"""
146147
edip = _edipole(qc)
147-
edip = edipole_to(edip, unit)
148+
edip = convert_edipole(edip, to_unit=unit)
148149
return edip
149150

150151
def equadrupole(qc: BaseQCCalc, unit: Optional[str] = "Debye*Angst") -> torch.Tensor:
@@ -165,7 +166,7 @@ def equadrupole(qc: BaseQCCalc, unit: Optional[str] = "Debye*Angst") -> torch.Te
165166
Tensor representing the quadrupole moment in atomic unit in ``(ndim, ndim)``
166167
"""
167168
equad = _equadrupole(qc)
168-
equad = equadrupole_to(equad, unit)
169+
equad = convert_equadrupole(equad, to_unit=unit)
169170
return equad
170171

171172
@memoize_method

dqc/test/test_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
1+
import torch
2+
import dqc.utils
13
from dqc.utils.config import config
24
from dqc.utils.misc import logger
5+
from dqc.test.utils import assert_fail
6+
7+
def test_converter_length():
8+
a = torch.tensor([1.0])
9+
10+
# convert to itself
11+
assert torch.allclose(dqc.utils.convert_length(a), a)
12+
# convert from atomic unit to angstrom
13+
assert torch.allclose(dqc.utils.convert_length(a, to_unit="angst"), a * 5.29177210903e-1)
14+
# convert from angstrom to atomic unit
15+
assert torch.allclose(dqc.utils.convert_length(a, from_unit="angst"), a / 5.29177210903e-1)
16+
# convert from angstrom to angstrom
17+
assert torch.allclose(dqc.utils.convert_length(a, from_unit="angst", to_unit="angst"), a)
18+
assert torch.allclose(dqc.utils.convert_length(a, from_unit="angst", to_unit="angstrom"), a)
19+
20+
def test_converter_wrong_unit():
21+
a = torch.tensor([1.0])
22+
assert_fail(lambda: dqc.utils.convert_length(a, from_unit="adsfa"), ValueError, ["'angst'"])
323

424
def test_logger(capsys):
525
# test if logger behaves correctly

dqc/test/utils.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
11
import gc
22
import torch
3-
from typing import Callable
3+
from typing import Callable, Optional, List, Union
44

5-
__all__ = ["assert_no_memleak_tensor"]
5+
__all__ = ["assert_fail", "assert_no_memleak_tensor"]
6+
7+
def assert_fail(fcn: Callable, err: Exception = Exception, contains: Optional[Union[str, List[str]]] = None):
8+
try:
9+
fcn()
10+
except err as e:
11+
if isinstance(contains, str):
12+
assert contains in str(e), f"The error message must contain '{contains}'. Got {str(e)} instead."
13+
elif isinstance(contains, list) or isinstance(contains, tuple):
14+
for c in contains:
15+
assert c in str(e), f"The error message must contain '{c}'. Got {str(e)} instead."
16+
return
17+
except Exception as e:
18+
assert False, f"Expected {err} to be raised, got {type(e)} instead:\n{e}"
19+
return
20+
21+
assert False, f"Expected {err} to be raised"
622

723
# memory test functions
824
def assert_no_memleak_tensor(fcn: Callable, strict: bool = True, gccollect: bool = False):

dqc/utils/units.py

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# This file contains various physical constants and functions to convert units
55
# from the atomic units
66

7-
__all__ = ["length_to", "time_to", "freq_to", "ir_ints_to", "raman_ints_to",
8-
"edipole_to", "equadrupole_to"]
7+
__all__ = ["convert_length", "convert_time", "convert_freq", "convert_ir_ints",
8+
"convert_raman_ints", "convert_edipole", "convert_equadrupole"]
99

1010
# 1 atomic unit in SI
1111
LENGTH = 5.29177210903e-11 # m
@@ -81,61 +81,97 @@
8181

8282
def _avail_keys(converter: Dict[str, float]) -> str:
8383
# returns the available keys in a string of list of string
84-
return str(list(_length_converter.keys()))
84+
return str(list(converter.keys()))
8585

8686
def _add_docstr_to(phys: str, converter: Dict[str, float]) -> Callable:
8787
# automatically add docstring for converter functions
8888

8989
def decorator(callable: Callable):
9090
callable.__doc__ = f"""
91-
Convert the {phys} from atomic unit to the given unit.
92-
Available units are (case-insensitive): {_avail_keys(converter)}
91+
Convert the {phys} from a unit to another unit.
92+
Available units are (case-insensitive): ``{_avail_keys(converter)}``
93+
94+
Arguments
95+
---------
96+
a: torch.Tensor
97+
The tensor to be converter.
98+
from_unit: str or None
99+
The unit of ``a``. If ``None``, it is assumed to be in atomic unit.
100+
to_unit: str or None
101+
The unit for ``a`` to be converted to. If ``None``, it is assumed
102+
to be converted to the atomic unit.
103+
104+
Returns
105+
-------
106+
torch.Tensor
107+
The tensor in the new unit.
93108
"""
94109
return callable
95110
return decorator
96111

97112
@_add_docstr_to("time", _time_converter)
98-
def time_to(a: PhysVarType, unit: UnitType) -> PhysVarType:
113+
def convert_time(a: PhysVarType, from_unit: UnitType = None,
114+
to_unit: UnitType = None) -> PhysVarType:
99115
# convert unit time from atomic unit to the given unit
100-
return _converter_to(a, unit, _time_converter)
116+
return _converter(a, from_unit, to_unit, _time_converter)
101117

102118
@_add_docstr_to("frequency", _freq_converter)
103-
def freq_to(a: PhysVarType, unit: UnitType) -> PhysVarType:
119+
def convert_freq(a: PhysVarType, from_unit: UnitType = None,
120+
to_unit: UnitType = None) -> PhysVarType:
104121
# convert unit frequency from atomic unit to the given unit
105-
return _converter_to(a, unit, _freq_converter)
122+
return _converter(a, from_unit, to_unit, _freq_converter)
106123

107124
@_add_docstr_to("IR intensity", _ir_ints_converter)
108-
def ir_ints_to(a: PhysVarType, unit: UnitType) -> PhysVarType:
125+
def convert_ir_ints(a: PhysVarType, from_unit: UnitType = None,
126+
to_unit: UnitType = None) -> PhysVarType:
109127
# convert unit IR intensity from atomic unit to the given unit
110-
return _converter_to(a, unit, _ir_ints_converter)
128+
return _converter(a, from_unit, to_unit, _ir_ints_converter)
111129

112130
@_add_docstr_to("Raman intensity", _raman_ints_converter)
113-
def raman_ints_to(a: PhysVarType, unit: UnitType) -> PhysVarType:
131+
def convert_raman_ints(a: PhysVarType, from_unit: UnitType = None,
132+
to_unit: UnitType = None) -> PhysVarType:
114133
# convert unit IR intensity from atomic unit to the given unit
115-
return _converter_to(a, unit, _raman_ints_converter)
134+
return _converter(a, from_unit, to_unit, _raman_ints_converter)
116135

117136
@_add_docstr_to("length", _length_converter)
118-
def length_to(a: PhysVarType, unit: UnitType) -> PhysVarType:
137+
def convert_length(a: PhysVarType, from_unit: UnitType = None,
138+
to_unit: UnitType = None) -> PhysVarType:
119139
# convert unit length from atomic unit to the given unit
120-
return _converter_to(a, unit, _length_converter)
140+
return _converter(a, from_unit, to_unit, _length_converter)
121141

122142
@_add_docstr_to("electric dipole", _edipole_converter)
123-
def edipole_to(a: PhysVarType, unit: UnitType) -> PhysVarType:
143+
def convert_edipole(a: PhysVarType, from_unit: UnitType = None,
144+
to_unit: UnitType = None) -> PhysVarType:
124145
# convert unit electric dipole from atomic unit to the given unit
125-
return _converter_to(a, unit, _edipole_converter)
146+
return _converter(a, from_unit, to_unit, _edipole_converter)
126147

127148
@_add_docstr_to("electric quadrupole", _equadrupole_converter)
128-
def equadrupole_to(a: PhysVarType, unit: UnitType) -> PhysVarType:
149+
def convert_equadrupole(a: PhysVarType, from_unit: UnitType = None,
150+
to_unit: UnitType = None) -> PhysVarType:
129151
# convert unit electric dipole from atomic unit to the given unit
130-
return _converter_to(a, unit, _equadrupole_converter)
131-
132-
def _converter_to(a: PhysVarType, unit: UnitType, converter: Dict[str, float]) -> PhysVarType:
133-
# converter from the atomic unit
134-
if unit is None:
152+
return _converter(a, from_unit, to_unit, _equadrupole_converter)
153+
154+
def _converter(a: PhysVarType, from_unit: UnitType, to_unit: UnitType,
155+
converter: Dict[str, float]) -> PhysVarType:
156+
# converter from a unit to another unit
157+
from_unit = _preproc_unit(from_unit)
158+
to_unit = _preproc_unit(to_unit)
159+
if from_unit == to_unit:
135160
return a
136-
u = unit.lower()
137-
try:
138-
return a * converter[u]
139-
except KeyError:
161+
if from_unit is not None:
162+
a = a / _get_converter_value(converter, from_unit)
163+
if to_unit is not None:
164+
a = a * _get_converter_value(converter, to_unit)
165+
return a
166+
167+
def _get_converter_value(converter: Dict[str, float], unit: UnitType) -> float:
168+
if unit not in converter:
140169
avail_units = _avail_keys(converter)
141170
raise ValueError(f"Unknown unit: {unit}. Available units are: {avail_units}")
171+
return converter[unit]
172+
173+
def _preproc_unit(unit: UnitType):
174+
if unit is None:
175+
return unit
176+
else:
177+
return ''.join(unit.lower().split())

0 commit comments

Comments
 (0)