summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSam Anthony <sam@samanthony.xyz>2024-11-23 11:34:42 -0500
committerSam Anthony <sam@samanthony.xyz>2024-11-23 11:34:42 -0500
commitd5a1ec8b54c1c3c516d07f1916276cd6e5a937e4 (patch)
tree34f4ed7975803f573d16a7215ae39a9b2791a9b9
parente3df4a078afd37314d330daa2de0883f8dd1811b (diff)
downloadsoen423-d5a1ec8b54c1c3c516d07f1916276cd6e5a937e4.zip
runicast: use DatagramChannel
-rw-r--r--src/main/java/derms/io/Serial.java24
-rw-r--r--src/main/java/derms/net/rmulticast/ReliableMulticast.java2
-rw-r--r--src/main/java/derms/net/runicast/Receive.java32
-rw-r--r--src/main/java/derms/net/runicast/ReceiveAcks.java27
-rw-r--r--src/main/java/derms/net/runicast/ReliableUnicastReceiver.java14
-rw-r--r--src/main/java/derms/net/runicast/ReliableUnicastSender.java36
-rw-r--r--src/main/java/derms/net/runicast/Retransmit.java21
-rw-r--r--src/main/java/derms/util/ThreadPool.java21
8 files changed, 110 insertions, 67 deletions
diff --git a/src/main/java/derms/io/Serial.java b/src/main/java/derms/io/Serial.java
new file mode 100644
index 0000000..b5b2299
--- /dev/null
+++ b/src/main/java/derms/io/Serial.java
@@ -0,0 +1,24 @@
+package derms.io;
+
+import java.io.*;
+import java.nio.ByteBuffer;
+
+public class Serial {
+ public static ByteBuffer encode(Serializable obj) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ ObjectOutputStream oos = new ObjectOutputStream(baos);
+ oos.writeObject(obj);
+ oos.flush();
+ ByteBuffer buf = ByteBuffer.wrap(baos.toByteArray());
+ oos.close();
+ return buf;
+ }
+
+ public static <T extends Serializable> T decode(ByteBuffer buf, Class<T> clazz) throws IOException, ClassNotFoundException {
+ ObjectInputStream ois = new ObjectInputStream(
+ new ByteArrayInputStream(buf.array()));
+ T obj = clazz.cast(ois.readObject());
+ ois.close();
+ return obj;
+ }
+}
diff --git a/src/main/java/derms/net/rmulticast/ReliableMulticast.java b/src/main/java/derms/net/rmulticast/ReliableMulticast.java
index 3894021..c23baaa 100644
--- a/src/main/java/derms/net/rmulticast/ReliableMulticast.java
+++ b/src/main/java/derms/net/rmulticast/ReliableMulticast.java
@@ -63,7 +63,7 @@ public class ReliableMulticast<T extends MessagePayload> {
public void close() {
log.info("Shutting down...");
- ThreadPool.shutDown(pool, log);
+ ThreadPool.shutdownNow(pool, log);
outSock.close();
log.info("Finished shutting down.");
}
diff --git a/src/main/java/derms/net/runicast/Receive.java b/src/main/java/derms/net/runicast/Receive.java
index 584861b..8620ebd 100644
--- a/src/main/java/derms/net/runicast/Receive.java
+++ b/src/main/java/derms/net/runicast/Receive.java
@@ -1,11 +1,13 @@
package derms.net.runicast;
-import derms.net.ConcurrentDatagramSocket;
+import derms.io.Serial;
import derms.net.MessagePayload;
-import derms.net.Packet;
import java.io.IOException;
import java.net.*;
+import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+import java.nio.channels.DatagramChannel;
import java.util.Queue;
import java.util.logging.Logger;
@@ -13,11 +15,11 @@ class Receive<T extends MessagePayload> implements Runnable {
private static final int bufSize = 8192;
private long seq; // Sequence number.
- private final ConcurrentDatagramSocket sock;
+ private final DatagramChannel sock;
private final Queue<T> delivered;
private final Logger log;
- Receive(ConcurrentDatagramSocket sock, Queue<T> delivered) {
+ Receive(DatagramChannel sock, Queue<T> delivered) {
this.seq = 0;
this.sock = sock;
this.delivered = delivered;
@@ -26,18 +28,15 @@ class Receive<T extends MessagePayload> implements Runnable {
@Override
public void run() {
- DatagramPacket pkt = new DatagramPacket(new byte[bufSize], bufSize);
for (;;) {
+ ByteBuffer buf = ByteBuffer.allocate(bufSize);
try {
- sock.receive(pkt);
- Message<T> msg = (Message<T>) Packet.decode(pkt, Message.class);
- SocketAddress sender = pkt.getSocketAddress();
+ SocketAddress sender = sock.receive(buf);
+ Message<T> msg = (Message<T>) Serial.decode(buf, Message.class);
recv(msg, sender);
- } catch (SocketTimeoutException e) {
- if (Thread.interrupted()) {
- log.info("Interrupted");
- return;
- }
+ } catch (ClosedChannelException e) {
+ log.info("Shutting down.");
+ return;
} catch (IOException | ClassNotFoundException | ClassCastException e) {
log.warning(e.getMessage());
}
@@ -45,16 +44,19 @@ class Receive<T extends MessagePayload> implements Runnable {
}
private void recv(Message<T> msg, SocketAddress sender) throws IOException {
+ log.info("Received " + msg);
if (msg.seq == seq) {
delivered.add(msg.payload);
+ log.info("Delivered " + msg);
ack(msg, sender);
+ log.info("Acked " + msg);
seq++;
}
}
private void ack(Message<T> msg, SocketAddress sender) throws IOException {
Ack ack = new Ack(msg.seq);
- DatagramPacket pkt = Packet.encode(ack, sender);
- sock.send(pkt);
+ ByteBuffer buf = Serial.encode(ack);
+ sock.send(buf, sender);
}
}
diff --git a/src/main/java/derms/net/runicast/ReceiveAcks.java b/src/main/java/derms/net/runicast/ReceiveAcks.java
index 0f585ff..9d7b7de 100644
--- a/src/main/java/derms/net/runicast/ReceiveAcks.java
+++ b/src/main/java/derms/net/runicast/ReceiveAcks.java
@@ -1,12 +1,12 @@
package derms.net.runicast;
-import derms.net.ConcurrentDatagramSocket;
+import derms.io.Serial;
import derms.net.MessagePayload;
-import derms.net.Packet;
import java.io.IOException;
-import java.net.DatagramPacket;
-import java.net.SocketTimeoutException;
+import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+import java.nio.channels.DatagramChannel;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Logger;
@@ -17,10 +17,10 @@ class ReceiveAcks<T extends MessagePayload> implements Runnable {
private final AtomicLong unacked;
private final Queue<Message<T>> sent;
- private final ConcurrentDatagramSocket sock;
+ private final DatagramChannel sock;
private final Logger log;
- ReceiveAcks(AtomicLong unacked, Queue<Message<T>> sent, ConcurrentDatagramSocket sock) {
+ ReceiveAcks(AtomicLong unacked, Queue<Message<T>> sent, DatagramChannel sock) {
this.unacked = unacked;
this.sent = sent;
this.sock = sock;
@@ -29,17 +29,15 @@ class ReceiveAcks<T extends MessagePayload> implements Runnable {
@Override
public void run() {
- DatagramPacket pkt = new DatagramPacket(new byte[bufSize], bufSize);
for (;;) {
+ ByteBuffer buf = ByteBuffer.allocate(bufSize);
try {
- sock.receive(pkt);
- Ack ack = Packet.decode(pkt, Ack.class);
+ sock.receive(buf);
+ Ack ack = Serial.decode(buf, Ack.class);
recvAck(ack.seq);
- } catch (SocketTimeoutException e) {
- if (Thread.interrupted()) {
- log.info("Interrupted.");
- return;
- }
+ } catch (ClosedChannelException e) {
+ log.info("Shutting down.");
+ return;
} catch (IOException | ClassNotFoundException | ClassCastException e) {
log.warning(e.getMessage());
}
@@ -47,6 +45,7 @@ class ReceiveAcks<T extends MessagePayload> implements Runnable {
}
private void recvAck(long ack) {
+ log.info("Received ack: " + ack);
unacked.updateAndGet((unacked) -> {
if (ack >= unacked)
return ack+1;
diff --git a/src/main/java/derms/net/runicast/ReliableUnicastReceiver.java b/src/main/java/derms/net/runicast/ReliableUnicastReceiver.java
index 81e3502..ff0b72a 100644
--- a/src/main/java/derms/net/runicast/ReliableUnicastReceiver.java
+++ b/src/main/java/derms/net/runicast/ReliableUnicastReceiver.java
@@ -1,20 +1,18 @@
package derms.net.runicast;
-import derms.net.ConcurrentDatagramSocket;
import derms.net.MessagePayload;
import derms.util.ThreadPool;
import java.io.IOException;
import java.net.SocketAddress;
+import java.nio.channels.DatagramChannel;
import java.time.Duration;
import java.util.concurrent.*;
import java.util.logging.Logger;
/** The receiving end of a reliable unicast connection. */
public class ReliableUnicastReceiver<T extends MessagePayload> {
- private static final Duration soTimeout = Duration.ofMillis(500); // Socket timeout.
-
- private final ConcurrentDatagramSocket sock;
+ private final DatagramChannel sock;
private final BlockingQueue<T> delivered;
private final Logger log;
private final ExecutorService pool;
@@ -25,18 +23,18 @@ public class ReliableUnicastReceiver<T extends MessagePayload> {
* @param laddr The local IP address and port to listen on.
*/
public ReliableUnicastReceiver(SocketAddress laddr) throws IOException {
- this.sock = new ConcurrentDatagramSocket(laddr);
- this.sock.setSoTimeout(soTimeout);
+ this.sock = DatagramChannel.open();
+ sock.bind(laddr);
this.delivered = new LinkedBlockingQueue<T>();
this.log = Logger.getLogger(getClass().getName());
this.pool = Executors.newCachedThreadPool();
pool.execute(new Receive<T>(sock, delivered));
}
- public void close() {
+ public void close() throws IOException {
log.info("Shutting down");
- ThreadPool.shutDown(pool, log);
sock.close();
+ ThreadPool.shutdown(pool, log);
}
/** Receive a message, blocking if necessary until one arrives. */
diff --git a/src/main/java/derms/net/runicast/ReliableUnicastSender.java b/src/main/java/derms/net/runicast/ReliableUnicastSender.java
index 83408d5..1f3c5d4 100644
--- a/src/main/java/derms/net/runicast/ReliableUnicastSender.java
+++ b/src/main/java/derms/net/runicast/ReliableUnicastSender.java
@@ -1,14 +1,13 @@
package derms.net.runicast;
-import derms.net.ConcurrentDatagramSocket;
+import derms.io.Serial;
import derms.net.MessagePayload;
-import derms.net.Packet;
import derms.util.ThreadPool;
import java.io.IOException;
-import java.net.DatagramPacket;
-import java.net.InetSocketAddress;
-import java.time.Duration;
+import java.net.SocketAddress;
+import java.nio.ByteBuffer;
+import java.nio.channels.DatagramChannel;
import java.util.Queue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -18,12 +17,10 @@ import java.util.logging.Logger;
/** The sending end of a reliable unicast connection. */
public class ReliableUnicastSender<T extends MessagePayload> {
- private static final Duration soTimeout = Duration.ofMillis(500); // Socket timeout.
-
private final AtomicLong next; // Next sequence number.
private final AtomicLong unacked; // Sequence number of first unacknowledged message.
private final Queue<Message<T>> sent;
- private final ConcurrentDatagramSocket sock;
+ private final DatagramChannel sock;
private final Logger log;
private final ExecutorService pool;
@@ -32,13 +29,12 @@ public class ReliableUnicastSender<T extends MessagePayload> {
*
* @param raddr The remote IP address to connect to.
*/
- public ReliableUnicastSender(InetSocketAddress raddr) throws IOException {
+ public ReliableUnicastSender(SocketAddress raddr) throws IOException {
this.next = new AtomicLong(0);
this.unacked = new AtomicLong(0);
this.sent = new LinkedBlockingQueue<Message<T>>();
- this.sock = new ConcurrentDatagramSocket();
- this.sock.connect(raddr);
- this.sock.setSoTimeout(soTimeout);
+ this.sock = DatagramChannel.open();
+ sock.connect(raddr);
this.log = Logger.getLogger(getClass().getName());
this.pool = Executors.newCachedThreadPool();
pool.execute(new ReceiveAcks<T>(unacked, sent, sock));
@@ -47,28 +43,32 @@ public class ReliableUnicastSender<T extends MessagePayload> {
public void send(T payload) throws IOException {
Message<T> msg = new Message<T>(next.get(), payload);
- DatagramPacket pkt = Packet.encode(msg, sock.getRemoteSocketAddress());
- sock.send(pkt);
+ ByteBuffer buf = Serial.encode(msg);
+ sock.send(buf, sock.getRemoteAddress());
sent.add(msg);
next.incrementAndGet();
+ log.info("Sent " + msg);
}
/** Wait for all messages to be acknowledged and close the connection. */
- public void close() throws InterruptedException {
+ public void close() throws InterruptedException, IOException {
// Wait for receiver to acknowledge all sent messages.
+ log.info("Waiting for acknowledgements...");
while (unacked.get() < next.get()) {
Thread.yield();
if (Thread.interrupted())
throw new InterruptedException();
}
- closeNow();
+ log.info("Shutting down.");
+ sock.close();
+ ThreadPool.shutdown(pool, log);
}
/** Close the connection immediately, without waiting for acknowledgements. */
- public void closeNow() {
+ public void closeNow() throws IOException {
log.info("Shutting down.");
- ThreadPool.shutDown(pool, log);
sock.close();
+ ThreadPool.shutdownNow(pool, log);
}
}
diff --git a/src/main/java/derms/net/runicast/Retransmit.java b/src/main/java/derms/net/runicast/Retransmit.java
index affd00c..16f8859 100644
--- a/src/main/java/derms/net/runicast/Retransmit.java
+++ b/src/main/java/derms/net/runicast/Retransmit.java
@@ -1,12 +1,13 @@
package derms.net.runicast;
-import derms.net.ConcurrentDatagramSocket;
+import derms.io.Serial;
import derms.net.MessagePayload;
-import derms.net.Packet;
import derms.util.Wait;
import java.io.IOException;
-import java.net.DatagramPacket;
+import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+import java.nio.channels.DatagramChannel;
import java.time.Duration;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicLong;
@@ -18,10 +19,10 @@ class Retransmit<T extends MessagePayload> implements Runnable {
private final AtomicLong unacked;
private final Queue<Message<T>> sent;
- private final ConcurrentDatagramSocket sock;
+ private final DatagramChannel sock;
private final Logger log;
- Retransmit(AtomicLong unacked, Queue<Message<T>> sent, ConcurrentDatagramSocket sock) {
+ Retransmit(AtomicLong unacked, Queue<Message<T>> sent, DatagramChannel sock) {
this.unacked = unacked;
this.sent = sent;
this.sock = sock;
@@ -40,15 +41,15 @@ class Retransmit<T extends MessagePayload> implements Runnable {
}
}
}
- } catch (InterruptedException e) {
- log.info("Interrupted.");
+ } catch (InterruptedException | ClosedChannelException e) {
+ log.info("Shutting down.");
}
}
- private void retransmit(Message<T> msg) {
+ private void retransmit(Message<T> msg) throws ClosedChannelException {
try {
- DatagramPacket pkt = Packet.encode(msg, sock.getRemoteSocketAddress());
- sock.send(pkt);
+ ByteBuffer buf = Serial.encode(msg);
+ sock.send(buf, sock.getRemoteAddress());
log.info("Retransmitted " + msg);
} catch (IOException e) {
log.warning("Failed to retransmit " + msg + ": " + e.getMessage());
diff --git a/src/main/java/derms/util/ThreadPool.java b/src/main/java/derms/util/ThreadPool.java
index 33588ff..ccd3afa 100644
--- a/src/main/java/derms/util/ThreadPool.java
+++ b/src/main/java/derms/util/ThreadPool.java
@@ -8,7 +8,26 @@ import java.util.logging.Logger;
public class ThreadPool {
public static final Duration timeout = Duration.ofSeconds(1);
- public static void shutDown(ExecutorService pool, Logger log) {
+ public static void shutdown(ExecutorService pool, Logger log) {
+ pool.shutdown();
+ try {
+ // Wait for existing threads to stop.
+ if (!pool.awaitTermination(timeout.toMillis(), TimeUnit.MILLISECONDS)) {
+ log.warning("Thread pool did not terminate after " + timeout + ". Forcefully shutting down...");
+ pool.shutdownNow(); // Cancel running tasks.
+ // Wait for tasks to stop.
+ if (!pool.awaitTermination(timeout.toMillis(), TimeUnit.MILLISECONDS))
+ log.warning("Thread pool did not terminate after " + timeout);
+ }
+ } catch (InterruptedException e) {
+ // (Re-)Cancel if current thread also interrupted.
+ pool.shutdownNow();
+ // Preserve interrupt status.
+ Thread.currentThread().interrupt();
+ }
+ }
+
+ public static void shutdownNow(ExecutorService pool, Logger log) {
pool.shutdownNow();
try {
if (!pool.awaitTermination(timeout.toMillis(), TimeUnit.MILLISECONDS))