Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CRAFT/basenet/vgg16_bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def init_weights(modules):


class vgg16_bn(torch.nn.Module):
def __init__(self, pretrained=True, freeze=True):
def __init__(self, weights=True, freeze=True):
super(vgg16_bn, self).__init__()
model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
vgg_pretrained_features = models.vgg16_bn(weights=weights).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
Expand All @@ -52,7 +52,7 @@ def __init__(self, pretrained=True, freeze=True):
nn.Conv2d(1024, 1024, kernel_size=1)
)

if not pretrained:
if not weights:
init_weights(self.slice1.modules())
init_weights(self.slice2.modules())
init_weights(self.slice3.modules())
Expand Down
4 changes: 2 additions & 2 deletions CRAFT/craft.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def forward(self, x):

class CRAFT(nn.Module):

def __init__(self, pretrained=False, freeze=False):
def __init__(self, weights=False, freeze=False):
super(CRAFT, self).__init__()

""" Base network """
self.basenet = vgg16_bn(pretrained, freeze)
self.basenet = vgg16_bn(weights, freeze)

""" U network """
self.upconv1 = double_conv(1024, 512, 256)
Expand Down
73 changes: 70 additions & 3 deletions CRAFT/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from PIL import Image
import numpy as np
import cv2
from huggingface_hub import hf_hub_url, cached_download
from multiprocessing import Pool
import functools

from huggingface_hub import hf_hub_url, hf_hub_download
from CRAFT.craft import CRAFT, init_CRAFT_model
from CRAFT.refinenet import RefineNet, init_refiner_model
from CRAFT.craft_utils import adjustResultCoordinates, getDetBoxes
Expand Down Expand Up @@ -39,6 +41,16 @@ def preprocess_image(image: np.ndarray, canvas_size: int, mag_ratio: bool):
x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
return x, ratio_w, ratio_h

def process_single(args):
text_score, link_score, ratio_w, ratio_h, text_threshold, link_threshold, low_text = args

boxes, polys = getDetBoxes(
text_score, link_score,
text_threshold, link_threshold,
low_text, False
)
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
return boxes

class CRAFTModel:

Expand Down Expand Up @@ -72,8 +84,11 @@ def __init__(
config = HF_MODELS[model_name]
paths[model_name] = os.path.join(cache_dir, config['filename'])
if not local_files_only:
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'])
paths[model_name] = hf_hub_download(
repo_id=config['repo_id'],
filename=config['filename'],
cache_dir=cache_dir
)

self.net = init_CRAFT_model(paths['craft'], device, fp16=fp16)
if self.use_refiner:
Expand All @@ -100,6 +115,57 @@ def get_text_map(self, x: torch.Tensor, ratio_w: int, ratio_h: int) -> Tuple[np.

return score_text, score_link


def get_batch_polygons(self, batch_images: torch.Tensor, ratios_w: torch.Tensor, ratios_h: torch.Tensor):
"""Batch process pre-normalized images on GPU"""
# Forward pass
#batch_images = batch_images.float() # Convert to float32
if self.fp16:
batch_images = batch_images.half() # Convert to half if using fp16

with torch.no_grad():
y, feature = self.net(batch_images.to(self.device))
if self.refiner:
y_refiner = self.refiner(y, feature)
link_scores = y_refiner[..., 0] # [B, H, W]
else:
link_scores = y[..., 1] # [B, H, W]

text_scores = y[..., 0] # [B, H, W]

batch_size = batch_images.size(0)
text_scores = text_scores.detach().cpu().numpy()
link_scores = link_scores.detach().cpu().numpy()

ratios_w = ratios_w.cpu().numpy()
ratios_h = ratios_h.cpu().numpy()

with Pool(processes=os.cpu_count()) as pool:
batch_args = [(text_scores[i], link_scores[i], ratios_w[i], ratios_h[i],
self.text_threshold, self.link_threshold, self.low_text)
for i in range(batch_size)]
batch_polys = pool.map(process_single, batch_args)

return batch_polys

def _convex_hull(self, x_coords, y_coords):
"""Simple convex hull approximation for GPU tensors"""
# For character detection, a simple bounding box is often sufficient
min_x = torch.min(x_coords)
max_x = torch.max(x_coords)
min_y = torch.min(y_coords)
max_y = torch.max(y_coords)

# Create rectangle corners
pts = torch.tensor([
[min_x, min_y],
[max_x, min_y],
[max_x, max_y],
[min_x, max_y]
], device=x_coords.device)

return pts

def get_polygons(self, image: Image.Image) -> List[List[List[int]]]:
x, ratio_w, ratio_h = preprocess_image(np.array(image), self.canvas_size, self.mag_ratio)

Expand Down Expand Up @@ -147,3 +213,4 @@ def get_boxes(self, image: Image.Image) -> List[List[List[int]]]:

boxes_final = self._get_boxes_preproc(x, ratio_w, ratio_h)
return boxes_final