diff options
| author | Sam Anthony <sam@samanthony.xyz> | 2024-11-23 11:34:42 -0500 |
|---|---|---|
| committer | Sam Anthony <sam@samanthony.xyz> | 2024-11-23 11:34:42 -0500 |
| commit | d5a1ec8b54c1c3c516d07f1916276cd6e5a937e4 (patch) | |
| tree | 34f4ed7975803f573d16a7215ae39a9b2791a9b9 | |
| parent | e3df4a078afd37314d330daa2de0883f8dd1811b (diff) | |
| download | soen423-d5a1ec8b54c1c3c516d07f1916276cd6e5a937e4.zip | |
runicast: use DatagramChannel
| -rw-r--r-- | src/main/java/derms/io/Serial.java | 24 | ||||
| -rw-r--r-- | src/main/java/derms/net/rmulticast/ReliableMulticast.java | 2 | ||||
| -rw-r--r-- | src/main/java/derms/net/runicast/Receive.java | 32 | ||||
| -rw-r--r-- | src/main/java/derms/net/runicast/ReceiveAcks.java | 27 | ||||
| -rw-r--r-- | src/main/java/derms/net/runicast/ReliableUnicastReceiver.java | 14 | ||||
| -rw-r--r-- | src/main/java/derms/net/runicast/ReliableUnicastSender.java | 36 | ||||
| -rw-r--r-- | src/main/java/derms/net/runicast/Retransmit.java | 21 | ||||
| -rw-r--r-- | src/main/java/derms/util/ThreadPool.java | 21 |
8 files changed, 110 insertions, 67 deletions
diff --git a/src/main/java/derms/io/Serial.java b/src/main/java/derms/io/Serial.java new file mode 100644 index 0000000..b5b2299 --- /dev/null +++ b/src/main/java/derms/io/Serial.java @@ -0,0 +1,24 @@ +package derms.io; + +import java.io.*; +import java.nio.ByteBuffer; + +public class Serial { + public static ByteBuffer encode(Serializable obj) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(baos); + oos.writeObject(obj); + oos.flush(); + ByteBuffer buf = ByteBuffer.wrap(baos.toByteArray()); + oos.close(); + return buf; + } + + public static <T extends Serializable> T decode(ByteBuffer buf, Class<T> clazz) throws IOException, ClassNotFoundException { + ObjectInputStream ois = new ObjectInputStream( + new ByteArrayInputStream(buf.array())); + T obj = clazz.cast(ois.readObject()); + ois.close(); + return obj; + } +} diff --git a/src/main/java/derms/net/rmulticast/ReliableMulticast.java b/src/main/java/derms/net/rmulticast/ReliableMulticast.java index 3894021..c23baaa 100644 --- a/src/main/java/derms/net/rmulticast/ReliableMulticast.java +++ b/src/main/java/derms/net/rmulticast/ReliableMulticast.java @@ -63,7 +63,7 @@ public class ReliableMulticast<T extends MessagePayload> { public void close() { log.info("Shutting down..."); - ThreadPool.shutDown(pool, log); + ThreadPool.shutdownNow(pool, log); outSock.close(); log.info("Finished shutting down."); } diff --git a/src/main/java/derms/net/runicast/Receive.java b/src/main/java/derms/net/runicast/Receive.java index 584861b..8620ebd 100644 --- a/src/main/java/derms/net/runicast/Receive.java +++ b/src/main/java/derms/net/runicast/Receive.java @@ -1,11 +1,13 @@ package derms.net.runicast; -import derms.net.ConcurrentDatagramSocket; +import derms.io.Serial; import derms.net.MessagePayload; -import derms.net.Packet; import java.io.IOException; import java.net.*; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.DatagramChannel; import java.util.Queue; import java.util.logging.Logger; @@ -13,11 +15,11 @@ class Receive<T extends MessagePayload> implements Runnable { private static final int bufSize = 8192; private long seq; // Sequence number. - private final ConcurrentDatagramSocket sock; + private final DatagramChannel sock; private final Queue<T> delivered; private final Logger log; - Receive(ConcurrentDatagramSocket sock, Queue<T> delivered) { + Receive(DatagramChannel sock, Queue<T> delivered) { this.seq = 0; this.sock = sock; this.delivered = delivered; @@ -26,18 +28,15 @@ class Receive<T extends MessagePayload> implements Runnable { @Override public void run() { - DatagramPacket pkt = new DatagramPacket(new byte[bufSize], bufSize); for (;;) { + ByteBuffer buf = ByteBuffer.allocate(bufSize); try { - sock.receive(pkt); - Message<T> msg = (Message<T>) Packet.decode(pkt, Message.class); - SocketAddress sender = pkt.getSocketAddress(); + SocketAddress sender = sock.receive(buf); + Message<T> msg = (Message<T>) Serial.decode(buf, Message.class); recv(msg, sender); - } catch (SocketTimeoutException e) { - if (Thread.interrupted()) { - log.info("Interrupted"); - return; - } + } catch (ClosedChannelException e) { + log.info("Shutting down."); + return; } catch (IOException | ClassNotFoundException | ClassCastException e) { log.warning(e.getMessage()); } @@ -45,16 +44,19 @@ class Receive<T extends MessagePayload> implements Runnable { } private void recv(Message<T> msg, SocketAddress sender) throws IOException { + log.info("Received " + msg); if (msg.seq == seq) { delivered.add(msg.payload); + log.info("Delivered " + msg); ack(msg, sender); + log.info("Acked " + msg); seq++; } } private void ack(Message<T> msg, SocketAddress sender) throws IOException { Ack ack = new Ack(msg.seq); - DatagramPacket pkt = Packet.encode(ack, sender); - sock.send(pkt); + ByteBuffer buf = Serial.encode(ack); + sock.send(buf, sender); } } diff --git a/src/main/java/derms/net/runicast/ReceiveAcks.java b/src/main/java/derms/net/runicast/ReceiveAcks.java index 0f585ff..9d7b7de 100644 --- a/src/main/java/derms/net/runicast/ReceiveAcks.java +++ b/src/main/java/derms/net/runicast/ReceiveAcks.java @@ -1,12 +1,12 @@ package derms.net.runicast; -import derms.net.ConcurrentDatagramSocket; +import derms.io.Serial; import derms.net.MessagePayload; -import derms.net.Packet; import java.io.IOException; -import java.net.DatagramPacket; -import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.DatagramChannel; import java.util.Queue; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Logger; @@ -17,10 +17,10 @@ class ReceiveAcks<T extends MessagePayload> implements Runnable { private final AtomicLong unacked; private final Queue<Message<T>> sent; - private final ConcurrentDatagramSocket sock; + private final DatagramChannel sock; private final Logger log; - ReceiveAcks(AtomicLong unacked, Queue<Message<T>> sent, ConcurrentDatagramSocket sock) { + ReceiveAcks(AtomicLong unacked, Queue<Message<T>> sent, DatagramChannel sock) { this.unacked = unacked; this.sent = sent; this.sock = sock; @@ -29,17 +29,15 @@ class ReceiveAcks<T extends MessagePayload> implements Runnable { @Override public void run() { - DatagramPacket pkt = new DatagramPacket(new byte[bufSize], bufSize); for (;;) { + ByteBuffer buf = ByteBuffer.allocate(bufSize); try { - sock.receive(pkt); - Ack ack = Packet.decode(pkt, Ack.class); + sock.receive(buf); + Ack ack = Serial.decode(buf, Ack.class); recvAck(ack.seq); - } catch (SocketTimeoutException e) { - if (Thread.interrupted()) { - log.info("Interrupted."); - return; - } + } catch (ClosedChannelException e) { + log.info("Shutting down."); + return; } catch (IOException | ClassNotFoundException | ClassCastException e) { log.warning(e.getMessage()); } @@ -47,6 +45,7 @@ class ReceiveAcks<T extends MessagePayload> implements Runnable { } private void recvAck(long ack) { + log.info("Received ack: " + ack); unacked.updateAndGet((unacked) -> { if (ack >= unacked) return ack+1; diff --git a/src/main/java/derms/net/runicast/ReliableUnicastReceiver.java b/src/main/java/derms/net/runicast/ReliableUnicastReceiver.java index 81e3502..ff0b72a 100644 --- a/src/main/java/derms/net/runicast/ReliableUnicastReceiver.java +++ b/src/main/java/derms/net/runicast/ReliableUnicastReceiver.java @@ -1,20 +1,18 @@ package derms.net.runicast; -import derms.net.ConcurrentDatagramSocket; import derms.net.MessagePayload; import derms.util.ThreadPool; import java.io.IOException; import java.net.SocketAddress; +import java.nio.channels.DatagramChannel; import java.time.Duration; import java.util.concurrent.*; import java.util.logging.Logger; /** The receiving end of a reliable unicast connection. */ public class ReliableUnicastReceiver<T extends MessagePayload> { - private static final Duration soTimeout = Duration.ofMillis(500); // Socket timeout. - - private final ConcurrentDatagramSocket sock; + private final DatagramChannel sock; private final BlockingQueue<T> delivered; private final Logger log; private final ExecutorService pool; @@ -25,18 +23,18 @@ public class ReliableUnicastReceiver<T extends MessagePayload> { * @param laddr The local IP address and port to listen on. */ public ReliableUnicastReceiver(SocketAddress laddr) throws IOException { - this.sock = new ConcurrentDatagramSocket(laddr); - this.sock.setSoTimeout(soTimeout); + this.sock = DatagramChannel.open(); + sock.bind(laddr); this.delivered = new LinkedBlockingQueue<T>(); this.log = Logger.getLogger(getClass().getName()); this.pool = Executors.newCachedThreadPool(); pool.execute(new Receive<T>(sock, delivered)); } - public void close() { + public void close() throws IOException { log.info("Shutting down"); - ThreadPool.shutDown(pool, log); sock.close(); + ThreadPool.shutdown(pool, log); } /** Receive a message, blocking if necessary until one arrives. */ diff --git a/src/main/java/derms/net/runicast/ReliableUnicastSender.java b/src/main/java/derms/net/runicast/ReliableUnicastSender.java index 83408d5..1f3c5d4 100644 --- a/src/main/java/derms/net/runicast/ReliableUnicastSender.java +++ b/src/main/java/derms/net/runicast/ReliableUnicastSender.java @@ -1,14 +1,13 @@ package derms.net.runicast; -import derms.net.ConcurrentDatagramSocket; +import derms.io.Serial; import derms.net.MessagePayload; -import derms.net.Packet; import derms.util.ThreadPool; import java.io.IOException; -import java.net.DatagramPacket; -import java.net.InetSocketAddress; -import java.time.Duration; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.DatagramChannel; import java.util.Queue; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -18,12 +17,10 @@ import java.util.logging.Logger; /** The sending end of a reliable unicast connection. */ public class ReliableUnicastSender<T extends MessagePayload> { - private static final Duration soTimeout = Duration.ofMillis(500); // Socket timeout. - private final AtomicLong next; // Next sequence number. private final AtomicLong unacked; // Sequence number of first unacknowledged message. private final Queue<Message<T>> sent; - private final ConcurrentDatagramSocket sock; + private final DatagramChannel sock; private final Logger log; private final ExecutorService pool; @@ -32,13 +29,12 @@ public class ReliableUnicastSender<T extends MessagePayload> { * * @param raddr The remote IP address to connect to. */ - public ReliableUnicastSender(InetSocketAddress raddr) throws IOException { + public ReliableUnicastSender(SocketAddress raddr) throws IOException { this.next = new AtomicLong(0); this.unacked = new AtomicLong(0); this.sent = new LinkedBlockingQueue<Message<T>>(); - this.sock = new ConcurrentDatagramSocket(); - this.sock.connect(raddr); - this.sock.setSoTimeout(soTimeout); + this.sock = DatagramChannel.open(); + sock.connect(raddr); this.log = Logger.getLogger(getClass().getName()); this.pool = Executors.newCachedThreadPool(); pool.execute(new ReceiveAcks<T>(unacked, sent, sock)); @@ -47,28 +43,32 @@ public class ReliableUnicastSender<T extends MessagePayload> { public void send(T payload) throws IOException { Message<T> msg = new Message<T>(next.get(), payload); - DatagramPacket pkt = Packet.encode(msg, sock.getRemoteSocketAddress()); - sock.send(pkt); + ByteBuffer buf = Serial.encode(msg); + sock.send(buf, sock.getRemoteAddress()); sent.add(msg); next.incrementAndGet(); + log.info("Sent " + msg); } /** Wait for all messages to be acknowledged and close the connection. */ - public void close() throws InterruptedException { + public void close() throws InterruptedException, IOException { // Wait for receiver to acknowledge all sent messages. + log.info("Waiting for acknowledgements..."); while (unacked.get() < next.get()) { Thread.yield(); if (Thread.interrupted()) throw new InterruptedException(); } - closeNow(); + log.info("Shutting down."); + sock.close(); + ThreadPool.shutdown(pool, log); } /** Close the connection immediately, without waiting for acknowledgements. */ - public void closeNow() { + public void closeNow() throws IOException { log.info("Shutting down."); - ThreadPool.shutDown(pool, log); sock.close(); + ThreadPool.shutdownNow(pool, log); } } diff --git a/src/main/java/derms/net/runicast/Retransmit.java b/src/main/java/derms/net/runicast/Retransmit.java index affd00c..16f8859 100644 --- a/src/main/java/derms/net/runicast/Retransmit.java +++ b/src/main/java/derms/net/runicast/Retransmit.java @@ -1,12 +1,13 @@ package derms.net.runicast; -import derms.net.ConcurrentDatagramSocket; +import derms.io.Serial; import derms.net.MessagePayload; -import derms.net.Packet; import derms.util.Wait; import java.io.IOException; -import java.net.DatagramPacket; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.DatagramChannel; import java.time.Duration; import java.util.Queue; import java.util.concurrent.atomic.AtomicLong; @@ -18,10 +19,10 @@ class Retransmit<T extends MessagePayload> implements Runnable { private final AtomicLong unacked; private final Queue<Message<T>> sent; - private final ConcurrentDatagramSocket sock; + private final DatagramChannel sock; private final Logger log; - Retransmit(AtomicLong unacked, Queue<Message<T>> sent, ConcurrentDatagramSocket sock) { + Retransmit(AtomicLong unacked, Queue<Message<T>> sent, DatagramChannel sock) { this.unacked = unacked; this.sent = sent; this.sock = sock; @@ -40,15 +41,15 @@ class Retransmit<T extends MessagePayload> implements Runnable { } } } - } catch (InterruptedException e) { - log.info("Interrupted."); + } catch (InterruptedException | ClosedChannelException e) { + log.info("Shutting down."); } } - private void retransmit(Message<T> msg) { + private void retransmit(Message<T> msg) throws ClosedChannelException { try { - DatagramPacket pkt = Packet.encode(msg, sock.getRemoteSocketAddress()); - sock.send(pkt); + ByteBuffer buf = Serial.encode(msg); + sock.send(buf, sock.getRemoteAddress()); log.info("Retransmitted " + msg); } catch (IOException e) { log.warning("Failed to retransmit " + msg + ": " + e.getMessage()); diff --git a/src/main/java/derms/util/ThreadPool.java b/src/main/java/derms/util/ThreadPool.java index 33588ff..ccd3afa 100644 --- a/src/main/java/derms/util/ThreadPool.java +++ b/src/main/java/derms/util/ThreadPool.java @@ -8,7 +8,26 @@ import java.util.logging.Logger; public class ThreadPool { public static final Duration timeout = Duration.ofSeconds(1); - public static void shutDown(ExecutorService pool, Logger log) { + public static void shutdown(ExecutorService pool, Logger log) { + pool.shutdown(); + try { + // Wait for existing threads to stop. + if (!pool.awaitTermination(timeout.toMillis(), TimeUnit.MILLISECONDS)) { + log.warning("Thread pool did not terminate after " + timeout + ". Forcefully shutting down..."); + pool.shutdownNow(); // Cancel running tasks. + // Wait for tasks to stop. + if (!pool.awaitTermination(timeout.toMillis(), TimeUnit.MILLISECONDS)) + log.warning("Thread pool did not terminate after " + timeout); + } + } catch (InterruptedException e) { + // (Re-)Cancel if current thread also interrupted. + pool.shutdownNow(); + // Preserve interrupt status. + Thread.currentThread().interrupt(); + } + } + + public static void shutdownNow(ExecutorService pool, Logger log) { pool.shutdownNow(); try { if (!pool.awaitTermination(timeout.toMillis(), TimeUnit.MILLISECONDS)) |