88from PIL import Image
99import numpy as np
1010
11- __all__ = [' Detectron2LayoutModel' ]
11+ __all__ = [" Detectron2LayoutModel" ]
1212
1313
1414class BaseLayoutModel (ABC ):
15-
1615 @abstractmethod
17- def detect (self ): pass
16+ def detect (self ):
17+ pass
1818
1919
2020class Detectron2LayoutModel (BaseLayoutModel ):
2121 """Create a Detectron2-based Layout Detection Model
2222
2323 Args:
24- config_path (:obj:`str`):
25- The path to the configuration file.
26- model_path (:obj:`str`, None):
27- The path to the saved weights of the model.
28- If set, overwrite the weights in the configuration file.
24+ config_path (:obj:`str`):
25+ The path to the configuration file.
26+ model_path (:obj:`str`, None):
27+ The path to the saved weights of the model.
28+ If set, overwrite the weights in the configuration file.
2929 Defaults to `None`.
30- label_map (:obj:`dict`, optional):
30+ label_map (:obj:`dict`, optional):
3131 The map from the model prediction (ids) to real
32- word labels (strings).
32+ word labels (strings).
3333 Defaults to `None`.
34- extra_config (:obj:`list`, optional):
35- Extra configuration passed to the Detectron2 model
36- configuration. The argument will be used in the `merge_from_list
34+ extra_config (:obj:`list`, optional):
35+ Extra configuration passed to the Detectron2 model
36+ configuration. The argument will be used in the `merge_from_list
3737 <https://detectron2.readthedocs.io/modules/config.html
38- #detectron2.config.CfgNode.merge_from_list>`_ function.
38+ #detectron2.config.CfgNode.merge_from_list>`_ function.
3939 Defaults to `[]`.
4040
4141 Examples::
@@ -45,10 +45,7 @@ class Detectron2LayoutModel(BaseLayoutModel):
4545
4646 """
4747
48- def __init__ (self , config_path ,
49- model_path = None ,
50- label_map = None ,
51- extra_config = []):
48+ def __init__ (self , config_path , model_path = None , label_map = None , extra_config = []):
5249
5350 cfg = get_cfg ()
5451 config_path = PathManager .get_local_path (config_path )
@@ -57,15 +54,15 @@ def __init__(self, config_path,
5754
5855 if model_path is not None :
5956 cfg .MODEL .WEIGHTS = model_path
60- cfg .MODEL .DEVICE = ' cuda' if torch .cuda .is_available () else ' cpu'
57+ cfg .MODEL .DEVICE = " cuda" if torch .cuda .is_available () else " cpu"
6158 self .cfg = cfg
6259
6360 self .label_map = label_map
6461 self ._create_model ()
6562
6663 def gather_output (self , outputs ):
6764
68- instance_pred = outputs [' instances' ].to ("cpu" )
65+ instance_pred = outputs [" instances" ].to ("cpu" )
6966
7067 layout = Layout ()
7168 scores = instance_pred .scores .tolist ()
@@ -79,9 +76,8 @@ def gather_output(self, outputs):
7976 label = self .label_map .get (label , label )
8077
8178 cur_block = TextBlock (
82- Rectangle (x_1 , y_1 , x_2 , y_2 ),
83- type = label ,
84- score = score )
79+ Rectangle (x_1 , y_1 , x_2 , y_2 ), type = label , score = score
80+ )
8581 layout .append (cur_block )
8682
8783 return layout
@@ -101,8 +97,8 @@ def detect(self, image):
10197
10298 # Convert PIL Image Input
10399 if isinstance (image , Image .Image ):
104- if image .mode != ' RGB' :
105- image = image .convert (' RGB' )
100+ if image .mode != " RGB" :
101+ image = image .convert (" RGB" )
106102 image = np .array (image )
107103
108104 outputs = self .model (image )
0 commit comments