Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeout support #469

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions aiohttp_sse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
reason: Optional[str] = None,
headers: Optional[Mapping[str, str]] = None,
sep: Optional[str] = None,
timeout: Optional[float] = None,
):
super().__init__(status=status, reason=reason)

Expand All @@ -54,6 +55,7 @@ def __init__(
self._ping_interval: float = self.DEFAULT_PING_INTERVAL
self._ping_task: Optional[asyncio.Task[None]] = None
self._sep = sep if sep is not None else self.DEFAULT_SEPARATOR
self._timeout = timeout

def is_connected(self) -> bool:
"""Check connection is prepared and ping task is not done."""
Expand Down Expand Up @@ -130,10 +132,16 @@ async def send(

buffer.write(self._sep)
try:
await self.write(buffer.getvalue().encode("utf-8"))
await asyncio.wait_for( # TODO(PY311): Use asyncio.timeout
self.write(buffer.getvalue().encode("utf-8")),
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
timeout=self._timeout,
)
except ConnectionResetError:
self.stop_streaming()
raise
except asyncio.TimeoutError:
self.stop_streaming()
raise TimeoutError

async def wait(self) -> None:
"""EventSourceResponse object is used for streaming data to the client,
Expand Down Expand Up @@ -202,8 +210,16 @@ async def _ping(self) -> None:
while True:
await asyncio.sleep(self._ping_interval)
try:
await self.write(message)
except (ConnectionResetError, RuntimeError):
await asyncio.wait_for( # TODO(PY311): Use asyncio.timeout
self.write(message),
timeout=self._timeout,
)
except (
ConnectionResetError,
RuntimeError,
TimeoutError,
asyncio.TimeoutError,
):
# RuntimeError - on writing after EOF
break

Expand Down Expand Up @@ -256,12 +272,19 @@ def sse_response(
headers: Optional[Mapping[str, str]] = None,
sep: Optional[str] = None,
response_cls: Type[EventSourceResponse] = EventSourceResponse,
timeout: Optional[float] = None,
) -> Any:
if not issubclass(response_cls, EventSourceResponse):
raise TypeError(
"response_cls must be subclass of "
"aiohttp_sse.EventSourceResponse, got {}".format(response_cls)
)

sse = response_cls(status=status, reason=reason, headers=headers, sep=sep)
sse = response_cls(
status=status,
reason=reason,
headers=headers,
sep=sep,
timeout=timeout,
)
return _ContextManager(sse._prepare(request))
46 changes: 45 additions & 1 deletion tests/test_sse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import sys
from typing import Awaitable, Callable, List
from typing import Awaitable, Callable, List, Optional

import pytest
from aiohttp import web
Expand Down Expand Up @@ -559,3 +559,47 @@ async def handler(request: web.Request) -> EventSourceResponse:

async with client.get("/") as response:
assert 200 == response.status


@pytest.mark.parametrize("timeout", (None, 0.1))
async def test_with_timeout(
aiohttp_client: ClientFixture,
monkeypatch: pytest.MonkeyPatch,
timeout: Optional[float],
) -> None:
"""Test write timeout.

Relates to this issue:
https://github.com/sysid/sse-starlette/issues/89
"""
timeout_raised = False

async def frozen_write(_data: bytes) -> None:
await asyncio.sleep(42)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Problem here, is that we're tampering with the server side of the connection. Is it possible to do something with the client to simulate the hanging connection? Then we can be sure this works correctly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to reproduce hanged connection, but the test covers any time-based issues.

Also I prepared this example to make sure solution helps directly to solve the issue:

import asyncio
from datetime import datetime

from aiohttp import web

from aiohttp_sse import sse_response

TIMEOUT = 5


async def hello(request: web.Request) -> web.StreamResponse:
    """Timeout example.

    How to reproduce the issue:
    1. Run this example
    2. Open console
    3. Executed the command below and then press Ctrl+Z (cmd+Z):
        curl -s -N localhost:8000/events > /dev/null
        
    4. Try to change TIMEOUT to None and repeat the steps above.
    """
    async with sse_response(request, timeout=TIMEOUT) as resp:
        i = 0
        try:
            while resp.is_connected():
                spaces = " " * 4096
                data = f"Server Time : {datetime.now()} {spaces}"

                i += 1
                if i % 100 == 0:
                    print(i, data)

                await resp.send(data)
                await asyncio.sleep(0.01)
        except Exception as exc:
            print(f"Exception: {type(exc).__name__} {exc}")
        finally:
            print("Disconnected")

    return resp


if __name__ == "__main__":
    app = web.Application()
    app.router.add_route("GET", "/events", hello)
    web.run_app(app, host="127.0.0.1", port=8000)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, my thinking is that we should be able to do something like resp.connection.transport.pause_reading() in the test to stop the client reading the connection. But, the test is not passing then.

I'm not yet convinced this fixes the reported issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With timeout=None this test is not passed (as expected)

@pytest.mark.parametrize("timeout", (None, 1.0))
async def test_with_timeout(
    aiohttp_client: ClientFixture,
    monkeypatch: pytest.MonkeyPatch,
    timeout: Optional[float],
) -> None:
    """Test write timeout.

    Relates to this issue:
    https://github.com/sysid/sse-starlette/issues/89
    """
    sse_closed = asyncio.Event()

    async def handler(request: web.Request) -> EventSourceResponse:
        sse = EventSourceResponse(timeout=timeout)
        sse.ping_interval = 1
        await sse.prepare(request)

        try:
            async with sse:
                i = 0
                while sse.is_connected():
                    spaces = " " * 4096
                    data = f"Server Time : {datetime.now()} {spaces}"

                    i += 1
                    if i % 100 == 0:
                        print(i, data)

                        await sse.send(data)
                        await asyncio.sleep(0.01)
        finally:
            sse_closed.set()
        return sse  # pragma: no cover

    app = web.Application()
    app.router.add_route("GET", "/", handler)

    client = await aiohttp_client(app)
    async with client.get("/") as resp:
        assert resp.status == 200
        resp.connection.transport.pause_reading()
        print(
            f"Transport paused reading with "
            f"{resp.connection.transport.pause_reading}"
        )
        
        async with asyncio.timeout(10):
            await sse_closed.wait()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing it tests is that the status was 200?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, if I add prints:

            print("A", time.time())
            await self.write(buffer.getvalue().encode("utf-8")),
            print("B", time.time())

I then need to add an await asyncio.sleep(0) to the original test:

            await asyncio.sleep(0)
            try:
                await sse.send("foo")

The send() call doesn't seem to yield, so without the sleep, the client code never runs and manages to pause the reading.

But, then my output looks like:

A 1714071179.1714509
B 1714071179.1715052

So, even after we pause reading, it's not waiting for the client...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, I increased the amount of data sent in each message, as you did above. Now I can see it working correctly!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've pushed some changes to the test. I think that's probably good now. The assert for the connection being closed was failing, so I removed that. Feel free to play with it if you think it should work though.

I'd note from the original issue:

continued generating chunks to send on this connection, slowly saturating TCP buffers before finally simply hanging in the send call.

We are only detecting that final hang and cancelling then. As far as I can tell, the buffers must be around 10MB, so if you were sending a 100 byte message once per minute, then it'd take ~28 hours to detect the hung client and disconnect it...


async def handler(request: web.Request) -> EventSourceResponse:
sse = EventSourceResponse(timeout=timeout)
sse.ping_interval = 42
await sse.prepare(request)
monkeypatch.setattr(sse, "write", frozen_write)

async with sse:
try:
await sse.send("foo")
except TimeoutError:
nonlocal timeout_raised
timeout_raised = True
raise

return sse # pragma: no cover

app = web.Application()
app.router.add_route("GET", "/", handler)

client = await aiohttp_client(app)
async with client.get("/") as resp:
assert resp.status == 200
await asyncio.sleep(0.5)
assert resp.connection and resp.connection.closed is bool(timeout)

assert timeout_raised is bool(timeout)