Skip to content

Commit dc6da7d

Browse files
committed
Cleaned up TCP network sessions. Changed various tests accordingly and reduced code duplication.
1 parent 904c68b commit dc6da7d

File tree

10 files changed

+205
-172
lines changed

10 files changed

+205
-172
lines changed

Application/Program.cs

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
using System.Text;
66
using System.Reflection;
77
using System.Diagnostics;
8+
using System.Net;
89

9-
using CompactMPC.Circuits;
10-
using CompactMPC.Circuits.Statistics;
1110
using CompactMPC.Networking;
1211
using CompactMPC.Protocol;
1312
using CompactMPC.ObliviousTransfer;
@@ -29,43 +28,6 @@ public static void Main(string[] args)
2928
BitArray.FromBinaryString("1101100010"),
3029
BitArray.FromBinaryString("0111110011")
3130
};
32-
33-
CircuitBuilder builder = new CircuitBuilder();
34-
(new SetIntersectionCircuitRecorder(NumberOfParties, NumberOfElements)).Record(builder);
35-
36-
Circuit circuit = builder.CreateCircuit();
37-
38-
CircuitStatistics statistics = CircuitStatistics.FromCircuit(circuit);
39-
40-
Console.WriteLine("--- Circuit Statistics ---");
41-
Console.WriteLine("Number of inputs: {0}", statistics.NumberOfInputs);
42-
Console.WriteLine("Number of outputs: {0}", statistics.NumberOfOutputs);
43-
Console.WriteLine("Number of ANDs: {0}", circuit.Context.NumberOfAndGates);
44-
Console.WriteLine("Number of XORs: {0}", circuit.Context.NumberOfXorGates);
45-
Console.WriteLine("Number of NOTs: {0}", circuit.Context.NumberOfNotGates);
46-
Console.WriteLine(" Total linear: {0}", circuit.Context.NumberOfXorGates + circuit.Context.NumberOfNotGates);
47-
Console.WriteLine("Multiplicative depth: {0}", statistics.Layers.Count);
48-
49-
int totalNonlinearGates = 0;
50-
int totalLinearGates = 0;
51-
52-
for (int i = 0; i < statistics.Layers.Count; ++i)
53-
{
54-
totalNonlinearGates += statistics.Layers[i].NumberOfNonlinearGates;
55-
totalLinearGates += statistics.Layers[i].NumberOfLinearGates;
56-
Console.WriteLine(" Layer {0}: {1} nonlinear / {2} linear", i, statistics.Layers[i].NumberOfNonlinearGates, statistics.Layers[i].NumberOfLinearGates);
57-
}
58-
59-
Console.WriteLine("Number of nonlinear gates in layers: {0}", totalNonlinearGates);
60-
Console.WriteLine("Number of linear gates in layers: {0}", totalLinearGates);
61-
62-
Bit[] result;
63-
64-
result = circuit.Evaluate(new LocalCircuitEvaluator(), inputs.SelectMany(input => input).ToArray());
65-
Console.WriteLine("Result (normal): {0}", new BitArray(result).ToBinaryString());
66-
67-
result = new Circuits.Batching.ForwardCircuit(circuit).Evaluate(new LocalCircuitEvaluator(), inputs.SelectMany(input => input).ToArray());
68-
Console.WriteLine("Result (forward): {0}", new BitArray(result).ToBinaryString());
6931

