Skip to content

Commit a397ffa

Browse files
authored
[skip ci] Doctest for TopKCategoricalAccuracy (#2284)
* Add doctests for TopKCategoricalAccuracy * Remove extra line
1 parent 245525a commit a397ffa

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

ignite/metrics/top_k_categorical_accuracy.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,48 @@ class TopKCategoricalAccuracy(Metric):
2424
device: specifies which device updates are accumulated on. Setting the
2525
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
2626
non-blocking. By default, CPU.
27+
28+
Examples:
29+
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
30+
The output of the engine's ``process_function`` needs to be in the format of
31+
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
32+
to the metric to transform the output into the form expected by the metric.
33+
34+
.. testcode::
35+
36+
def process_function(engine, batch):
37+
y_pred, y = batch
38+
return y_pred, y
39+
40+
def one_hot_to_binary_output_transform(output):
41+
y_pred, y = output
42+
y = torch.argmax(y, dim=1) # one-hot vector to label index vector
43+
return y_pred, y
44+
45+
engine = Engine(process_function)
46+
metric = TopKCategoricalAccuracy(
47+
k=2, output_transform=one_hot_to_binary_output_transform)
48+
metric.attach(engine, 'top_k_accuracy')
49+
50+
preds = torch.Tensor([
51+
[0.7, 0.2, 0.05, 0.05], # 1 is in the top 2
52+
[0.2, 0.3, 0.4, 0.1], # 0 is not in the top 2
53+
[0.4, 0.4, 0.1, 0.1], # 0 is in the top 2
54+
[0.7, 0.05, 0.2, 0.05] # 2 is in the top 2
55+
])
56+
target = torch.Tensor([ # targets as one-hot vectors
57+
[0, 1, 0, 0],
58+
[1, 0, 0, 0],
59+
[1, 0, 0, 0],
60+
[0, 0, 1, 0]
61+
])
62+
63+
state = engine.run([[preds, target]])
64+
print(state.metrics['top_k_accuracy'])
65+
66+
.. testoutput::
67+
68+
0.75
2769
"""
2870

2971
def __init__(

0 commit comments

Comments
 (0)