Skip to content

Commit cc3c34d

Browse files
[formrecognizer] Adding to_dict() on custom models (Azure#18402)
* adding to_dict() on custom models * add tests * add to_dict to models * fix TextAppearance to_dict() * add remaining tests and fixes * review feedback * fix value transform and add service test * add to_dict test in custom forms * fix tests Co-authored-by: Krista Pratico <krpratic@microsoft.com>
1 parent 68e38c6 commit cc3c34d

File tree

6 files changed

+1281
-2
lines changed

6 files changed

+1281
-2
lines changed

sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_models.py

Lines changed: 187 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ class Point(namedtuple("Point", "x y")):
144144
def __new__(cls, x, y):
145145
return super(Point, cls).__new__(cls, x, y)
146146

147+
def to_dict(self):
148+
return {"x": self.x, "y": self.y}
149+
147150

148151
class FormPageRange(namedtuple("FormPageRange", "first_page_number last_page_number")):
149152
"""The 1-based page range of the form.
@@ -159,6 +162,12 @@ def __new__(cls, first_page_number, last_page_number):
159162
cls, first_page_number, last_page_number
160163
)
161164

165+
def to_dict(self):
166+
return {
167+
"first_page_number": self.first_page_number,
168+
"last_page_number": self.last_page_number,
169+
}
170+
162171

163172
class FormElement(object):
164173
"""Base type which includes properties for a form element.
@@ -183,6 +192,14 @@ def __init__(self, **kwargs):
183192
self.text = kwargs.get("text", None)
184193
self.kind = kwargs.get("kind", None)
185194

195+
def to_dict(self):
196+
return {
197+
"text": self.text,
198+
"bounding_box": [f.to_dict() for f in self.bounding_box] if self.bounding_box else [],
199+
"page_number": self.page_number,
200+
"kind": self.kind,
201+
}
202+
186203

187204
class RecognizedForm(object):
188205
"""Represents a form that has been recognized by a trained or prebuilt model.
@@ -234,6 +251,16 @@ def __repr__(self):
234251
)[:1024]
235252
)
236253

254+
def to_dict(self):
255+
return {
256+
"fields": {k: v.to_dict() for k, v in self.fields.items()} if self.fields else {},
257+
"form_type": self.form_type,
258+
"pages": [v.to_dict() for v in self.pages] if self.pages else [],
259+
"model_id": self.model_id,
260+
"form_type_confidence": self.form_type_confidence,
261+
"page_range": self.page_range.to_dict() if self.page_range else None
262+
}
263+
237264

238265
class FormField(object):
239266
"""Represents a field recognized in an input form.
@@ -305,6 +332,20 @@ def __repr__(self):
305332
:1024
306333
]
307334

335+
def to_dict(self):
336+
value = self.value
337+
if isinstance(self.value, dict):
338+
value = {k: v.to_dict() for k, v in self.value.items()}
339+
elif isinstance(self.value, list):
340+
value = [v.to_dict() for v in self.value]
341+
return {
342+
"value_type": self.value_type,
343+
"name": self.name,
344+
"value": value,
345+
"confidence": self.confidence,
346+
"label_data": self.label_data.to_dict() if self.label_data else None,
347+
"value_data": self.value_data.to_dict() if self.value_data else None,
348+
}
308349

309350
class FieldData(object):
310351
"""Contains the data for the form field. This includes the text,
@@ -374,6 +415,14 @@ def __repr__(self):
374415
:1024
375416
]
376417

418+
def to_dict(self):
419+
return {
420+
"text": self.text,
421+
"bounding_box": [f.to_dict() for f in self.bounding_box] if self.bounding_box else [],
422+
"page_number": self.page_number,
423+
"field_elements": [f.to_dict() for f in self.field_elements] if self.field_elements else []
424+
}
425+
377426

378427
class FormPage(object):
379428
"""Represents a page recognized from the input document. Contains lines,
@@ -433,6 +482,17 @@ def __repr__(self):
433482
)[:1024]
434483
)
435484

