@@ -24,6 +24,48 @@ class TopKCategoricalAccuracy(Metric):
2424 device: specifies which device updates are accumulated on. Setting the
2525 metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
2626 non-blocking. By default, CPU.
27+
28+ Examples:
29+ To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
30+ The output of the engine's ``process_function`` needs to be in the format of
31+ ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
32+ to the metric to transform the output into the form expected by the metric.
33+
34+ .. testcode::
35+
36+ def process_function(engine, batch):
37+ y_pred, y = batch
38+ return y_pred, y
39+
40+ def one_hot_to_binary_output_transform(output):
41+ y_pred, y = output
42+ y = torch.argmax(y, dim=1) # one-hot vector to label index vector
43+ return y_pred, y
44+
45+ engine = Engine(process_function)
46+ metric = TopKCategoricalAccuracy(
47+ k=2, output_transform=one_hot_to_binary_output_transform)
48+ metric.attach(engine, 'top_k_accuracy')
49+
50+ preds = torch.Tensor([
51+ [0.7, 0.2, 0.05, 0.05], # 1 is in the top 2
52+ [0.2, 0.3, 0.4, 0.1], # 0 is not in the top 2
53+ [0.4, 0.4, 0.1, 0.1], # 0 is in the top 2
54+ [0.7, 0.05, 0.2, 0.05] # 2 is in the top 2
55+ ])
56+ target = torch.Tensor([ # targets as one-hot vectors
57+ [0, 1, 0, 0],
58+ [1, 0, 0, 0],
59+ [1, 0, 0, 0],
60+ [0, 0, 1, 0]
61+ ])
62+
63+ state = engine.run([[preds, target]])
64+ print(state.metrics['top_k_accuracy'])
65+
66+ .. testoutput::
67+
68+ 0.75
2769 """
2870
2971 def __init__ (
0 commit comments