Skip to content

Commit 8badc05

Browse files
committed
simplification to PyMC.print_coefficients
1 parent b6f5ca8 commit 8badc05

File tree

3 files changed

+38
-37
lines changed

3 files changed

+38
-37
lines changed

causalpy/pymc_models.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,23 @@ def print_row(
237237
formatted_val = f"{round_num(coeff_samples.mean().data, round_to)}, 94% HDI [{round_num(coeff_samples.quantile(0.03).data, round_to)}, {round_num(coeff_samples.quantile(1 - 0.03).data, round_to)}]" # noqa: E501
238238
print(f" {formatted_name} {formatted_val}")
239239

240+
def print_coefficients_for_unit(
241+
unit_coeffs: xr.DataArray,
242+
unit_sigma: xr.DataArray,
243+
labels: list,
244+
round_to: int,
245+
) -> None:
246+
"""Print coefficients for a single unit"""
247+
# Determine the width of the longest label
248+
max_label_length = max(len(name) for name in labels + ["sigma"])
249+
250+
for name in labels:
251+
coeff_samples = unit_coeffs.sel(coeffs=name)
252+
print_row(max_label_length, name, coeff_samples, round_to)
253+
254+
# Add coefficient for measurement std
255+
print_row(max_label_length, "sigma", unit_sigma, round_to)
256+
240257
print("Model coefficients:")
241258
coeffs = az.extract(self.idata.posterior, var_names="beta")
242259

@@ -247,32 +264,16 @@ def print_row(
247264
for unit in treated_units:
248265
print(f"\nTreated unit: {unit}")
249266
unit_coeffs = coeffs.sel(treated_units=unit)
250-
251-
# Determine the width of the longest label
252-
max_label_length = max(len(name) for name in labels + ["sigma"])
253-
254-
for name in labels:
255-
coeff_samples = unit_coeffs.sel(coeffs=name)
256-
print_row(max_label_length, name, coeff_samples, round_to or 2)
257-
258-
# Add coefficient for measurement std for this unit
259267
unit_sigma = az.extract(self.idata.posterior, var_names="sigma").sel(
260268
treated_units=unit
261269
)
262-
print_row(max_label_length, "sigma", unit_sigma, round_to or 2)
270+
print_coefficients_for_unit(
271+
unit_coeffs, unit_sigma, labels, round_to or 2
272+
)
263273
else:
264274
# Single treated unit case (backward compatibility)
265-
# Determine the width of the longest label
266-
max_label_length = max(len(name) for name in labels + ["sigma"])
267-
268-
for name in labels:
269-
coeff_samples = coeffs.sel(coeffs=name)
270-
print_row(max_label_length, name, coeff_samples, round_to or 2)
271-
272-
# Add coefficient for measurement std
273-
coeff_samples = az.extract(self.idata.posterior, var_names="sigma")
274-
name = "sigma"
275-
print_row(max_label_length, name, coeff_samples, round_to or 2)
275+
unit_sigma = az.extract(self.idata.posterior, var_names="sigma")
276+
print_coefficients_for_unit(coeffs, unit_sigma, labels, round_to or 2)
276277

277278

278279
class LinearRegression(PyMCModel):

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

docs/source/notebooks/multi_cell_geolift.ipynb

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@
626626
{
627627
"data": {
628628
"application/vnd.jupyter.widget-view+json": {
629-
"model_id": "23d7ea6ad64243c5ab00ac8b1ad0d19b",
629+
"model_id": "d7dc0c0854004ee8b87dbe88f4484e83",
630630
"version_major": 2,
631631
"version_minor": 0
632632
},
@@ -848,7 +848,7 @@
848848
{
849849
"data": {
850850
"application/vnd.jupyter.widget-view+json": {
851-
"model_id": "31dce179e1a047ddbfa198b8f82d97d4",
851+
"model_id": "8d6a96eecd9d4e4aba5894120d647af4",
852852
"version_major": 2,
853853
"version_minor": 0
854854
},
@@ -873,7 +873,7 @@
873873
"name": "stderr",
874874
"output_type": "stream",
875875
"text": [
876-
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 36 seconds.\n",
876+
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 35 seconds.\n",
877877
"Sampling: [beta, sigma, y_hat]\n",
878878
"Sampling: [y_hat]\n",
879879
"Sampling: [y_hat]\n",
@@ -941,7 +941,7 @@
941941
"<text text-anchor=\"middle\" x=\"131\" y=\"-404.74\" font-family=\"Times,serif\" font-size=\"14.00\">Data</text>\n",
942942
"</g>\n",
943943
"<!-- mu -->\n",
944-
"<g id=\"node4\" class=\"node\">\n",
944+
"<g id=\"node2\" class=\"node\">\n",
945945
"<title>mu</title>\n",
946946
"<polygon fill=\"none\" stroke=\"black\" points=\"193.12,-325.23 102.88,-325.23 102.88,-267.73 193.12,-267.73 193.12,-325.23\"/>\n",
947947
"<text text-anchor=\"middle\" x=\"148\" y=\"-307.93\" font-family=\"Times,serif\" font-size=\"14.00\">mu</text>\n",
@@ -955,15 +955,21 @@
955955
"<polygon fill=\"black\" stroke=\"black\" points=\"146.23,-337.34 144.08,-326.96 139.29,-336.41 146.23,-337.34\"/>\n",
956956
"</g>\n",
957957
"<!-- y_hat -->\n",
958-
"<g id=\"node2\" class=\"node\">\n",
958+
"<g id=\"node3\" class=\"node\">\n",
959959
"<title>y_hat</title>\n",
960960
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"150\" cy=\"-174.66\" rx=\"41.01\" ry=\"40.66\"/>\n",
961961
"<text text-anchor=\"middle\" x=\"150\" y=\"-186.11\" font-family=\"Times,serif\" font-size=\"14.00\">y_hat</text>\n",
962962
"<text text-anchor=\"middle\" x=\"150\" y=\"-169.61\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
963963
"<text text-anchor=\"middle\" x=\"150\" y=\"-153.11\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
964964
"</g>\n",
965+
"<!-- mu&#45;&gt;y_hat -->\n",
966+
"<g id=\"edge5\" class=\"edge\">\n",
967+
"<title>mu&#45;&gt;y_hat</title>\n",
968+
"<path fill=\"none\" stroke=\"black\" d=\"M148.47,-267.38C148.67,-255.28 148.91,-240.73 149.15,-226.83\"/>\n",
969+
"<polygon fill=\"black\" stroke=\"black\" points=\"152.64,-226.94 149.31,-216.88 145.65,-226.82 152.64,-226.94\"/>\n",
970+
"</g>\n",
965971
"<!-- y -->\n",
966-
"<g id=\"node3\" class=\"node\">\n",
972+
"<g id=\"node4\" class=\"node\">\n",
967973
"<title>y</title>\n",
968974
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M165,-98C165,-98 135,-98 135,-98 129,-98 123,-92 123,-86 123,-86 123,-52.5 123,-52.5 123,-46.5 129,-40.5 135,-40.5 135,-40.5 165,-40.5 165,-40.5 171,-40.5 177,-46.5 177,-52.5 177,-52.5 177,-86 177,-86 177,-92 171,-98 165,-98\"/>\n",
969975
"<text text-anchor=\"middle\" x=\"150\" y=\"-80.7\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n",
@@ -976,12 +982,6 @@
976982
"<path fill=\"none\" stroke=\"black\" d=\"M150,-133.68C150,-125.83 150,-117.6 150,-109.76\"/>\n",
977983
"<polygon fill=\"black\" stroke=\"black\" points=\"153.5,-109.79 150,-99.79 146.5,-109.79 153.5,-109.79\"/>\n",
978984
"</g>\n",
979-
"<!-- mu&#45;&gt;y_hat -->\n",
980-
"<g id=\"edge5\" class=\"edge\">\n",
981-
"<title>mu&#45;&gt;y_hat</title>\n",
982-
"<path fill=\"none\" stroke=\"black\" d=\"M148.47,-267.38C148.67,-255.28 148.91,-240.73 149.15,-226.83\"/>\n",
983-
"<polygon fill=\"black\" stroke=\"black\" points=\"152.64,-226.94 149.31,-216.88 145.65,-226.82 152.64,-226.94\"/>\n",
984-
"</g>\n",
985985
"<!-- beta -->\n",
986986
"<g id=\"node5\" class=\"node\">\n",
987987
"<title>beta</title>\n",
@@ -1014,7 +1014,7 @@
10141014
"</svg>\n"
10151015
],
10161016
"text/plain": [
1017-
"<graphviz.graphs.Digraph at 0x1273e0ec0>"
1017+
"<graphviz.graphs.Digraph at 0x106560ec0>"
10181018
]
10191019
},
10201020
"execution_count": 15,

0 commit comments

Comments
 (0)