485+
def to_dict(self):
486+
return {
487+
"page_number": self.page_number,
488+
"text_angle": self.text_angle,
489+
"width": self.width,
490+
"height": self.height,
491+
"unit": self.unit,
492+
"tables": [table.to_dict() for table in self.tables] if self.tables else [],
493+
"lines": [line.to_dict() for line in self.lines] if self.lines else [],
494+
"selection_marks": [mark.to_dict() for mark in self.selection_marks] if self.selection_marks else []
495+
}
436496

437497
class FormLine(FormElement):
438498
"""An object representing an extracted line of text.
@@ -489,6 +549,16 @@ def __repr__(self):
489549
:1024
490550
]
491551

552+
def to_dict(self):
553+
return {
554+
"text": self.text,
555+
"bounding_box": [f.to_dict() for f in self.bounding_box] if self.bounding_box else [],
556+
"words": [f.to_dict() for f in self.words] if self.words else [],
557+
"page_number": self.page_number,
558+
"kind": self.kind,
559+
"appearance": self.appearance.to_dict() if self.appearance else None
560+
}
561+
492562

493563
class FormWord(FormElement):
494564
"""Represents a word recognized from the input document.
@@ -526,6 +596,15 @@ def __repr__(self):
526596
:1024
527597
]
528598

599+
def to_dict(self):
600+
return {
601+
"text": self.text,
602+
"bounding_box": [f.to_dict() for f in self.bounding_box] if self.bounding_box else [],
603+
"confidence": self.confidence,
604+
"page_number": self.page_number,
605+
"kind": self.kind,
606+
}
607+
529608

530609
class FormSelectionMark(FormElement):
531610
"""Information about the extracted selection mark.
@@ -560,12 +639,22 @@ def _from_generated(cls, mark, page):
560639
)
561640

562641
def __repr__(self):
563-
return "FormSelectionMark(text={}, bounding_box={}, confidence={}, page_number={}, state={})".format(
564-
self.text, self.bounding_box, self.confidence, self.page_number, self.state
642+
return "FormSelectionMark(text={}, bounding_box={}, confidence={}, page_number={}, state={}, kind={})".format(
643+
self.text, self.bounding_box, self.confidence, self.page_number, self.state, self.kind
565644
)[
566645
:1024
567646
]
568647

648+
def to_dict(self):
649+
return {
650+
"text": self.text,
651+
"bounding_box": [f.to_dict() for f in self.bounding_box] if self.bounding_box else [],
652+
"confidence": self.confidence,
653+
"state": self.state,
654+
"page_number": self.page_number,
655+
"kind": self.kind,
656+
}
657+
569658

570659
class FormTable(object):
571660
"""Information about the extracted table contained on a page.
@@ -606,6 +695,15 @@ def __repr__(self):
606695
:1024
607696
]
608697

698+
def to_dict(self):
699+
return {
700+
"page_number": self.page_number,
701+
"row_count": self.row_count,
702+
"column_count": self.column_count,
703+
"cells": [cell.to_dict() for cell in self.cells],
704+
"bounding_box": [box.to_dict() for box in self.bounding_box] if self.bounding_box else []
705+
}
706+
609707

610708
class FormTableCell(object): # pylint:disable=too-many-instance-attributes
611709
"""Represents a cell contained in a table recognized from the input document.
@@ -691,6 +789,22 @@ def __repr__(self):
691789
]
692790
)
693791

792+
def to_dict(self):
793+
return {
794+
"text": self.text,
795+
"row_index": self.row_index,
796+
"column_index": self.column_index,
797+
"row_span": self.row_span,
798+
"column_span": self.column_span,
799+
"confidence": self.confidence,
800+
"is_header": self.is_header,
801+
"is_footer": self.is_footer,
802+
"page_number": self.page_number,
803+
"bounding_box": [box.to_dict() for box in self.bounding_box] if self.bounding_box else [],
804+
"field_elements": [element.to_dict() for element in self.field_elements]
805+
if self.field_elements else None
806+
}
807+
694808

695809
class CustomFormModel(object):
696810
"""Represents a model trained from custom forms.
@@ -793,6 +907,18 @@ def __repr__(self):
793907
]
794908
)
795909

