diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 3412e5ac5..5c1305cc1 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -46,6 +46,7 @@ Available Datasets datasets/pyhealth.datasets.BMDHSDataset datasets/pyhealth.datasets.COVID19CXRDataset datasets/pyhealth.datasets.ChestXray14Dataset + datasets/pyhealth.datasets.HAM10000Dataset datasets/pyhealth.datasets.TUABDataset datasets/pyhealth.datasets.TUEVDataset datasets/pyhealth.datasets.ClinVarDataset diff --git a/docs/api/datasets/pyhealth.datasets.HAM10000Dataset.rst b/docs/api/datasets/pyhealth.datasets.HAM10000Dataset.rst new file mode 100644 index 000000000..a20d320c7 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.HAM10000Dataset.rst @@ -0,0 +1,20 @@ +pyhealth.datasets.HAM10000Dataset +=================================== + +The **HAM10000** (Human Against Machine with 10000 training images) dataset is a large +collection of multi-source dermatoscopic images of pigmented skin lesions. It is a +widely used benchmark dataset for computer-aided melanoma detection and general +dermatology image classification research. + +Each sample includes an image and associated metadata describing the lesion diagnosis, +patient sex, age, and localization. + +Refer to the original dataset and documentation for more information: + +- Dataset page: https://www.kaggle.com/datasets/kmader/skin-cancer-mnist-ham10000 +- Dataset paper: *Tschandl et al., “The HAM10000 dataset: A large collection of multi-source dermatoscopic images of pigmented lesions.”* Scientific Data, 2018. + +.. autoclass:: pyhealth.datasets.HAM10000Dataset + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 7d6a65f16..e99eb6578 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -84,3 +84,4 @@ def __init__(self, *args, **kwargs): load_processors, save_processors, ) +from .ham10000 import HAM10000Dataset \ No newline at end of file diff --git a/pyhealth/datasets/configs/ham10000.yaml b/pyhealth/datasets/configs/ham10000.yaml new file mode 100644 index 000000000..68aa9852e --- /dev/null +++ b/pyhealth/datasets/configs/ham10000.yaml @@ -0,0 +1,13 @@ +version: "1.0" +tables: + ham10000: + file_path: "ham10000-metadata-pyhealth.csv" + patient_id: "lesion_id" + timestamp: null + attributes: + - "image" + - "path" + - "label" + - "age" + - "sex" + - "localization" \ No newline at end of file diff --git a/pyhealth/datasets/ham10000.py b/pyhealth/datasets/ham10000.py new file mode 100644 index 000000000..3b6fc6d6c --- /dev/null +++ b/pyhealth/datasets/ham10000.py @@ -0,0 +1,120 @@ +""" +PyHealth dataset class for the HAM10000 dermoscopic lesion dataset. + +Dataset link: + https://www.kaggle.com/datasets/kmader/skin-cancer-mnist-ham10000 + +Dataset paper: + Philipp Tschandl et al. "HAM10000 Dataset: A large collection + of multi-source dermatoscopic images of pigmented lesions." + Scientific Data, 2018. + +This dataset contains 10,015 dermoscopic lesion images across 7 classes: + akiec: Actinic keratoses and intraepithelial carcinoma + bcc: Basal cell carcinoma + bkl: Benign keratosis-like lesions + df: Dermatofibroma + mel: Melanoma + nv: Melanocytic nevi + vasc: Vascular lesions + +Author: + Kacper Dural +""" + +import os +import logging +from pathlib import Path +from typing import Optional, List + +import pandas as pd + +from pyhealth.datasets import BaseDataset +from pyhealth.processors import ImageProcessor +from pyhealth.tasks import BaseTask + +logger = logging.getLogger(__name__) + + +class HAM10000Dataset(BaseDataset): + """Dataset class for the HAM10000 dermoscopic image dataset.""" + + classes: List[str] = ["akiec", "bcc", "bkl", "df", "mel", "nv", "vasc"] + + def __init__( + self, + root: str, + config_path: Optional[str] = None, + ): + """ + Args: + root: dataset root directory containing: + - images/ folder with .jpg images + - metadata.csv describing samples + + config_path: path to dataset config (optional) + """ + self.root = root + self.image_dir = os.path.join(root, "images") + self.meta_path = os.path.join(root, "HAM10000_metadata.csv") + + if not os.path.exists(self.meta_path): + raise FileNotFoundError("metadata.csv not found in dataset root.") + + if not os.path.exists(self.image_dir): + raise FileNotFoundError("images/ folder not found in dataset root.") + + # build internal metadata table + self._index_data() + + super().__init__( + root=root, + tables=["ham10000"], + dataset_name="HAM10000", + config_path=config_path, + ) + + @property + def default_task(self) -> BaseTask: + """Default task: multiclass dermatology label classification.""" + from pyhealth.tasks import image_multiclass_classification_fn + return image_multiclass_classification_fn + + def set_task(self, *args, **kwargs): + """Attach an ImageProcessor if user does not supply one.""" + input_processors = kwargs.get("input_processors", None) or {} + if "image" not in input_processors: + input_processors["image"] = ImageProcessor( + image_size=224, + mode="RGB", + ) + kwargs["input_processors"] = input_processors + return super().set_task(*args, **kwargs) + + def _index_data(self) -> None: + """Reads metadata.csv and builds ham10000-metadata-pyhealth.csv.""" + + df = pd.read_csv(self.meta_path) + + if "image_id" not in df.columns or "dx" not in df.columns: + raise ValueError("metadata.csv must contain image_id and dx.") + + # Full paths + df["path"] = df["image_id"].apply( + lambda x: os.path.join(self.image_dir, f"{x}.jpg") + ) + + # Verify images exist + missing = df[~df["path"].apply(os.path.exists)] + if len(missing) > 0: + logger.warning(f"{len(missing)} images listed in metadata were not found.") + + df.rename(columns={ + "dx": "label", + "image_id": "image", + }, inplace=True) + + # Save cleaned metadata file for reproducibility + out_path = os.path.join(self.root, "ham10000-metadata-pyhealth.csv") + df.to_csv(out_path, index=False) + logger.info(f"Saved processed metadata to {out_path}") diff --git a/tests/core/test_ham10000.py b/tests/core/test_ham10000.py new file mode 100644 index 000000000..5ef7b4c7a --- /dev/null +++ b/tests/core/test_ham10000.py @@ -0,0 +1,110 @@ +""" +Unit tests for the HAM10000Dataset class and its associated task. + + +""" + +import os +import shutil +import unittest + +import numpy as np +from PIL import Image + +from pyhealth.datasets import HAM10000Dataset +from pyhealth.tasks import ham10000_multiclass_fn + + +class TestHAM10000Dataset(unittest.TestCase): + def setUp(self): + # Reset test directory + if os.path.exists("test_ham"): + shutil.rmtree("test_ham") + os.makedirs("test_ham/images") + + # Create mock metadata.csv + # Two lesions (lesion_id): L001 and L002 + # Three images total + lines = [ + "lesion_id,image_id,dx,dx_type,age,sex,localization", + "L001,ISIC_0000001,mel,histo,60.0,male,back", + "L001,ISIC_0000002,nv,histo,60.0,male,back", + "L002,ISIC_0000003,bkl,histo,45.0,female,lower extremity", + ] + + meta_path = os.path.join("test_ham", "metadata.csv") + with open(meta_path, "w") as f: + f.write("\n".join(lines)) + + # Create synthetic dermoscopic images + for row in lines[1:]: + image_id = row.split(",")[1] + img_path = os.path.join("test_ham/images", f"{image_id}.jpg") + + # Random RGB image 224x224 + img = Image.fromarray( + np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8), + mode="RGB", + ) + img.save(img_path) + + # Create dataset + self.dataset = HAM10000Dataset(root="test_ham") + + def tearDown(self): + if os.path.exists("test_ham"): + shutil.rmtree("test_ham") + + + def test_stats(self): + """Ensure stats() runs without error.""" + self.dataset.stats() + + def test_num_patients(self): + """lesion_id maps to 'patient_id'; expect 2 unique lesion groups.""" + self.assertEqual(len(self.dataset.unique_patient_ids), 2) + + def test_get_patient_L001(self): + """Lesion L001 has two images/samples.""" + events = self.dataset.get_patient("L001").get_events() + self.assertEqual(len(events), 2) + + first = events[0] + self.assertIn(first["label"], ["mel", "nv"]) # one of the two + + def test_get_patient_L002(self): + """Lesion L002 has one image sample.""" + events = self.dataset.get_patient("L002").get_events() + self.assertEqual(len(events), 1) + self.assertEqual(events[0]["label"], "bkl") + self.assertEqual(events[0]["sex"], "female") + + def test_default_task_present(self): + """Ensure default_task returns a callable task function.""" + t = self.dataset.default_task + self.assertTrue(callable(t)) + + def test_set_task_multiclass(self): + """Apply the HAM10000 multiclass task and verify sample count & labels.""" + samples = self.dataset.set_task(ham10000_multiclass_fn) + + # Expect 3 samples total + self.assertEqual(len(samples), 3) + + # Extract labels (string labels mapped to integers in the task) + labels = [s["label"] for s in samples] + # DX classes present in mock data: mel, nv, bkl + self.assertCountEqual(labels, labels) + + def test_image_loading(self): + """Ensure the image processor hook works and images load correctly.""" + samples = self.dataset.set_task(ham10000_multiclass_fn) + # Just fetch the first image; it should be a transformed tensor + sample = samples[0] + img = sample["image"] + self.assertIsNotNone(img) + # Expect shape (3, H, W) + self.assertEqual(len(img.shape), 3) + +if __name__ == "__main__": + unittest.main()