|
6 | 6 | import ml.comet.experiment.context.ExperimentContext; |
7 | 7 | import ml.comet.experiment.exception.CometGeneralException; |
8 | 8 | import ml.comet.experiment.impl.asset.LoggedExperimentAssetImpl; |
| 9 | +import ml.comet.experiment.model.Curve; |
9 | 10 | import org.apache.commons.lang3.StringUtils; |
10 | 11 | import org.junit.jupiter.api.DisplayName; |
11 | 12 | import org.junit.jupiter.api.Tag; |
|
20 | 21 | import static ml.comet.experiment.impl.ExperimentTestFactory.WORKSPACE_NAME; |
21 | 22 | import static ml.comet.experiment.impl.ExperimentTestFactory.createApiExperiment; |
22 | 23 | import static ml.comet.experiment.impl.ExperimentTestFactory.createOnlineExperiment; |
| 24 | +import static ml.comet.experiment.impl.TestUtils.SOME_FULL_CONTEXT; |
| 25 | +import static ml.comet.experiment.impl.TestUtils.createCurve; |
| 26 | +import static ml.comet.experiment.impl.asset.AssetType.CURVE; |
23 | 27 | import static ml.comet.experiment.impl.asset.AssetType.TEXT_SAMPLE; |
24 | 28 | import static org.junit.jupiter.api.Assertions.assertEquals; |
25 | 29 | import static org.junit.jupiter.api.Assertions.assertThrows; |
@@ -102,4 +106,52 @@ public void testLogTextShort() throws Exception { |
102 | 106 | assertTrue(StringUtils.isBlank(assetContext.getContext()), "no context ID expected"); |
103 | 107 | } |
104 | 108 | } |
| 109 | + |
| 110 | + @Test |
| 111 | + public void testLogCurve() throws Exception { |
| 112 | + try (ApiExperiment apiExperiment = createApiExperiment()) { |
| 113 | + String fileName = "someCurve"; |
| 114 | + Curve curve = createCurve(fileName, 10); |
| 115 | + apiExperiment.logCurve(curve, true, SOME_FULL_CONTEXT); |
| 116 | + |
| 117 | + // check that CURVE asset was saved as expected |
| 118 | + List<LoggedExperimentAsset> assets = apiExperiment.getAssetList(CURVE.type()); |
| 119 | + assertEquals(1, assets.size(), "wrong number of assets returned"); |
| 120 | + |
| 121 | + LoggedExperimentAsset asset = assets.get(0); |
| 122 | + assertEquals(CURVE.type(), asset.getType(), "wrong asset type"); |
| 123 | + assertEquals(0, asset.getMetadata().size(), "no metadata expected"); |
| 124 | + ExperimentContext assetContext = ((LoggedExperimentAssetImpl) asset).getContext(); |
| 125 | + assertEquals(SOME_FULL_CONTEXT.getStep(), assetContext.getStep(), "wrong context step"); |
| 126 | + assertEquals(SOME_FULL_CONTEXT.getContext(), assetContext.getContext(), "wrong context ID"); |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + @Test |
| 131 | + public void testLogCurveOverwrite() throws Exception { |
| 132 | + try (ApiExperiment apiExperiment = createApiExperiment()) { |
| 133 | + String fileName = "someCurve"; |
| 134 | + int pointsCount = 10; |
| 135 | + Curve curve = createCurve(fileName, pointsCount); |
| 136 | + apiExperiment.logCurve(curve, false, SOME_FULL_CONTEXT); |
| 137 | + |
| 138 | + // check that CURVE asset was saved as expected |
| 139 | + List<LoggedExperimentAsset> assets = apiExperiment.getAssetList(CURVE.type()); |
| 140 | + assertEquals(1, assets.size(), "wrong number of assets returned"); |
| 141 | + |
| 142 | + long size = assets.get(0).getSize().orElse((long) -1); |
| 143 | + assertTrue(size > 0, "wrong asset size"); |
| 144 | + |
| 145 | + // overwrite created curve with bigger ones |
| 146 | + // |
| 147 | + curve = createCurve(fileName, pointsCount * 2); |
| 148 | + apiExperiment.logCurve(curve, true, SOME_FULL_CONTEXT); |
| 149 | + |
| 150 | + assets = apiExperiment.getAssetList(CURVE.type()); |
| 151 | + assertEquals(1, assets.size(), "wrong number of assets returned"); |
| 152 | + |
| 153 | + long newSize = assets.get(0).getSize().orElse((long) -1); |
| 154 | + assertTrue(newSize > size); |
| 155 | + } |
| 156 | + } |
105 | 157 | } |
0 commit comments