Coding Projects Python

asyncio: an interlude

Published · 3min

This post is part 4 of the “Starting with asyncio” series:

  1. Starting with asyncio
  2. Exploring asyncio protocols
  3. A simple asyncio client
  4. asyncio: an interlude
  5. A high-level asyncio server and client

Yesterday’s entry contained some issues I didn’t notice until later. The main one is that after the initial successful test, I’d notice that the client would start to disconnect. There’s no explicit timeout on the connections, so this might be either some underlying socket timeout or there’s something else going on.

Essentially, it was working partly by accident, as can be see by partly fixing the main function:

async def main():
    loop = asyncio.get_running_loop()
    transport, _ = await loop.create_connection(
        protocol_factory=ChatClientProtocol,
        host="localhost",
        port=8007,
    )
    transport.close()

The transport closes immedately! Essentially, there was an implicit timeout in asyncio.run() cause by the function returning. Awaiting on a future will fix that. Here are the updates to the ChatProtocol class:

import asyncio
import typing as t


class ChatProtocol(asyncio.Protocol):

    def __init__(
        self,
        delimiter: bytes = b"\r\n",
        on_exit: t.Optional[asyncio.Future] = None,
    ):
        self.buffer = bytearray()
        self.start = 0  # Where we'll start the delimiter search from
        self.delimiter = delimiter
        self.on_exit = on_exit

    ...

    def connection_lost(self, exc: Exception) -> None:
        if self.on_exit is not None:
            self.on_exit.set_result(True)

It can now optionally accept a future to wait on to signal when the server connection is lost. The existing ChatClientProtocol class requires no additional updates, but here’s the new main function for the client:

async def main():
    loop = asyncio.get_running_loop()
    on_exit = loop.create_future()
    transport, _ = await loop.create_connection(
        protocol_factory=lambda: ChatClientProtocol(on_exit=on_exit),
        host="localhost",
        port=8007,
    )
    try:
        await on_exit
    finally:
        transport.close()

This makes sure that the client doesn’t exit until such time as it gets a signal that it should exit when the server closes the connection.

Here’s the finished client:

import asyncio

from server import ChatProtocol


class ChatClientProtocol(ChatProtocol):
    def write_line(self, line: str) -> None:
        # This relies on the transport to flush appropriately.
        self.transport.write(line.encode())
        self.transport.write(self.delimiter)

    def connection_made(self, transport: asyncio.BaseTransport) -> None:
        super().connection_made(transport)
        self.write_line(input())

    def line_received(self, line: bytes) -> None:
        print(line.decode())
        self.write_line(input())


async def main():
    loop = asyncio.get_running_loop()
    on_exit = loop.create_future()
    transport, _ = await loop.create_connection(
        protocol_factory=lambda: ChatClientProtocol(on_exit=on_exit),
        host="localhost",
        port=8007,
    )
    try:
        await on_exit
    finally:
        transport.close()


if __name__ == "__main__":
    asyncio.run(main())

And the server:

import asyncio
import typing as t


class ChatProtocol(asyncio.Protocol):
    __slots__ = ["transport", "delimiter", "buffer"]

    def __init__(
        self,
        delimiter: bytes = b"\r\n",
        on_exit: t.Optional[asyncio.Future] = None,
    ):
        self.buffer = bytearray()
        self.start = 0  # Where we'll start the delimiter search from
        self.delimiter = delimiter
        self.on_exit = on_exit

    def connection_made(self, transport: asyncio.BaseTransport) -> None:
        self.transport: asyncio.Transport = transport  # type: ignore

    def data_received(self, data: bytes) -> None:
        self.buffer.extend(data)
        # The data that's received may contain multiple delimiters, so we try
        # to find each.
        while True:
            pos = self.buffer.find(self.delimiter, self.start)
            if pos == -1:
                # The delimiter hasn't been found, so next time we start checking
                # from the end of the buffer, just far enough back to match the
                # delimiter if a single byte is added.
                self.start = max(0, len(self.buffer) - len(self.delimiter) + 1)
                break
            # Split the buffer on the delimiter
            line, self.buffer = (
                self.buffer[:pos],
                self.buffer[pos + len(self.delimiter) :],
            )
            self.start = 0
            self.line_received(bytes(line))

    def line_received(self, line: bytes) -> None:
        pass

    def connection_lost(self, exc: Exception) -> None:
        if self.on_exit is not None:
            self.on_exit.set_result(True)


class ChatEchoProtocol(ChatProtocol):
    def line_received(self, line: bytes) -> None:
        if line == b".":
            self.transport.close()
        else:
            self.transport.write(b"You sent: ")
            self.transport.write(line)
            self.transport.write(self.delimiter)


async def main():
    server = await asyncio.get_running_loop().create_server(
        protocol_factory=ChatEchoProtocol,
        host="localhost",
        port=8007,
        reuse_port=True,
    )
    async with server:
        await server.serve_forever()


if __name__ == "__main__":
    asyncio.run(main())