Skip to content

Commit 676b02a

Browse files
make QGF with irreducible_poly=None compatible with QGF with the correct poly (#1540)
1 parent 1d9ecc7 commit 676b02a

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

qualtran/_infra/data_types.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
import abc
5252
from enum import Enum
5353
from functools import cached_property
54-
from typing import Any, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union
54+
from typing import Any, Iterable, List, Literal, Optional, Sequence, TYPE_CHECKING, Union
5555

5656
import attrs
5757
import numpy as np
@@ -906,8 +906,19 @@ class QGF(QDType):
906906

907907
characteristic: SymbolicInt
908908
degree: SymbolicInt
909-
irreducible_poly: Optional['galois.Poly'] = None
910-
element_repr: str = 'int'
909+
irreducible_poly: Optional['galois.Poly'] = attrs.field()
910+
element_repr: Literal["int", "poly", "power"] = attrs.field(default='int')
911+
912+
@irreducible_poly.default
913+
def _irreducible_poly_default(self):
914+
if is_symbolic(self.characteristic, self.degree):
915+
return None
916+
917+
from galois import GF
918+
919+
return GF( # type: ignore[call-overload]
920+
int(self.characteristic), int(self.degree), compile='python-calculate'
921+
).irreducible_poly
911922

912923
@cached_property
913924
def order(self) -> SymbolicInt:
@@ -936,10 +947,12 @@ def _quint_equivalent(self) -> QUInt:
936947
def gf_type(self):
937948
from galois import GF
938949

950+
poly = self.irreducible_poly if self.degree > 1 else None
951+
939952
return GF( # type: ignore[call-overload]
940953
int(self.characteristic),
941954
int(self.degree),
942-
irreducible_poly=self.irreducible_poly,
955+
irreducible_poly=poly,
943956
repr=self.element_repr,
944957
compile='python-calculate',
945958
)

qualtran/_infra/data_types_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,3 +540,11 @@ def test_montgomery_bit_conversion(bitsize):
540540
dtype = QMontgomeryUInt(bitsize)
541541
for v in range(1 << bitsize):
542542
assert v == dtype.from_bits(dtype.to_bits(v))
543+
544+
545+
def test_qgf_with_default_poly_is_compatible():
546+
qgf_one = QGF(2, 4)
547+
548+
qgf_two = QGF(2, 4, irreducible_poly=qgf_one.gf_type.irreducible_poly)
549+
550+
assert qgf_one == qgf_two

0 commit comments

Comments
 (0)