Skip to content

Commit 6cbe5fc

Browse files
authored
Allow running pyro_sim.py from anywhere (#121)
I changed setup.py to install an autogenerated shim script that loads pyro.pyro_sim as a module, as otherwise `__file__` would point to the bin directory it gets installed to, rather than the pyro/ directory. With a minor change to test.py, this lets us run the regression tests from outside pyro/ as well.
1 parent 9210631 commit 6cbe5fc

File tree

5 files changed

+24
-16
lines changed

5 files changed

+24
-16
lines changed

.github/workflows/regtest.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,5 @@ jobs:
4444
run: python setup.py install --user
4545

4646
- name: Run tests via test.py
47-
run: |
48-
cd pyro
49-
./test.py
47+
run: ./pyro/test.py
5048

pyro/pyro_sim.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ def compare_to_benchmark(self, rtol):
296296
""" Are we comparing to a benchmark? """
297297

298298
basename = self.rp.get_param("io.basename")
299-
compare_file = "{}/tests/{}{:04d}".format(
300-
self.solver_name, basename, self.sim.n)
299+
compare_file = "{}{}/tests/{}{:04d}".format(
300+
self.pyro_home, self.solver_name, basename, self.sim.n)
301301
msg.warning(f"comparing to: {compare_file} ")
302302
try:
303303
sim_bench = io.read(compare_file)
@@ -317,9 +317,9 @@ def compare_to_benchmark(self, rtol):
317317
def store_as_benchmark(self):
318318
""" Are we storing a benchmark? """
319319

320-
if not os.path.isdir(self.solver_name + "/tests/"):
320+
if not os.path.isdir(self.pyro_home + self.solver_name + "/tests/"):
321321
try:
322-
os.mkdir(self.solver_name + "/tests/")
322+
os.mkdir(self.pyro_home + self.solver_name + "/tests/")
323323
except (FileNotFoundError, PermissionError):
324324
msg.fail(
325325
"ERROR: unable to create the solver's tests/ directory")
@@ -357,7 +357,7 @@ def parse_args():
357357
return p.parse_args()
358358

359359

360-
if __name__ == "__main__":
360+
def main():
361361
args = parse_args()
362362

363363
if args.compare_benchmark or args.make_benchmark:
@@ -371,3 +371,7 @@ def parse_args():
371371
inputs_file=args.param[0],
372372
other_commands=args.other)
373373
pyro.run_sim()
374+
375+
376+
if __name__ == "__main__":
377+
main()

pyro/test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,23 +70,24 @@ def do_tests(out_file,
7070

7171
# standalone tests
7272
if single is None:
73-
err = mg_test_simple.test_poisson_dirichlet(256, comp_bench=True, bench_dir="multigrid/tests/",
73+
bench_dir = os.path.dirname(os.path.realpath(__file__)) + "/multigrid/tests/"
74+
err = mg_test_simple.test_poisson_dirichlet(256, comp_bench=True, bench_dir=bench_dir,
7475
store_bench=store_all_benchmarks, verbose=0)
7576
results["mg_poisson_dirichlet"] = err
7677

7778
err = mg_test_vc_dirichlet.test_vc_poisson_dirichlet(512,
78-
comp_bench=True, bench_dir="multigrid/tests/",
79+
comp_bench=True, bench_dir=bench_dir,
7980
store_bench=store_all_benchmarks, verbose=0)
8081
results["mg_vc_poisson_dirichlet"] = err
8182

82-
err = mg_test_vc_periodic.test_vc_poisson_periodic(512, comp_bench=True, bench_dir="multigrid/tests/",
83+
err = mg_test_vc_periodic.test_vc_poisson_periodic(512, comp_bench=True, bench_dir=bench_dir,
8384
store_bench=store_all_benchmarks,
8485
verbose=0)
8586
results["mg_vc_poisson_periodic"] = err
8687

8788
err = mg_test_general_inhomogeneous.test_general_poisson_inhomogeneous(512,
8889
comp_bench=True,
89-
bench_dir="multigrid/tests/",
90+
bench_dir=bench_dir,
9091
store_bench=store_all_benchmarks,
9192
verbose=0)
9293
results["mg_general_poisson_inhomogeneous"] = err

pyro/util/runparams.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import os
4848
import re
4949
import textwrap
50+
from pathlib import Path
5051

5152
from pyro.util import msg
5253

@@ -118,7 +119,7 @@ def load_params(self, pfile, no_new=0):
118119

119120
# check to see whether the file exists
120121
if not os.path.isfile(pfile):
121-
pfile = "{}/{}".format(os.environ["PYRO_HOME"], pfile)
122+
pfile = str(Path(__file__).resolve().parents[1] / pfile)
122123

123124
try:
124125
f = open(pfile)

setup.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from pathlib import Path
21
import glob
2+
from pathlib import Path
33

4-
from setuptools import setup, find_packages
4+
from setuptools import find_packages, setup
55

66
# find all of the "_default" files
77
defaults = []
@@ -40,7 +40,11 @@
4040
url='https://github.com/python-hydro/pyro2',
4141
license='BSD',
4242
packages=find_packages(),
43-
scripts=["pyro/pyro_sim.py"],
43+
entry_points={
44+
"console_scripts": [
45+
"pyro_sim.py = pyro.pyro_sim:main",
46+
]
47+
},
4448
package_data={"pyro": benchmarks + defaults + inputs},
4549
install_requires=['numpy', 'numba', 'matplotlib', 'h5py'],
4650
use_scm_version={"version_scheme": "post-release",

0 commit comments

Comments
 (0)