|
1 | 1 | struct CoupledDofsRestriction{TM} <: AbstractRestriction |
2 | | - coupling_matrix::TM |
| 2 | + coupling_matrices::Vector{TM} |
3 | 3 | parameters::Dict{Symbol, Any} |
4 | 4 | end |
5 | 5 |
|
6 | 6 |
|
7 | 7 | """ |
8 | 8 | CoupledDofsRestriction(matrix::AbstractMatrix) |
9 | 9 |
|
10 | | - Creates an restriction that couples dofs together. |
| 10 | + Creates a restriction that couples dofs together. |
11 | 11 |
|
12 | 12 | The coupling is stored in a CSC Matrix `matrix`, s.t., |
13 | 13 |
|
|
16 | 16 | The matrix can be obtained from, e.g., `get_periodic_coupling_matrix`. |
17 | 17 | """ |
18 | 18 | function CoupledDofsRestriction(matrix::AbstractMatrix) |
19 | | - return CoupledDofsRestriction(matrix, Dict{Symbol, Any}(:name => "CoupledDofsRestriction")) |
| 19 | + return CoupledDofsRestriction([matrix], Dict{Symbol, Any}(:name => "CoupledDofsRestriction")) |
20 | 20 | end |
21 | 21 |
|
22 | 22 |
|
23 | | -function assemble!(R::CoupledDofsRestriction, sol, SC; kwargs...) |
| 23 | +""" |
| 24 | + CoupledDofsRestriction(matrices::Vector{AM}) where {AM <: AbstractMatrix} |
24 | 25 |
|
25 | | - # extract all col indices |
26 | | - _, J, _ = findnz(R.coupling_matrix) |
| 26 | + Creates a `CoupledDofsRestriction` from multiple given coupling matrices. |
| 27 | +""" |
| 28 | +function CoupledDofsRestriction(matrices::Vector{AM}) where {AM <: AbstractMatrix} |
| 29 | + return CoupledDofsRestriction(matrices, Dict{Symbol, Any}(:name => "CoupledDofsRestriction")) |
| 30 | +end |
27 | 31 |
|
28 | | - # remove duplicates |
29 | | - unique_cols = unique(J) |
30 | 32 |
|
| 33 | +function assemble!(R::CoupledDofsRestriction, sol, SC; kwargs...) |
| 34 | + |
| 35 | + # extract all col indices and remove duplicates |
31 | 36 | # subtract diagonal and shrink matrix to non-empty cols |
32 | | - B = (R.coupling_matrix - LinearAlgebra.I)[:, unique_cols] |
| 37 | + Bs = [ (matrix - LinearAlgebra.I)[:, unique(findnz(matrix)[2])] for matrix in R.coupling_matrices ] |
| 38 | + |
| 39 | + # combine all into one matrix |
| 40 | + B = hcat(Bs...) |
| 41 | + |
| 42 | + # eliminate redundant cols by QR: |
| 43 | + qr_result = qr(B) |
| 44 | + |
| 45 | + # pick minimal number of cols that are rank preserving |
| 46 | + cols_of_interest = qr_result.pcol[1:rank(qr_result)] |
| 47 | + B = B[:, cols_of_interest] |
33 | 48 |
|
34 | 49 | R.parameters[:matrix] = B |
35 | | - R.parameters[:rhs] = Zeros(length(unique_cols)) |
| 50 | + R.parameters[:rhs] = Zeros(size(B, 2)) |
36 | 51 |
|
37 | 52 | # fixed dofs are all active rows of B |
38 | | - I, _, _ = findnz(B) |
39 | | - R.parameters[:fixed_dofs] = unique(I) |
| 53 | + R.parameters[:fixed_dofs] = unique(findnz(B)[1]) |
40 | 54 |
|
41 | 55 | return nothing |
42 | 56 | end |
0 commit comments