Skip to content

Commit 4951963

Browse files
authored
Merge pull request #1 from SymbolicML/nested-types
Fix behavior for nested types
2 parents 8360881 + d895e37 commit 4951963

File tree

3 files changed

+83
-34
lines changed

3 files changed

+83
-34
lines changed

README.md

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,30 @@ extracts the base numeric type from a numeric type `T`:
1313
For example,
1414

1515
| Input Type | Output Type |
16-
|---|---|
16+
|:-:|---|
1717
| `Float32` | `Float32` |
1818
| `ComplexF32` | `Float32` |
1919
| `Measurement{Float32}` | `Float32` |
2020
| `Dual{BigFloat}` | `BigFloat` |
21+
| `Dual{ComplexF32}` | `Float32` |
2122
| `Rational{Int8}` | `Int8` |
22-
| `Quantity{Float32,Dimensions}` | `Float32` |
23+
| `Quantity{Float32, ...}` | `Float32` |
24+
| `Quantity{Measurement{Float32}, ...}` | `Float32` |
2325

24-
Packages should write a method to `base_numeric_type`
25-
when the base type of a numeric type
26-
is not the first parametric type.
27-
For example, if you were to create a quantity-like type
28-
`Quantity{Dimensions,NumericType}`, you would need
29-
to write a custom interface.
26+
Package maintainers should write a specialized method for their type.
27+
For example, to define the base numeric type for a dual number, one could write:
3028

31-
But if the base type comes first,
32-
the default method will work.
29+
```julia
30+
import BaseType: base_numeric_type
31+
32+
base_numeric_type(::Type{Dual{T}}) where {T} = base_numeric_type(T)
33+
```
34+
35+
It is important to call `base_numeric_type` recursively like this to deal with
36+
nested numeric types such as `Quantity{Measurement{T}}`.
37+
38+
The fallback behavior of `base_numeric_type` is to return the *first* type parameter,
39+
or, if that type has parameters of its own (such as `Dual{Complex{Float32}}`),
40+
to recursively take the first type parameter until a non-parameterized type is found.
41+
This works for the vast majority of types, but it is still preferred
42+
if package maintainers write a specialized method.

src/BaseType.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,35 @@ as a measurement or a quantity.
1212
For example,
1313
1414
| Input Type | Output Type |
15-
|---|---|
15+
|:-:|---|
1616
| `Float32` | `Float32` |
1717
| `ComplexF32` | `Float32` |
1818
| `Measurement{Float32}` | `Float32` |
19-
| `Rational{Int8}` | `Int8` |
2019
| `Dual{BigFloat}` | `BigFloat` |
21-
| `Quantity{Float32,Dimensions}` | `Float32` |
20+
| `Rational{Int8}` | `Int8` |
21+
| `Quantity{Float32, ...}` | `Float32` |
22+
| `Quantity{Measurement{Float32}, ...}` | `Float32` |
23+
| `Dual{Complex{Float32}}` | `Float32` |
24+
25+
The standard behavior is to return the *first* type parameter,
26+
or, if that type has parameters of its own (such as `Dual{Complex{Float32}}`),
27+
to recursively take the first type parameter until a non-parameterized type is found.
2228
"""
2329
@generated function base_numeric_type(::Type{T}) where {T}
24-
params = T isa UnionAll ? T.body.parameters : T.parameters
25-
return isempty(params) ? :($T) : :($(first(params)))
30+
# This uses a generated function for type stability in Julia <=1.9,
31+
# though in Julia >=1.10 it is not necessary.
32+
# TODO: switch to non-generated when Julia >= 1.10 is LTS.
33+
return :($(_base_numeric_type(T)))
2634
end
2735
base_numeric_type(x) = base_numeric_type(typeof(x))
2836

37+
function _base_numeric_type(::Type{T}) where {T}
38+
params = T isa UnionAll ? T.body.parameters : T.parameters
39+
if isempty(params)
40+
return T
41+
else
42+
return _base_numeric_type(first(params))
43+
end
44+
end
45+
2946
end

test/unittests.jl

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,49 @@
1-
using Test: @test, @inferred
1+
using Test: @test, @testset, @inferred
22
using BaseType: base_numeric_type
3-
using DualNumbers: DualNumbers
3+
using DualNumbers: DualNumbers, Dual
44
using DynamicQuantities: DynamicQuantities
55
using Measurements: ±
66
using Unitful: Unitful
77

8-
expected_type_pairs = [
9-
Float32 => Float32,
10-
ComplexF64 => Float64,
11-
DualNumbers.Dual{Int64} => Int64,
12-
DynamicQuantities.Quantity{Float32} => Float32,
13-
typeof(1.5DynamicQuantities.u"km/s") => Float64,
14-
typeof(1.5f0Unitful.u"km/s") => Float32,
15-
BigFloat => BigFloat,
16-
typeof(1.5 ± 0.2) => Float64,
17-
typeof(1.5f0 ± 0.2f0) => Float32,
18-
]
8+
@testset "Basic usage" begin
9+
expected_type_pairs = [
10+
Float32 => Float32,
11+
ComplexF64 => Float64,
12+
DualNumbers.Dual{Int64} => Int64,
13+
DynamicQuantities.Quantity{Float32} => Float32,
14+
typeof(1.5DynamicQuantities.u"km/s") => Float64,
15+
typeof(1.5f0Unitful.u"km/s") => Float32,
16+
BigFloat => BigFloat,
17+
typeof(1.5 ± 0.2) => Float64,
18+
typeof(1.5f0 ± 0.2f0) => Float32,
19+
]
1920

20-
for (x, y) in expected_type_pairs
21-
@eval @test base_numeric_type($x) == $y
22-
# Make sure compiler can inline it:
23-
@eval @inferred $y base_numeric_type($x)
21+
for (x, y) in expected_type_pairs
22+
@eval @test base_numeric_type($x) == $y
23+
# Make sure compiler can inline it:
24+
@eval @inferred $y base_numeric_type($x)
25+
end
26+
27+
@test base_numeric_type(1.5DynamicQuantities.u"km/s") == base_numeric_type(typeof(1.5DynamicQuantities.u"km/s"))
28+
@inferred base_numeric_type(1.5DynamicQuantities.u"km/s")
29+
end
30+
31+
@testset "Nested types" begin
32+
# Quantity ∘ Measurement:
33+
x = 5Unitful.u"m/s" ± 0.1Unitful.u"m/s"
34+
@test base_numeric_type(x) == Float64
35+
36+
# Quantity ∘ Dual:
37+
y = Dual(1.0)Unitful.u"m/s"
38+
@test base_numeric_type(y) == Float64
2439
end
2540

26-
@test base_numeric_type(1.5DynamicQuantities.u"km/s") == base_numeric_type(typeof(1.5DynamicQuantities.u"km/s"))
27-
@inferred base_numeric_type(1.5DynamicQuantities.u"km/s")
41+
struct Node{T}
42+
child::Union{Node{T},Nothing}
43+
value::T
44+
end
45+
46+
@testset "Safe default behavior for recursive types" begin
47+
c = Node{Int}(Node{Int}(nothing, 1), 2)
48+
@test base_numeric_type(c) == Int
49+
end

0 commit comments

Comments
 (0)