Skip to content

Commit 7719ef0

Browse files
committed
Finished scala tests
1 parent 7f24745 commit 7719ef0

File tree

8 files changed

+863
-12
lines changed

8 files changed

+863
-12
lines changed

src/main/scala/io/github/hexagonnico/cmplxlib/matrix/Mat2c.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ case class Mat2c(m00: Complex, m01: Complex, m10: Complex, m11: Complex) extends
9090
* @param v The vector by which this matrix is multiplied
9191
* @return The product of this matrix by the given vector
9292
*/
93-
override def *(v: Vec2c): Vec2c = Vec2c(this.row0 dot v, this.row1 dot v)
93+
override def *(v: Vec2c): Vec2c = Vec2c(
94+
this.m00 * v.x + this.m01 * v.y,
95+
this.m10 * v.x + this.m11 * v.y
96+
)
9497

9598
/**
9699
* Returns the product of this matrix by the vector with the given components.
@@ -119,8 +122,8 @@ case class Mat2c(m00: Complex, m01: Complex, m10: Complex, m11: Complex) extends
119122
* @return The product between this matrix and the given one
120123
*/
121124
override def *(m: Mat2c): Mat2c = Mat2c(
122-
this.row0 dot m.col0, this.row0 dot m.col1,
123-
this.row1 dot m.col0, this.row1 dot m.col1
125+
this.m00 * m.m00 + this.m01 * m.m10, this.m00 * m.m01 + this.m01 * m.m11,
126+
this.m10 * m.m00 + this.m11 * m.m10, this.m10 * m.m01 + this.m11 * m.m11
124127
)
125128

126129
/**

src/main/scala/io/github/hexagonnico/cmplxlib/matrix/Mat3c.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,11 @@ case class Mat3c(
117117
* @param v The vector by which this matrix is multiplied
118118
* @return The product of this matrix by the given vector
119119
*/
120-
override def *(v: Vec3c): Vec3c = Vec3c(this.row0 dot v, this.row1 dot v, this.row2 dot v)
120+
override def *(v: Vec3c): Vec3c = Vec3c(
121+
this.m00 * v.x + this.m01 * v.y + this.m02 * v.z,
122+
this.m10 * v.x + this.m11 * v.y + this.m12 * v.z,
123+
this.m20 * v.x + this.m21 * v.y + this.m22 * v.z
124+
)
121125

122126
/**
123127
* Returns the product of this matrix by the vector with the given components.
@@ -148,9 +152,15 @@ case class Mat3c(
148152
* @return The product between this matrix and the given one
149153
*/
150154
override def *(m: Mat3c): Mat3c = Mat3c(
151-
this.row0 dot m.col0, this.row0 dot m.col1, this.row0 dot m.col2,
152-
this.row1 dot m.col0, this.row1 dot m.col1, this.row1 dot m.col2,
153-
this.row2 dot m.col0, this.row2 dot m.col1, this.row2 dot m.col2
155+
this.m00 * m.m00 + this.m01 * m.m10 + this.m02 * m.m20,
156+
this.m00 * m.m01 + this.m01 * m.m11 + this.m02 * m.m21,
157+
this.m00 * m.m02 + this.m01 * m.m12 + this.m02 * m.m22,
158+
this.m10 * m.m00 + this.m11 * m.m10 + this.m12 * m.m20,
159+
this.m10 * m.m01 + this.m11 * m.m11 + this.m12 * m.m21,
160+
this.m10 * m.m02 + this.m11 * m.m12 + this.m12 * m.m22,
161+
this.m20 * m.m00 + this.m21 * m.m10 + this.m22 * m.m20,
162+
this.m20 * m.m01 + this.m21 * m.m11 + this.m22 * m.m21,
163+
this.m20 * m.m02 + this.m21 * m.m12 + this.m22 * m.m22
154164
)
155165

156166
/**

src/main/scala/io/github/hexagonnico/cmplxlib/matrix/Mat4c.scala

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,12 @@ case class Mat4c(
143143
* @param v The vector by which this matrix is multiplied
144144
* @return The product of this matrix by the given vector
145145
*/
146-
override def *(v: Vec4c): Vec4c = Vec4c(this.row0 dot v, this.row1 dot v, this.row2 dot v, this.row3 dot v)
146+
override def *(v: Vec4c): Vec4c = Vec4c(
147+
this.m00 * v.x + this.m01 * v.y + this.m02 * v.z + this.m03 * v.w,
148+
this.m10 * v.x + this.m11 * v.y + this.m12 * v.z + this.m13 * v.w,
149+
this.m20 * v.x + this.m21 * v.y + this.m22 * v.z + this.m23 * v.w,
150+
this.m30 * v.x + this.m31 * v.y + this.m32 * v.z + this.m33 * v.w
151+
)
147152