7032
if (args.Length == 0)
7133
{
@@ -97,7 +59,7 @@ public static void Main(string[] args)
9759

9860
private static void RunSecureComputationParty(int localPartyId, BitArray localInput)
9961
{
100-
using (TcpMultiPartyNetworkSession session = new TcpMultiPartyNetworkSession(StartPort, NumberOfParties, localPartyId))
62+
using (IMultiPartyNetworkSession session = CreateLocalSession(localPartyId, StartPort, NumberOfParties))
10163
{
10264
using (CryptoContext cryptoContext = CryptoContext.CreateDefault())
10365
{
@@ -130,5 +92,10 @@ private static void RunSecureComputationParty(int localPartyId, BitArray localIn
13092
}
13193
}
13294
}
95+
96+
private static IMultiPartyNetworkSession CreateLocalSession(int localPartyId, int startPort, int numberOfParties)
97+
{
98+
return TcpMultiPartyNetworkSession.EstablishAsync(new Party(localPartyId), IPAddress.Loopback, StartPort, NumberOfParties).Result;
99+
}
133100
}
134101
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace CompactMPC.Networking
6+
{
7+
public class NetworkConsistencyException : Exception
8+
{
9+
public NetworkConsistencyException(string message) : base(message) { }
10+
public NetworkConsistencyException(string message, Exception innerException) : base(message, innerException) { }
11+
}
12+
}

CompactMPC/Networking/Party.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,40 @@ public class Party
1111
private int _id;
1212
private string _name;
1313

14+
public Party(int id)
15+
{
16+
_id = id;
17+
_name = "Party " + (id + 1);
18+
}
19+
1420
public Party(int id, string name)
1521
{
1622
_id = id;
1723
_name = name;
1824
}
1925

26+
public override string ToString()
27+
{
28+
return String.Format("{0} (id: {1})", _name, _id);
29+
}
30+
31+
public override bool Equals(object other)
32+
{
33+
Party otherParty = other as Party;
34+
if (otherParty != null)
35+
return _id == otherParty.Id && _name == otherParty.Name;
36+
37+
return false;
38+
}
39+
40+
public override int GetHashCode()
41+
{
42+
int hashCode = 321773176;
43+
hashCode = hashCode * -1521134295 + _id.GetHashCode();
44+
hashCode = hashCode * -1521134295 + _name.GetHashCode();
45+
return hashCode;
46+
}
47+
2048
public int Id
2149
{
2250
get

CompactMPC/Networking/TcpMultiPartyNetworkSession.cs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,41 @@ namespace CompactMPC.Networking
1212
{
1313
public class TcpMultiPartyNetworkSession : IMultiPartyNetworkSession
1414
{
15-
private List<ITwoPartyNetworkSession> _remotePartySessions;
15+
private ITwoPartyNetworkSession[] _remotePartySessions;
1616
private Party _localParty;
17-
private int _numberOfParties;
1817

19-
public TcpMultiPartyNetworkSession(int startPort, int numberOfParties, int localPartyId)
18+
private TcpMultiPartyNetworkSession(Party localParty, ITwoPartyNetworkSession[] remotePartySessions)
2019
{
21-
_remotePartySessions = new List<ITwoPartyNetworkSession>(numberOfParties - 1);
22-
_localParty = new Party(localPartyId, "Party " + (localPartyId + 1));
23-
_numberOfParties = numberOfParties;
20+
_remotePartySessions = remotePartySessions;
21+
_localParty = localParty;
22+
}
23+
24+
public static async Task<TcpMultiPartyNetworkSession> EstablishAsync(Party localParty, IPAddress address, int startPort, int numberOfParties)
25+
{
26+
TcpTwoPartyNetworkSession[] remotePartySessions = new TcpTwoPartyNetworkSession[numberOfParties - 1];
2427

25-
for (int i = 0; i < localPartyId; ++i)
28+
for (int i = 0; i < localParty.Id; ++i)
2629
{
27-
TcpClient client = new TcpClient();
28-
client.ConnectAsync("127.0.0.1", startPort + i).Wait();
29-
_remotePartySessions.Add(new TcpTwoPartyNetworkSession(client, _localParty));
30+
remotePartySessions[i] = await TcpTwoPartyNetworkSession.ConnectAsync(localParty, address, startPort + i);
3031
}
3132

32-
TcpListener listener = new TcpListener(IPAddress.Any, startPort + localPartyId) { ExclusiveAddressUse = true };
33+
TcpListener listener = new TcpListener(IPAddress.Any, startPort + localParty.Id) { ExclusiveAddressUse = true };
3334
listener.Start();
3435

35-
for (int j = localPartyId + 1; j < numberOfParties; ++j)
36+
for (int j = localParty.Id + 1; j < numberOfParties; ++j)
3637
{
37-
TcpClient client = listener.AcceptTcpClientAsync().Result;
38-
_remotePartySessions.Add(new TcpTwoPartyNetworkSession(client, _localParty));
38+
remotePartySessions[j - 1] = await TcpTwoPartyNetworkSession.AcceptAsync(localParty, listener);
3939
}
4040

4141
listener.Stop();
4242

4343
for (int i = 0; i < numberOfParties; ++i)
4444
{
45-
if (i != localPartyId && !_remotePartySessions.Any(session => session.RemoteParty.Id == i))
46-
throw new InvalidDataException("Establishing connections was unsuccessful.");
45+
if (i != localParty.Id && !remotePartySessions.Any(session => session.RemoteParty.Id == i))
46+
throw new NetworkConsistencyException("Inconsistent TCP connection graph.");
4747
}
48+
49+
return new TcpMultiPartyNetworkSession(localParty, remotePartySessions);
4850
}
4951

5052
public void Dispose()
@@ -73,7 +75,7 @@ public int NumberOfParties
7375
{
7476
get
7577
{
76-
return _numberOfParties;
78+
return _remotePartySessions.Length + 1;
7779
}
7880
}
7981
}

CompactMPC/Networking/TcpTwoPartyNetworkSession.cs

Lines changed: 42 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,65 +9,70 @@
99

1010
namespace CompactMPC.Networking
1111
{
12-
public class TcpTwoPartyNetworkSession : ITwoPartyNetworkSession, IMultiPartyNetworkSession
12+
public class TcpTwoPartyNetworkSession : ITwoPartyNetworkSession
1313
{
1414
private TcpClient _client;
1515
private IMessageChannel _channel;
1616
private Party _localParty;
1717
private Party _remoteParty;
1818

19-
public TcpTwoPartyNetworkSession(TcpClient client, Party localParty)
19+
private TcpTwoPartyNetworkSession(TcpClient client, Party localParty, Party remoteParty)
2020
{
21-
Stream stream = client.GetStream();
22-
stream.Write(new byte[] { (byte)localParty.Id }, 0, 1);
23-
int remotePartyId = stream.Read(1)[0];
24-
2521
_client = client;
26-
_channel = new StreamMessageChannel(stream);
22+
_channel = new StreamMessageChannel(client.GetStream());
2723
_localParty = localParty;
28-
_remoteParty = new Party(remotePartyId, "Party " + (remotePartyId + 1));
24+
_remoteParty = remoteParty;
2925
}
3026

31-
public static TcpTwoPartyNetworkSession FromPort(int port)
27+
public static async Task<TcpTwoPartyNetworkSession> ConnectAsync(Party localParty, IPAddress address, int port)
3228
{
33-
TcpClient client;
34-
Party localParty;
29+
TcpClient client = new TcpClient();
30+
await client.ConnectAsync(address, port);
31+
return CreateFromPartyInformationExchange(localParty, client);
32+
}
3533

36-
try
37-
{
38-
Console.WriteLine("Starting TCP server...");
39-
client = AcceptTcpClient(port);
40-
Console.WriteLine("TCP server started.");
34+
public static async Task<TcpTwoPartyNetworkSession> AcceptAsync(Party localParty, int port)
35+
{
36+
TcpTwoPartyNetworkSession session;
4137

42-
localParty = new Party(0, "Party 0");
43-
}
44-
catch (Exception)
45-
{
46-
Console.WriteLine("Starting TCP server failed, starting client...");
47-
client = ConnectTcpClient(port);
48-
Console.WriteLine("TCP client started.");
38+
TcpListener listener = new TcpListener(IPAddress.Any, port) { ExclusiveAddressUse = true };
39+
listener.Start();
40+
session = await AcceptAsync(localParty, listener);
41+
listener.Stop();
4942

50-
localParty = new Party(1, "Bob");
51-
}
43+
return session;
44+
}
5245

53-
return new TcpTwoPartyNetworkSession(client, localParty);
46+
public static async Task<TcpTwoPartyNetworkSession> AcceptAsync(Party localParty, TcpListener listener)
47+
{
48+
TcpClient client = await listener.AcceptTcpClientAsync();
49+
return CreateFromPartyInformationExchange(localParty, client);
5450
}
5551

56-
private static TcpClient AcceptTcpClient(int port)
52+
private static TcpTwoPartyNetworkSession CreateFromPartyInformationExchange(Party localParty, TcpClient client)
5753
{
58-
TcpListener listener = new TcpListener(IPAddress.Any, port) { ExclusiveAddressUse = true };
59-
listener.Start();
60-
Task<TcpClient> acceptTask = listener.AcceptTcpClientAsync();
61-
acceptTask.Wait();
62-
listener.Stop();
63-
return acceptTask.Result;
54+
WritePartyInformation(client, localParty);
55+
Party remoteParty = ReadPartyInformation(client);
56+
return new TcpTwoPartyNetworkSession(client, localParty, remoteParty);
6457
}
6558

66-
private static TcpClient ConnectTcpClient(int port)
59+
private static void WritePartyInformation(TcpClient client, Party party)
6760
{
68-
TcpClient client = new TcpClient();
69-
client.ConnectAsync("127.0.0.1", port).Wait();
70-
return client;
61+
using (BinaryWriter writer = new BinaryWriter(client.GetStream(), Encoding.UTF8, true))
62+
{
63+
writer.Write(party.Id);
64+
writer.Write(party.Name);
65+
}
66+
}
67+
68+
private static Party ReadPartyInformation(TcpClient client)
69+
{
70+
using (BinaryReader reader = new BinaryReader(client.GetStream(), Encoding.UTF8, true))
71+
{
72+
int id = reader.ReadInt32();
73+
string name = reader.ReadString();
74+
return new Party(id, name);
75+
}
7176
}
7277

7378
public void Dispose()
@@ -98,21 +103,5 @@ public Party RemoteParty
98103
return _remoteParty;
99104
}
100105
}
101-
102-
public IEnumerable<ITwoPartyNetworkSession> RemotePartySessions
103-
{
104-
get
105-
{
106-
yield return this;
107-
}
108-
}
109-
110-
public int NumberOfParties
111-
{
112-
get
113-
{
114-
return 2;
115-
}
116-
}
117106
}
118107
}

UnitTests/CircuitTest.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,17 @@ public void TestCircuitEvaluation()
3636
BitArray forwardEvaluationOutput = new BitArray(new ForwardCircuit(circuit).Evaluate(new LocalCircuitEvaluator(), sequentialInput));
3737
BitArray expectedEvaluationOutput = BitArray.FromBinaryString("01010000101100");
3838

39-
Assert.IsTrue(
40-
Enumerable.SequenceEqual(expectedEvaluationOutput, lazyEvaluationOutput),
39+
CollectionAssert.AreEqual(
40+
expectedEvaluationOutput,
41+
lazyEvaluationOutput,
4142
"Incorrect lazy evaluation output {0} (should be {1}).",
4243
lazyEvaluationOutput.ToBinaryString(),
4344
expectedEvaluationOutput.ToBinaryString()
4445
);
4546

46-
Assert.IsTrue(
47-
Enumerable.SequenceEqual(expectedEvaluationOutput, forwardEvaluationOutput),
47+
CollectionAssert.AreEqual(
48+
expectedEvaluationOutput,
49+
forwardEvaluationOutput,
4850
"Incorrect forward evaluation output {0} (should be {1}).",
4951
forwardEvaluationOutput.ToBinaryString(),
5052
expectedEvaluationOutput.ToBinaryString()

0 commit comments

Comments
 (0)