diff --git a/CRAFT/basenet/vgg16_bn.py b/CRAFT/basenet/vgg16_bn.py index 6cad358..e401b06 100644 --- a/CRAFT/basenet/vgg16_bn.py +++ b/CRAFT/basenet/vgg16_bn.py @@ -27,10 +27,10 @@ def init_weights(modules): class vgg16_bn(torch.nn.Module): - def __init__(self, pretrained=True, freeze=True): + def __init__(self, weights=True, freeze=True): super(vgg16_bn, self).__init__() model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://') - vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features + vgg_pretrained_features = models.vgg16_bn(weights=weights).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() @@ -52,7 +52,7 @@ def __init__(self, pretrained=True, freeze=True): nn.Conv2d(1024, 1024, kernel_size=1) ) - if not pretrained: + if not weights: init_weights(self.slice1.modules()) init_weights(self.slice2.modules()) init_weights(self.slice3.modules()) diff --git a/CRAFT/craft.py b/CRAFT/craft.py index 782a98e..5ab2b11 100755 --- a/CRAFT/craft.py +++ b/CRAFT/craft.py @@ -33,11 +33,11 @@ def forward(self, x): class CRAFT(nn.Module): - def __init__(self, pretrained=False, freeze=False): + def __init__(self, weights=False, freeze=False): super(CRAFT, self).__init__() """ Base network """ - self.basenet = vgg16_bn(pretrained, freeze) + self.basenet = vgg16_bn(weights, freeze) """ U network """ self.upconv1 = double_conv(1024, 512, 256) diff --git a/CRAFT/model.py b/CRAFT/model.py index a2e2e4e..4bc93b7 100644 --- a/CRAFT/model.py +++ b/CRAFT/model.py @@ -6,8 +6,10 @@ from PIL import Image import numpy as np import cv2 -from huggingface_hub import hf_hub_url, cached_download +from multiprocessing import Pool +import functools +from huggingface_hub import hf_hub_url, hf_hub_download from CRAFT.craft import CRAFT, init_CRAFT_model from CRAFT.refinenet import RefineNet, init_refiner_model from CRAFT.craft_utils import adjustResultCoordinates, getDetBoxes @@ -39,6 +41,16 @@ def preprocess_image(image: np.ndarray, canvas_size: int, mag_ratio: bool): x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] return x, ratio_w, ratio_h +def process_single(args): + text_score, link_score, ratio_w, ratio_h, text_threshold, link_threshold, low_text = args + + boxes, polys = getDetBoxes( + text_score, link_score, + text_threshold, link_threshold, + low_text, False + ) + boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h) + return boxes class CRAFTModel: @@ -72,8 +84,11 @@ def __init__( config = HF_MODELS[model_name] paths[model_name] = os.path.join(cache_dir, config['filename']) if not local_files_only: - config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) - cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename']) + paths[model_name] = hf_hub_download( + repo_id=config['repo_id'], + filename=config['filename'], + cache_dir=cache_dir + ) self.net = init_CRAFT_model(paths['craft'], device, fp16=fp16) if self.use_refiner: @@ -100,6 +115,57 @@ def get_text_map(self, x: torch.Tensor, ratio_w: int, ratio_h: int) -> Tuple[np. return score_text, score_link + + def get_batch_polygons(self, batch_images: torch.Tensor, ratios_w: torch.Tensor, ratios_h: torch.Tensor): + """Batch process pre-normalized images on GPU""" + # Forward pass + #batch_images = batch_images.float() # Convert to float32 + if self.fp16: + batch_images = batch_images.half() # Convert to half if using fp16 + + with torch.no_grad(): + y, feature = self.net(batch_images.to(self.device)) + if self.refiner: + y_refiner = self.refiner(y, feature) + link_scores = y_refiner[..., 0] # [B, H, W] + else: + link_scores = y[..., 1] # [B, H, W] + + text_scores = y[..., 0] # [B, H, W] + + batch_size = batch_images.size(0) + text_scores = text_scores.detach().cpu().numpy() + link_scores = link_scores.detach().cpu().numpy() + + ratios_w = ratios_w.cpu().numpy() + ratios_h = ratios_h.cpu().numpy() + + with Pool(processes=os.cpu_count()) as pool: + batch_args = [(text_scores[i], link_scores[i], ratios_w[i], ratios_h[i], + self.text_threshold, self.link_threshold, self.low_text) + for i in range(batch_size)] + batch_polys = pool.map(process_single, batch_args) + + return batch_polys + + def _convex_hull(self, x_coords, y_coords): + """Simple convex hull approximation for GPU tensors""" + # For character detection, a simple bounding box is often sufficient + min_x = torch.min(x_coords) + max_x = torch.max(x_coords) + min_y = torch.min(y_coords) + max_y = torch.max(y_coords) + + # Create rectangle corners + pts = torch.tensor([ + [min_x, min_y], + [max_x, min_y], + [max_x, max_y], + [min_x, max_y] + ], device=x_coords.device) + + return pts + def get_polygons(self, image: Image.Image) -> List[List[List[int]]]: x, ratio_w, ratio_h = preprocess_image(np.array(image), self.canvas_size, self.mag_ratio) @@ -147,3 +213,4 @@ def get_boxes(self, image: Image.Image) -> List[List[List[int]]]: boxes_final = self._get_boxes_preproc(x, ratio_w, ratio_h) return boxes_final +