148153
/**
149154
* Returns the product of this matrix by the vector with the given components.
@@ -176,10 +181,22 @@ case class Mat4c(
176181
* @return The product between this matrix and the given one
177182
*/
178183
override def *(m: Mat4c): Mat4c = Mat4c(
179-
this.row0 dot m.col0, this.row0 dot m.col1, this.row0 dot m.col2, this.row0 dot m.col3,
180-
this.row1 dot m.col0, this.row1 dot m.col1, this.row1 dot m.col2, this.row1 dot m.col3,
181-
this.row2 dot m.col0, this.row2 dot m.col1, this.row2 dot m.col2, this.row2 dot m.col3,
182-
this.row3 dot m.col0, this.row3 dot m.col1, this.row3 dot m.col2, this.row3 dot m.col3
184+
this.m00 * m.m00 + this.m01 * m.m10 + this.m02 * m.m20 + this.m03 * m.m30,
185+
this.m00 * m.m01 + this.m01 * m.m11 + this.m02 * m.m21 + this.m03 * m.m31,
186+
this.m00 * m.m02 + this.m01 * m.m12 + this.m02 * m.m22 + this.m03 * m.m32,
187+
this.m00 * m.m03 + this.m01 * m.m13 + this.m02 * m.m23 + this.m03 * m.m33,
188+
this.m10 * m.m00 + this.m11 * m.m10 + this.m12 * m.m20 + this.m13 * m.m30,
189+
this.m10 * m.m01 + this.m11 * m.m11 + this.m12 * m.m21 + this.m13 * m.m31,
190+
this.m10 * m.m02 + this.m11 * m.m12 + this.m12 * m.m22 + this.m13 * m.m32,
191+
this.m10 * m.m03 + this.m11 * m.m13 + this.m12 * m.m23 + this.m13 * m.m33,
192+
this.m20 * m.m00 + this.m21 * m.m10 + this.m22 * m.m20 + this.m23 * m.m30,
193+
this.m20 * m.m01 + this.m21 * m.m11 + this.m22 * m.m21 + this.m23 * m.m31,
194+
this.m20 * m.m02 + this.m21 * m.m12 + this.m22 * m.m22 + this.m23 * m.m32,
195+
this.m20 * m.m03 + this.m21 * m.m13 + this.m22 * m.m23 + this.m23 * m.m33,
196+
this.m30 * m.m00 + this.m31 * m.m10 + this.m32 * m.m20 + this.m33 * m.m30,
197+
this.m30 * m.m01 + this.m31 * m.m11 + this.m32 * m.m21 + this.m33 * m.m31,
198+
this.m30 * m.m02 + this.m31 * m.m12 + this.m32 * m.m22 + this.m33 * m.m32,
199+
this.m30 * m.m03 + this.m31 * m.m13 + this.m32 * m.m23 + this.m33 * m.m33
183200
)
184201