910+
def to_dict(self):
911+
return {
912+
"model_id": self.model_id,
913+
"status": self.status,
914+
"training_started_on": self.training_started_on,
915+
"training_completed_on": self.training_completed_on,
916+
"submodels": [submodel.to_dict() for submodel in self.submodels] if self.submodels else [],
917+
"errors": [err.to_dict() for err in self.errors] if self.errors else [],
918+
"training_documents": [doc.to_dict() for doc in self.training_documents] if self.training_documents else [],
919+
"model_name": self.model_name,
920+
"properties": self.properties.to_dict() if self.properties else None
921+
}
796922

797923
class CustomFormSubmodel(object):
798924
"""Represents a submodel that extracts fields from a specific type of form.
@@ -887,6 +1013,14 @@ def __repr__(self):
8871013
:1024
8881014
]
8891015

1016+
def to_dict(self):
1017+
return {
1018+
"model_id": self.model_id,
1019+
"accuracy": self.accuracy,
1020+
"fields": {k: v.to_dict() for k, v in self.fields.items()} if self.fields else {},
1021+
"form_type": self.form_type
1022+
}
1023+
8901024

8911025
class CustomFormModelField(object):
8921026
"""A field that the model will extract from forms it analyzes.
@@ -920,6 +1054,13 @@ def __repr__(self):
9201054
self.label, self.name, self.accuracy
9211055
)[:1024]
9221056

1057+
def to_dict(self):
1058+
return {
1059+
"label": self.label,
1060+
"accuracy": self.accuracy,
1061+
"name": self.name
1062+
}
1063+
9231064

9241065
class TrainingDocumentInfo(object):
9251066
"""Report for an individual document used for training
@@ -991,6 +1132,15 @@ def __repr__(self):
9911132
:1024
9921133
]
9931134

1135+
def to_dict(self):
1136+
return {
1137+
"name": self.name,
1138+
"status": self.status,
1139+
"page_count": self.page_count,
1140+
"errors": [err.to_dict() for err in self.errors],
1141+
"model_id": self.model_id
1142+
}
1143+
9941144

9951145
class FormRecognizerError(object):
9961146
"""Represents an error that occurred while training.
@@ -1016,6 +1166,12 @@ def __repr__(self):
10161166
self.code, self.message
10171167
)[:1024]
10181168

1169+
def to_dict(self):
1170+
return {
1171+
"code": self.code,
1172+
"message": self.message
1173+
}
1174+
10191175

10201176
class CustomFormModelInfo(object):
10211177
"""Custom model information.
@@ -1081,6 +1237,16 @@ def __repr__(self):
10811237
)[:1024]
10821238
)
10831239

1240+
def to_dict(self):
1241+
return {
1242+
"model_id": self.model_id,
1243+
"status": self.status,
1244+
"training_started_on": self.training_started_on,
1245+
"training_completed_on": self.training_completed_on,
1246+
"model_name": self.model_name,
1247+
"properties": self.properties.to_dict() if self.properties else None
1248+
}
1249+
10841250

10851251
class AccountProperties(object):
10861252
"""Summary of all the custom models on the account.
@@ -1105,6 +1271,12 @@ def __repr__(self):
11051271
self.custom_model_count, self.custom_model_limit
11061272
)[:1024]
11071273

1274+
def to_dict(self):
1275+
return {
1276+
"custom_model_count": self.custom_model_count,
1277+
"custom_model_limit": self.custom_model_limit
1278+
}
1279+
11081280

