Skip to content

Commit e6a1fdf

Browse files
DevPranjalsdesrozisvfdev-5
authored
[skip ci] Doctests for MeanAbsoluteError and MeanSquaredError (#2280)
* Add doctests for MeanAbsoluteError and MeanSquaredError * Change test output to lower precision * Update doctests to remove random values Co-authored-by: Sylvain Desroziers <sylvain.desroziers@gmail.com> Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 4287069 commit e6a1fdf

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

ignite/metrics/mean_absolute_error.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,36 @@ class MeanAbsoluteError(Metric):
2626
device: specifies which device updates are accumulated on. Setting the
2727
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
2828
non-blocking. By default, CPU.
29+
30+
Examples:
31+
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
32+
The output of the engine's ``process_function`` needs to be in the format of
33+
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
34+
to the metric to transform the output into the form expected by the metric.
35+
36+
``y_pred`` and ``y`` should have the same shape.
37+
38+
.. testcode::
39+
40+
def process_function(engine, batch):
41+
y_pred, y = batch
42+
return y_pred, y
43+
engine = Engine(process_function)
44+
metric = MeanAbsoluteError()
45+
metric.attach(engine, 'mae')
46+
preds = torch.Tensor([
47+
[1, 2, 4, 1],
48+
[2, 3, 1, 5],
49+
[1, 3, 5, 1],
50+
[1, 5, 1 ,11]
51+
])
52+
target = preds * 0.75
53+
state = engine.run([[preds, target]])
54+
print(state.metrics['mae'])
55+
56+
.. testoutput::
57+
58+
2.9375
2959
"""
3060

3161
@reinit__is_reduced

ignite/metrics/mean_squared_error.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,36 @@ class MeanSquaredError(Metric):
2626
device: specifies which device updates are accumulated on. Setting the
2727
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
2828
non-blocking. By default, CPU.
29+
30+
Examples:
31+
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
32+
The output of the engine's ``process_function`` needs to be in the format of
33+
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
34+
to the metric to transform the output into the form expected by the metric.
35+
36+
``y_pred`` and ``y`` should have the same shape.
37+
38+
.. testcode::
39+
40+
def process_function(engine, batch):
41+
y_pred, y = batch
42+
return y_pred, y
43+
engine = Engine(process_function)
44+
metric = MeanSquaredError()
45+
metric.attach(engine, 'mse')
46+
preds = torch.Tensor([
47+
[1, 2, 4, 1],
48+
[2, 3, 1, 5],
49+
[1, 3, 5, 1],
50+
[1, 5, 1 ,11]
51+
])
52+
target = preds * 0.75
53+
state = engine.run([[preds, target]])
54+
print(state.metrics['mse'])
55+
56+
.. testoutput::
57+
58+
3.828125
2959
"""
3060

3161
@reinit__is_reduced

0 commit comments

Comments
 (0)