185202
/**
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package io.github.hexagonnico.cmplxlib.matrix
2+
3+
import io.github.hexagonnico.cmplxlib.Complex
4+
import io.github.hexagonnico.cmplxlib.vector.Vec2c
5+
import org.scalatest.funsuite.AnyFunSuite
6+
7+
class Mat2cSuite extends AnyFunSuite {
8+
9+
test("Sum of matrices") {
10+
val a = Mat2c(
11+
Complex(1.0, 2.0), Complex(1.5, 1.0),
12+
Complex(0.5, 1.5), Complex(3.0, 2.0)
13+
)
14+
val b = Mat2c(
15+
Complex(2.0, 2.5), Complex(3.0, 0.5),
16+
Complex(1.0, 1.5), Complex(4.0, 3.5)
17+
)
18+
assert(a + b == Mat2c(
19+
Complex(3.0, 4.5), Complex(4.5, 1.5),
20+
Complex(1.5, 3.0), Complex(7.0, 5.5)
21+
))
22+
}
23+
24+
test("Negative matrix") {
25+
assert(Mat2c(
26+
Complex(1.0, 2.0), Complex(1.5, 1.0),
27+
Complex(0.5, 1.5), Complex(3.0, 2.0)
28+
).negated == Mat2c(
29+
Complex(-1.0, -2.0), Complex(-1.5, -1.0),
30+
Complex(-0.5, -1.5), Complex(-3.0, -2.0)
31+
))
32+
}
33+
34+
test("Subtraction of matrices") {
35+
val a = Mat2c(
36+
Complex(1.0, 2.0), Complex(1.5, 1.0),
37+
Complex(0.5, 1.5), Complex(3.0, 2.0)
38+
)
39+
val b = Mat2c(
40+
Complex(2.0, 2.5), Complex(3.0, 0.5),
41+
Complex(1.0, 1.5), Complex(4.0, 3.5)
42+
)
43+
assert(a - b == Mat2c(
44+
Complex(-1.0, -0.5), Complex(-1.5, 0.5),
45+
Complex(-0.5, 0.0), Complex(-1.0, -1.5)
46+
))
47+
}
48+
49+
test("Matrix multiplied by a scalar") {
50+
assert(Mat2c(
51+
Complex(1.0, 2.0), Complex(1.5, 1.0),
52+
Complex(0.5, 1.5), Complex(3.0, 2.0)
53+
) * 1.5 == Mat2c(
54+
Complex(1.5, 3.0), Complex(2.25, 1.5),
55+
Complex(0.75, 2.25), Complex(4.5, 3.0)
56+
))
57+
}
58+
59+
test("Matrix multiplied by a scalar commutativity") {
60+
assert(1.5 * Mat2c(
61+
Complex(1.0, 2.0), Complex(1.5, 1.0),
62+
Complex(0.5, 1.5), Complex(3.0, 2.0)
63+
) == Mat2c(
64+
Complex(1.5, 3.0), Complex(2.25, 1.5),
65+
Complex(0.75, 2.25), Complex(4.5, 3.0)
66+
))
67+
}
68+
69+
test("Matrix-vector product") {
70+
val mat = Mat2c(
71+
Complex(1.0, 2.0), Complex(1.5, 1.0),
72+
Complex(0.5, 1.5), Complex(3.0, 2.0)
73+
)
74+
val vec = Vec2c(Complex(1.0, 1.0), Complex(2.0, 3.0))
75+
assert(mat * vec == Vec2c(Complex(-1.0, 9.5), Complex(-1.0, 15.0)))
76+
}
77+
78+
test("Matrix-vector product by values") {
79+
val mat = Mat2c(
80+
Complex(1.0, 2.0), Complex(1.5, 1.0),
81+
Complex(0.5, 1.5), Complex(3.0, 2.0)
82+
)
83+
assert(mat * (Complex(1.0, 1.0), Complex(2.0, 3.0)) == Vec2c(Complex(-1.0, 9.5), Complex(-1.0, 15.0)))
84+
}
85+
86+
test("Transposed") {
87+
assert(Mat2c(
88+
Complex(1.0, 2.0), Complex(1.5, 1.0),
89+
Complex(0.5, 1.5), Complex(3.0, 2.0)
90+
).transposed == Mat2c(
91+
Complex(1.0, 2.0), Complex(0.5, 1.5),
92+
Complex(1.5, 1.0), Complex(3.0, 2.0)
93+
))
94+
}
95+
96+
test("Symmetric matrix") {
97+
assert(Mat2c(
98+
Complex.Zero, Complex.One,
99+
Complex.One, Complex.Zero
100+
).isSymmetric)
101+
}
102+
103+
test("Skew symmetric matrix") {
104+
assert(Mat2c(
105+
Complex.Zero, Complex.One,
106+
-Complex.One, Complex.Zero
107+
).isSkewSymmetric)
108+
}
109+
110+
test("Complex conjugate of a matrix") {
111+
assert(Mat2c(
112+
Complex(1.0, 1.0), Complex(2.0, -2.0),
113+
Complex(-3.0, 3.0), Complex(4.0, 4.0)
114+
).conjugate == Mat2c(
115+
Complex(1.0, -1.0), Complex(2.0, 2.0),
116+
Complex(-3.0, -3.0), Complex(4.0, -4.0)
117+
))
118+
}
119+
120+
test("Conjugate transposed of a matrix") {
121+
assert(Mat2c(
122+
Complex(1.0, 1.0), Complex(2.0, -2.0),
123+
Complex(-3.0, 3.0), Complex(4.0, 4.0)
124+
).hermitian == Mat2c(
125+
Complex(1.0, -1.0), Complex(-3.0, -3.0),
126+
Complex(2.0, 2.0), Complex(4.0, -4.0)
127+
))
128+
}
129+
130+
test("Matrix product") {
131+
val a = Mat2c(
132+
Complex(1.0, 2.0), Complex(1.5, 1.0),
133+
Complex(0.5, 1.5), Complex(3.0, 2.0)
134+
)
135+
val b = Mat2c(
136+
Complex(2.0, 2.5), Complex(3.0, 0.5),
137+
Complex(1.0, 1.5), Complex(4.0, 3.5)
138+
)
139+
assert(a * b == Mat2c(
140+
Complex(-3.0, 9.75), Complex(4.5, 15.75),
141+
Complex(-2.75, 10.75), Complex(5.75, 23.25)
142+
))
143+
}
144+
145+
test("Matrix power") {
146+
val a = Mat2c(
147+
Complex(1.0, 2.0), Complex(1.5, 1.0),
148+
Complex(0.5, 1.5), Complex(3.0, 2.0)
149+
)
150+
assert((a power 3) == (a * a * a))
151+
}
152+
153+
test("Matrix determinant") {
154+
val a = Mat2c(
155+
Complex(1.0, 2.0), Complex(1.5, 1.0),
156+
Complex(0.5, 1.5), Complex(3.0, 2.0)
157+
)
158+
assert(a.determinant == Complex(-0.25, 5.25))
159+
}
160+
}

0 commit comments

Comments
 (0)