From fc4dabe50d1199f1cd9ac79e92339ba8827283ae Mon Sep 17 00:00:00 2001 From: happypeepo Date: Sat, 25 Oct 2025 23:09:39 +0530 Subject: [PATCH] Added type hints to decode functions and ReportUtils --- pytm/__init__.py | 9 ++++++--- pytm/json.py | 24 +++++++++++++----------- pytm/report_util.py | 9 +++++---- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/pytm/__init__.py b/pytm/__init__.py index dbe277d6..f5bc7b63 100644 --- a/pytm/__init__.py +++ b/pytm/__init__.py @@ -24,6 +24,9 @@ ] import sys +from types import ModuleType +from typing import Any, Dict + from .json import load, loads from .pytm import ( @@ -51,8 +54,8 @@ ) -def pdoc_overrides(): - result = {"pytm": False, "json": False, "template_engine": False} +def pdoc_overrides() -> Dict[str, Any]: + result: Dict[str, Any] = {"pytm": False, "json": False, "template_engine": False} mod = sys.modules[__name__] for name, klass in mod.__dict__.items(): if not isinstance(klass, type): @@ -60,7 +63,7 @@ def pdoc_overrides(): for i in dir(klass): if i in ("check", "dfd", "seq"): result[f"{name}.{i}"] = False - attr = getattr(klass, i, {}) + attr: Any = getattr(klass, i, {}) if isinstance(attr, var) and attr.doc != "": result[f"{name}.{i}"] = attr.doc return result diff --git a/pytm/json.py b/pytm/json.py index db0c7af7..1e22253c 100644 --- a/pytm/json.py +++ b/pytm/json.py @@ -1,5 +1,7 @@ import json import sys +from typing import Any, TextIO, Dict, Union, List + from .pytm import ( TM, @@ -18,23 +20,23 @@ ) -def loads(s): +def loads(s: str) -> "TM": """Load a TM object from a JSON string *s*.""" - result = json.loads(s, object_hook=decode) + result: Any = json.loads(s, object_hook=decode) if not isinstance(result, TM): raise ValueError("Failed to decode JSON input as TM") return result -def load(fp): +def load(fp: TextIO) -> "TM": """Load a TM object from an open file containing JSON.""" - result = json.load(fp, object_hook=decode) + result: Any = json.load(fp, object_hook=decode) if not isinstance(result, TM): raise ValueError("Failed to decode JSON input as TM") return result -def decode(data): +def decode(data: Dict[str, Any]) -> Union[Dict[str, Any], TM]: if "elements" not in data and "flows" not in data and "boundaries" not in data: return data @@ -49,9 +51,9 @@ def decode(data): return TM(data.pop("name"), **data) -def decode_boundaries(flat): - boundaries = {} - refs = {} +def decode_boundaries(flat: List[Dict[str, Any]]) -> Dict[str, Boundary]: + boundaries: Dict[str, Boundary] = {} + refs: Dict[str, str] = {} for i, e in enumerate(flat): name = e.pop("name", None) if name is None: @@ -70,8 +72,8 @@ def decode_boundaries(flat): return boundaries -def decode_elements(flat, boundaries): - elements = {} +def decode_elements(flat: List[Dict[str, Any]], boundaries: Dict[str, Boundary]) -> Dict[str, Any]: + elements: Dict[str, Any] = {} for i, e in enumerate(flat): klass = getattr(sys.modules[__name__], e.pop("__class__", "Asset")) name = e.pop("name", None) @@ -89,7 +91,7 @@ def decode_elements(flat, boundaries): return elements -def decode_flows(flat, elements): +def decode_flows(flat: List[Dict[str, Any]], elements: Dict[str, Any]) -> None: for i, e in enumerate(flat): name = e.pop("name", None) if name is None: diff --git a/pytm/report_util.py b/pytm/report_util.py index 90df7de6..1fb24594 100644 --- a/pytm/report_util.py +++ b/pytm/report_util.py @@ -1,7 +1,8 @@ +from typing import Any, List, Union class ReportUtils: @staticmethod - def getParentName(element): + def getParentName(element: Any) -> str: from pytm import Boundary if (isinstance(element, Boundary)): parent = element.inBoundary @@ -14,7 +15,7 @@ def getParentName(element): @staticmethod - def getNamesOfParents(element): + def getNamesOfParents(element: Any) -> Union[List[str], str]: from pytm import Boundary if (isinstance(element, Boundary)): parents = [p.name for p in element.parents()] @@ -23,7 +24,7 @@ def getNamesOfParents(element): return "ERROR: getNamesOfParents method is not valid for " + element.__class__.__name__ @staticmethod - def getFindingCount(element): + def getFindingCount(element: Any) -> str: from pytm import Element if (isinstance(element, Element)): return str(len(list(element.findings))) @@ -31,7 +32,7 @@ def getFindingCount(element): return "ERROR: getFindingCount method is not valid for " + element.__class__.__name__ @staticmethod - def getElementType(element): + def getElementType(element: Any) -> str: from pytm import Element if (isinstance(element, Element)): return str(element.__class__.__name__)