-
Notifications
You must be signed in to change notification settings - Fork 728
Description
Right now, the test framework (#1841) only checks if the metadata is a dict, but we should add the metadata testing to to test if the metadata of datamodule is calculated correctly and passed correctly to dataloaders and model for initialisation.
The blocker is different data_modules might have different metadata configurations so we can't directly add these tests to the integration test.
One idea proposed by @fkiraly:
Create an intermediate model class (between _BasePtForecasterV2 and ModelMetadata class) where we add a method _check_metadata that can take the specific test for metadata of datamodule based on which data module the model is using , i.e, if the model is using EncoderDecoderDataModule , _check_metadata will test the metadata of the EncoderDecoderDataModule. As this will be a parent of ModelMetadata, there will no duplication of code
So the idea is:
class _BaseObject(_SkbaseBaseObject):
pass
class _BasePtForecaster(_BaseObject):
pass
class _BasePtForecasterV2(_BasePtForecaster):
_tags = {
"object_type": "forecaster_pytorch_v2",
}
class _EncoderDecoderConfigBase(_BasePtForecasterV2):
def _check_metadata():
# metadata tests
class TFTMetadata(_EncoderDecoderConfigBase):
# testing logicif this is a TSLibDataModule then, the ConfigBase class will change and will have tests specific for TSLibDataModule