Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 281 additions & 45 deletions PKU-PosterLayout.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,37 @@
# Copyright 2024 Shunsuke Kitada and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This script was generated from shunk031/cookiecutter-huggingface-datasets.
#

from __future__ import annotations

import ast
import pathlib
from typing import List, Optional, TypedDict, Union, cast
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Sequence, TypedDict, Union, cast

import datasets as ds
import pandas as pd
from datasets.utils.logging import get_logger
from PIL import Image
from PIL.Image import Image as PilImage

if TYPE_CHECKING:
from ralfpt.saliency_detection import SaliencyTester
from simple_lama_inpainting import SimpleLama

logger = get_logger(__name__)

_DESCRIPTION = (
Expand All @@ -16,11 +40,11 @@

_CITATION = """\
@inproceedings{hsu2023posterlayout,
title={PosterLayout: A New Benchmark and Approach for Content-aware Visual-Textual Presentation Layout},
author={Hsu, Hsiao Yuan and He, Xiangteng and Peng, Yuxin and Kong, Hao and Zhang, Qing},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={6018--6026},
year={2023}
title={PosterLayout: A New Benchmark and Approach for Content-aware Visual-Textual Presentation Layout},
author={Hsu, Hsiao Yuan and He, Xiangteng and Peng, Yuxin and Kong, Hao and Zhang, Qing},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={6018--6026},
year={2023}
}
"""

Expand Down Expand Up @@ -127,30 +151,196 @@ def get_canvas_files(base_dir: str) -> List[pathlib.Path]:
return sorted(canvas_dir.iterdir(), key=lambda f: int(f.stem))


def ralf_style_example(
example,
inpainter: SimpleLama,
saliency_testers: List[SaliencyTester],
old_saliency_maps: Sequence[str] = (
"basnet_saliency_map",
"pfpn_saliency_map",
),
new_saliency_maps: Sequence[str] = (
"saliency_map",
"saliency_map_sub",
),
):
from ralfpt.inpainting import apply_inpainting
from ralfpt.saliency_detection import apply_saliency_detection
from ralfpt.transforms import has_valid_area, load_from_pku_ltrb
from ralfpt.typehints import Element

def get_pku_layout_elements(
annotations, image_w: int, image_h: int
) -> List[Element]:
ann, *_ = annotations
total_elem = ann["total_elem"]

elements = []
for i in range(total_elem):
cls_elem = annotations[i]["cls_elem"]
box_elem = annotations[i]["box_elem"]

label = cls_elem

coordinates = load_from_pku_ltrb(
box=box_elem, global_width=image_w, global_height=image_h
)
if has_valid_area(**coordinates):
element: Element = {"label": label, "coordinates": coordinates}
elements.append(element)

return elements

assert len(old_saliency_maps) == len(new_saliency_maps)
assert len(new_saliency_maps) == len(saliency_testers)

annotations = example["annotations"]
is_test = annotations is None

#
# Remove the old saliency maps
#
for old_sal_map in old_saliency_maps:
del example[old_sal_map]

# If there are no annotations, it is a test data, so return it as it is.
if is_test:
saliency_maps = apply_saliency_detection(
image=example["canvas"],
saliency_testers=saliency_testers, # type: ignore
)

for new_sal_map, sal_map in zip(new_saliency_maps, saliency_maps):
example[new_sal_map] = sal_map

return example

image = example["original_poster"]
image_w, image_h = image.size

poster_path = annotations[0]["poster_path"]

#
# Get layout elements
#
try:
elements = get_pku_layout_elements(
annotations=example["annotations"], image_w=image_w, image_h=image_h
)
except AssertionError as e:
logger.warning(f"[{poster_path}] Failed to get layout elements: {e}")

# If the layout elements are not available,
# return the example without inpainting and saliency maps
for new_sal_map in new_saliency_maps:
example[new_sal_map] = None

return example

#
# Apply RALF-style inpainting
#
inpainted_image = apply_inpainting(
image=image, elements=elements, inpainter=inpainter
)
example["inpainted_poster"] = inpainted_image

#
# Apply Ralf-style saliency detection
#
saliency_maps = apply_saliency_detection(
image=inpainted_image,
saliency_testers=saliency_testers, # type: ignore
)
for new_sal_map, sal_map in zip(new_saliency_maps, saliency_maps):
example[new_sal_map] = sal_map

return example


@dataclass
class PosterLayoutConfig(ds.BuilderConfig):
_saliency_maps: Optional[Sequence[str]] = None
_saliency_testers: Optional[Sequence[str]] = None

def get_default_salieny_maps(self) -> Sequence[str]:
return ["basnet_saliency_map", "pfpn_saliency_map"]

def get_salient_maps(self) -> Sequence[str]:
if self.name == "default":
return self.get_default_salieny_maps()
elif self.name == "ralf-style":
return ["saliency_map", "saliency_map_sub"]
else:
raise ValueError("Invalid config name")

def get_saliency_testers(self) -> Optional[Sequence[str]]:
if self.name == "default":
return None
elif self.name == "ralf-style":
return [
"creative-graphic-design/ISNet-general-use",
"creative-graphic-design/BASNet-SmartText",
]
else:
raise ValueError(f"Invalid config name: {self.name}")

def __post_init__(self):
super().__post_init__()
self._saliency_maps = self.get_salient_maps()
self._saliency_testers = self.get_saliency_testers()

@property
def saliency_maps(self) -> Sequence[str]:
assert self._saliency_maps is not None
return self._saliency_maps

@property
def saliency_testers(self) -> Sequence[str]:
assert self._saliency_testers is not None
return self._saliency_testers


class PosterLayoutDataset(ds.GeneratorBasedBuilder):
VERSION = ds.Version("1.0.0")
BUILDER_CONFIGS = [ds.BuilderConfig(version=VERSION)]
BUILDER_CONFIG_CLASS = PosterLayoutConfig
BUILDER_CONFIGS = [
PosterLayoutConfig(name="default", version=VERSION),
PosterLayoutConfig(name="ralf-style", version=VERSION),
]

def _info(self) -> ds.DatasetInfo:
features = ds.Features(
config: PosterLayoutConfig = self.config # type: ignore
base_features = {
"original_poster": ds.Image(),
"inpainted_poster": ds.Image(),
"canvas": ds.Image(),
}
saliency_map_features = (
{
"original_poster": ds.Image(),
"inpainted_poster": ds.Image(),
"basnet_saliency_map": ds.Image(),
"pfpn_saliency_map": ds.Image(),
"canvas": ds.Image(),
"annotations": ds.Sequence(
{
"poster_path": ds.Value("string"),
"total_elem": ds.Value("int32"),
"cls_elem": ds.ClassLabel(
num_classes=4, names=["text", "logo", "underlay", "INVALID"]
),
"box_elem": ds.Sequence(ds.Value("int32")),
}
),
}
if self.config.name == "default"
else {
**{col: ds.Image() for col in config.saliency_maps},
}
)
annotation_features = {
"annotations": ds.Sequence(
{
"poster_path": ds.Value("string"),
"total_elem": ds.Value("int32"),
"cls_elem": ds.ClassLabel(
num_classes=4, names=["text", "logo", "underlay", "INVALID"]
),
"box_elem": ds.Sequence(ds.Value("int32")),
}
),
}
features = ds.Features(
{**base_features, **saliency_map_features, **annotation_features}
)
return ds.DatasetInfo(
description=_DESCRIPTION,
citation=_CITATION,
Expand Down Expand Up @@ -224,6 +414,7 @@ def _generate_train_examples(
)

it = zip(poster_files, inpainted_files, basnet_map_files, pfpn_map_files)

for i, (
original_poster_path,
inpainted_poster_path,
Expand All @@ -234,18 +425,15 @@ def _generate_train_examples(
poster_anns = ann_df[ann_df["poster_path"] == poster_path]

annotations = poster_anns.to_dict(orient="records")

yield (
i,
{
"original_poster": load_image(original_poster_path),
"inpainted_poster": load_image(inpainted_poster_path),
"basnet_saliency_map": load_image(basnet_map_path),
"pfpn_saliency_map": load_image(pfpn_map_path),
"canvas": None,
"annotations": annotations,
},
)
example = {
"original_poster": load_image(original_poster_path),
"inpainted_poster": load_image(inpainted_poster_path),
"basnet_saliency_map": load_image(basnet_map_path),
"pfpn_saliency_map": load_image(pfpn_map_path),
"canvas": None,
"annotations": annotations,
}
yield i, example

def _generate_test_examples(self, poster: TestPoster, saliency_maps: SaliencyMaps):
canvas_files = get_canvas_files(base_dir=poster["canvas"])
Expand All @@ -255,20 +443,19 @@ def _generate_test_examples(self, poster: TestPoster, saliency_maps: SaliencyMap

assert len(canvas_files) == len(basnet_map_files) == len(pfpn_map_files)
it = zip(canvas_files, basnet_map_files, pfpn_map_files)

for i, (canvas_path, basnet_map_path, pfpn_map_path) in enumerate(it):
yield (
i,
{
"original_poster": None,
"inpainted_poster": None,
"basnet_saliency_map": load_image(basnet_map_path),
"pfpn_saliency_map": load_image(pfpn_map_path),
"canvas": load_image(canvas_path),
"annotations": None,
},
)
example = {
"original_poster": None,
"inpainted_poster": None,
"basnet_saliency_map": load_image(basnet_map_path),
"pfpn_saliency_map": load_image(pfpn_map_path),
"canvas": load_image(canvas_path),
"annotations": None,
}
yield i, example

def _generate_examples(
def _get_generator(
self,
poster: Union[TrainPoster, TestPoster],
saliency_maps: SaliencyMaps,
Expand All @@ -289,3 +476,52 @@ def _generate_examples(
)
else:
raise ValueError("Invalid dataset")

def _generate_examples(
self,
poster: Union[TrainPoster, TestPoster],
saliency_maps: SaliencyMaps,
annotation: Optional[str] = None,
):
config: PosterLayoutConfig = self.config # type: ignore

generator = self._get_generator(
poster=poster,
saliency_maps=saliency_maps,
annotation=annotation,
)

def _generate_default(generator):
for idx, example in generator:
yield idx, example

def _generate_ralf_style(generator):
from ralfpt.saliency_detection import SaliencyTester
from simple_lama_inpainting import SimpleLama

inpainter = SimpleLama()
saliency_testers = [
SaliencyTester(model_name=model) for model in config.saliency_testers
]

for idx, example in generator:
old_saliency_maps = config.get_default_salieny_maps()
new_saliency_maps = config.saliency_maps

example = ralf_style_example(
example,
inpainter=inpainter,
old_saliency_maps=old_saliency_maps,
new_saliency_maps=new_saliency_maps,
saliency_testers=saliency_testers,
)
yield idx, example

if config.name == "default":
yield from _generate_default(generator)

elif config.name == "ralf-style":
yield from _generate_ralf_style(generator)

else:
raise ValueError(f"Invalid config name: {config.name}")
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ The language data in PKU-PosterLayout is in Chinese ([BCP-47 zh](https://www.rfc
import datasets as ds

dataset = ds.load_dataset("creative-graphic-design/PKU-PosterLayout")

# or you can download RALF (https://arxiv.org/abs/2311.13602)-style preprocessed dataset
dataset = ds.load_dataset("creative-graphic-design/PKU-PosterLayout", name="ralf-style")

```

### Data Fields
Expand Down
Loading