11091281
class CustomFormModelProperties(object):
11101282
"""Optional model properties.
@@ -1126,6 +1298,11 @@ def __repr__(self):
11261298
self.is_composed_model
11271299
)
11281300

1301+
def to_dict(self):
1302+
return {
1303+
"is_composed_model": self.is_composed_model
1304+
}
1305+
11291306

11301307
class TextAppearance(object):
11311308
"""An object representing the appearance of the text line.
@@ -1150,6 +1327,11 @@ def _from_generated(cls, appearance):
11501327
def __repr__(self):
11511328
return "TextAppearance(style={})".format(repr(self.style))
11521329

1330+
def to_dict(self):
1331+
return {
1332+
"style": self.style.to_dict() if self.style else None
1333+
}
1334+
11531335

11541336
class TextStyle(object):
11551337
"""An object representing the style of the text line.
@@ -1167,3 +1349,6 @@ def __init__(self, **kwargs):
11671349

11681350
def __repr__(self):
11691351
return "TextStyle(name={}, confidence={})".format(self.name, self.confidence)
1352+
1353+
def to_dict(self):
1354+
return {"name": self.name, "confidence": self.confidence}

sdk/formrecognizer/azure-ai-formrecognizer/tests/test_compose_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def test_compose_model_with_model_name(self, client, formrecognizer_storage_cont
3333
self.assertEqual(composed_model.model_name, "my composed model")
3434
self.assertComposedModelHasValues(composed_model, model_1, model_2)
3535

36+
composed_model_dict = composed_model.to_dict()
37+
self.assertEqual(composed_model_dict.get("model_name"), "my composed model")
38+
self.assertIsNotNone(composed_model_dict.get("model_id"))
39+
3640
@FormRecognizerPreparer()
3741
@GlobalClientPreparer()
3842
def test_compose_model_no_model_name(self, client, formrecognizer_storage_container_sas_url):

sdk/formrecognizer/azure-ai-formrecognizer/tests/test_compose_model_async.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ async def test_compose_model_with_model_name(self, client, formrecognizer_storag
3333
self.assertEqual(composed_model.model_name, "my composed model")
3434
self.assertComposedModelHasValues(composed_model, model_1, model_2)
3535

36+
composed_model_dict = composed_model.to_dict()
37+
self.assertEqual(composed_model_dict.get("model_name"), "my composed model")
38+
self.assertIsNotNone(composed_model_dict.get("model_id"))
39+
3640
@FormRecognizerPreparer()
3741
@GlobalClientPreparer()
3842
async def test_compose_model_no_model_name(self, client, formrecognizer_storage_container_sas_url):

sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ def callback(raw_response, _, headers):
205205
self.assertIsNotNone(recognized_form[0].model_id)
206206
self.assertUnlabeledFormFieldDictTransformCorrect(recognized_form[0].fields, actual_fields, read_results)
207207

208+
recognized_form_dict = [v.to_dict() for v in recognized_form]
209+
self.assertIsNone(recognized_form_dict[0].get("form_type_confidence"))
210+
self.assertIsNotNone(recognized_form_dict[0].get("model_id"))
211+
self.assertEqual(recognized_form_dict[0].get("form_type"), "form-0")
212+
208213
@FormRecognizerPreparer()
209214
@GlobalClientPreparer()
210215
def test_custom_form_multipage_unlabeled_transform(self, client, formrecognizer_multipage_storage_container_sas_url):

sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_async.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ def callback(raw_response, _, headers):
224224
self.assertIsNotNone(recognized_form[0].model_id)
225225
self.assertUnlabeledFormFieldDictTransformCorrect(recognized_form[0].fields, actual_fields, read_results)
226226

227+
recognized_form_dict = [v.to_dict() for v in recognized_form]
228+
self.assertIsNone(recognized_form_dict[0].get("form_type_confidence"))
229+
self.assertIsNotNone(recognized_form_dict[0].get("model_id"))
230+
self.assertEqual(recognized_form_dict[0].get("form_type"), "form-0")
231+
227232
@FormRecognizerPreparer()
228233
@GlobalClientPreparer()
229234
async def test_custom_forms_multipage_unlabeled_transform(self, client, formrecognizer_multipage_storage_container_sas_url):

0 commit comments

Comments
 (0)