Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ source = ["eynollah"]

[tool.ruff]
line-length = 120
# TODO: Reenable and fix after release v0.6.0
exclude = ['src/eynollah/training']

[tool.ruff.lint]
ignore = [
Expand Down
29 changes: 14 additions & 15 deletions src/eynollah/training/gt_gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def get_textline_contours_for_visualization(xml_file):



x_len, y_len = 0, 0
for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
Expand Down Expand Up @@ -293,6 +294,7 @@ def get_textline_contours_and_ocr_text(xml_file):



x_len, y_len = 0, 0
for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
Expand Down Expand Up @@ -362,7 +364,7 @@ def get_layout_contours_for_visualization(xml_file):
link=alltags[0].split('}')[0]+'}'


x_len, y_len = 0, 0
for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
Expand Down Expand Up @@ -637,23 +639,20 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
link=alltags[0].split('}')[0]+'}'


x_len, y_len = 0, 0
for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])

if 'columns_width' in list(config_params.keys()):
columns_width_dict = config_params['columns_width']
metadata_element = root1.find(link+'Metadata')
comment_is_sub_element = False
num_col = None
for child in metadata_element:
tag2 = child.tag
if tag2.endswith('}Comments') or tag2.endswith('}comments'):
text_comments = child.text
num_col = int(text_comments.split('num_col')[1])
comment_is_sub_element = True
if not comment_is_sub_element:
num_col = None

if num_col:
x_new = columns_width_dict[str(num_col)]
Expand Down Expand Up @@ -1739,15 +1738,15 @@ def read_xml(xml_file):



def bounding_box(cnt,color, corr_order_index ):
x, y, w, h = cv2.boundingRect(cnt)
x = int(x*scale_w)
y = int(y*scale_h)

w = int(w*scale_w)
h = int(h*scale_h)

return [x,y,w,h,int(color), int(corr_order_index)+1]
# def bounding_box(cnt,color, corr_order_index ):
# x, y, w, h = cv2.boundingRect(cnt)
# x = int(x*scale_w)
# y = int(y*scale_h)
#
# w = int(w*scale_w)
# h = int(h*scale_h)
#
# return [x,y,w,h,int(color), int(corr_order_index)+1]

def resize_image(seg_in,input_height,input_width):
return cv2.resize(seg_in,(input_width,input_height),interpolation=cv2.INTER_NEAREST)
Expand Down
53 changes: 32 additions & 21 deletions src/eynollah/training/inference.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import sys
import os
from typing import Tuple
import warnings
import json

import numpy as np
import cv2
from tensorflow.keras.models import load_model
from numpy._typing import NDArray
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import *
from keras.models import Model, load_model
from keras import backend as K
import click
from tensorflow.python.keras import backend as tensorflow_backend
import xml.etree.ElementTree as ET
Expand All @@ -34,6 +35,7 @@
"""

class sbb_predict:

def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
self.image=image
self.dir_in=dir_in
Expand Down Expand Up @@ -77,7 +79,7 @@ def otsu_copy_binary(self,img):
#print(img[:,:,0].min())
#blur = cv2.GaussianBlur(img,(5,5))
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
_, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)



Expand Down Expand Up @@ -116,19 +118,19 @@ def soft_dice_loss(self,y_true, y_pred, epsilon=1e-6):
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch

def weighted_categorical_crossentropy(self,weights=None):

def loss(y_true, y_pred):
labels_floats = tf.cast(y_true, tf.float32)
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)

if weights is not None:
weight_mask = tf.maximum(tf.reduce_max(tf.constant(
np.array(weights, dtype=np.float32)[None, None, None])
* labels_floats, axis=-1), 1.0)
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
return tf.reduce_mean(per_pixel_loss)
return self.loss
# def weighted_categorical_crossentropy(self,weights=None):
#
# def loss(y_true, y_pred):
# labels_floats = tf.cast(y_true, tf.float32)
# per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
#
# if weights is not None:
# weight_mask = tf.maximum(tf.reduce_max(tf.constant(
# np.array(weights, dtype=np.float32)[None, None, None])
# * labels_floats, axis=-1), 1.0)
# per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
# return tf.reduce_mean(per_pixel_loss)
# return self.loss


def IoU(self,Yi,y_predi):
Expand Down Expand Up @@ -177,12 +179,13 @@ def start_new_session_and_model(self):
##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir)

assert isinstance(self.model, Model)
if self.task != 'classification' and self.task != 'reading_order':
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]

def visualize_model_output(self, prediction, img, task):
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
if task == "binarization":
prediction = prediction * -1
prediction = prediction + 1
Expand Down Expand Up @@ -226,9 +229,12 @@ def visualize_model_output(self, prediction, img, task):

added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0)

assert isinstance(added_image, np.ndarray)
assert isinstance(layout_only, np.ndarray)
return added_image, layout_only

def predict(self, image_dir):
assert isinstance(self.model, Model)
if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name']
img_1ch = img=cv2.imread(image_dir, 0)
Expand All @@ -240,7 +246,7 @@ def predict(self, image_dir):
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]

label_p_pred = self.model.predict(img_in, verbose=0)
label_p_pred = self.model.predict(img_in, verbose='0')
index_class = np.argmax(label_p_pred[0])

print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
Expand Down Expand Up @@ -361,7 +367,7 @@ def predict(self, image_dir):
#input_1[:,:,1] = img3[:,:,0]/5.

if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs):
y_pr = self.model.predict(input_1 , verbose=0)
y_pr = self.model.predict(input_1 , verbose='0')
scalibility_num = scalibility_num+1

if batch_counter==inference_bs:
Expand Down Expand Up @@ -395,6 +401,7 @@ def predict(self, image_dir):
name_space = name_space.split('{')[1]

page_element = root_xml.find(link+'Page')
assert isinstance(page_element, ET.Element)

"""
ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
Expand Down Expand Up @@ -489,14 +496,16 @@ def predict(self, image_dir):

img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
verbose=0)
verbose='0')

if self.task == 'enhancement':
seg = label_p_pred[0, :, :, :]
seg = seg * 255
elif self.task == 'segmentation' or self.task == 'binarization':
seg = np.argmax(label_p_pred, axis=3)[0]
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
else:
raise ValueError(f"Unhandled task {self.task}")


if i == 0 and j == 0:
Expand Down Expand Up @@ -551,6 +560,8 @@ def predict(self, image_dir):
elif self.task == 'segmentation' or self.task == 'binarization':
seg = np.argmax(label_p_pred, axis=3)[0]
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
else:
raise ValueError(f"Unhandled task {self.task}")

prediction_true = seg.astype(int)

Expand Down
Loading