Skip to content

Commit d0c24e7

Browse files
committed
Ensure schema typing and datetime parsing lint compliance
1 parent c87bbcb commit d0c24e7

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

slurm_usage.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import os
2525
import re
2626
import subprocess
27+
import types
2728
import typing
2829
from collections import defaultdict
2930
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -678,7 +679,7 @@ def to_dict(self) -> dict[str, Any]:
678679
@classmethod
679680
def get_polars_schema(cls) -> dict[str, pl.DataType]:
680681
"""Get Polars schema derived from Pydantic model fields."""
681-
mapping: dict[str, pl.DataType] = {
682+
mapping: dict[type[Any], pl.DataType] = {
682683
str: pl.Utf8,
683684
int: pl.Int64,
684685
float: pl.Float64,
@@ -687,18 +688,24 @@ def get_polars_schema(cls) -> dict[str, pl.DataType]:
687688
_DATETIME_TYPE: pl.Datetime("us", "UTC"),
688689
}
689690

690-
schema = {}
691+
schema: dict[str, pl.DataType] = {}
691692
for field_name, field_info in cls.model_fields.items():
692693
annotation = field_info.annotation
693694

694695
# Handle Optional types (Union[T, None] or T | None)
695696
origin = typing.get_origin(annotation)
696-
if origin is typing.Union or origin is type(None | int):
697+
if origin in (typing.Union, types.UnionType):
697698
args = typing.get_args(annotation)
698-
# Get the non-None type
699-
annotation = next((arg for arg in args if arg is not type(None)), str)
700-
# Map Python types to Polars types
701-
schema[field_name] = mapping.get(annotation, pl.Utf8)
699+
non_none_args = [arg for arg in args if arg is not type(None)]
700+
if non_none_args:
701+
annotation = non_none_args[0]
702+
703+
mapped_type: pl.DataType | None = None
704+
if isinstance(annotation, type):
705+
mapped_type = mapping.get(annotation)
706+
707+
# Map Python types to Polars types (default to Utf8 for unknown types)
708+
schema[field_name] = mapped_type or pl.Utf8
702709

703710
return schema
704711

@@ -811,13 +818,15 @@ def _parse_datetime(date_str: str | None) -> datetime | None:
811818
try:
812819
# SLURM uses ISO format: 2025-08-19T10:30:00
813820
dt = datetime.fromisoformat(date_str)
814-
# Ensure timezone-aware (assume UTC if naive)
815-
if dt.tzinfo is None:
816-
dt = dt.replace(tzinfo=UTC)
817-
return dt
818821
except (ValueError, AttributeError):
819822
return None
820823

824+
# Ensure timezone-aware (assume UTC if naive)
825+
if dt.tzinfo is None:
826+
dt = dt.replace(tzinfo=UTC)
827+
828+
return dt
829+
821830

822831
def _parse_gpu_count(alloc_tres: str) -> int:
823832
"""Parse GPU count from AllocTRES string.

0 commit comments

Comments
 (0)