Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion telegram/ext/_applicationbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
("connect_timeout", "connect_timeout"),
("read_timeout", "read_timeout"),
("write_timeout", "write_timeout"),
("media_write_timeout", "media_write_timeout"),
("http_version", "http_version"),
("get_updates_connection_pool_size", "get_updates_connection_pool_size"),
("get_updates_proxy", "get_updates_proxy"),
Expand Down Expand Up @@ -152,6 +153,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
"_http_version",
"_job_queue",
"_local_mode",
"_media_write_timeout",
"_persistence",
"_pool_timeout",
"_post_init",
Expand Down Expand Up @@ -181,6 +183,7 @@ def __init__(self: "InitApplicationBuilder"):
self._connect_timeout: ODVInput[float] = DEFAULT_NONE
self._read_timeout: ODVInput[float] = DEFAULT_NONE
self._write_timeout: ODVInput[float] = DEFAULT_NONE
self._media_write_timeout: ODVInput[float] = DEFAULT_NONE
self._pool_timeout: ODVInput[float] = DEFAULT_NONE
self._request: DVInput[BaseRequest] = DEFAULT_NONE
self._get_updates_connection_pool_size: DVInput[int] = DEFAULT_NONE
Expand Down Expand Up @@ -243,6 +246,10 @@ def _build_request(self, get_updates: bool) -> BaseRequest:
"write_timeout": getattr(self, f"{prefix}write_timeout"),
"pool_timeout": getattr(self, f"{prefix}pool_timeout"),
}

if not get_updates:
timeouts["media_write_timeout"] = self._media_write_timeout

# Get timeouts that were actually set-
effective_timeouts = {
key: value for key, value in timeouts.items() if not isinstance(value, DefaultValue)
Expand Down Expand Up @@ -424,9 +431,13 @@ def _request_check(self, get_updates: bool) -> None:
prefix = "get_updates_" if get_updates else ""
name = prefix + "request"

timeouts = ["connect_timeout", "read_timeout", "write_timeout", "pool_timeout"]
if not get_updates:
timeouts.append("media_write_timeout")

# Code below tests if it's okay to set a Request object. Only okay if no other request args
# or instances containing a Request were set previously
for attr in ("connect_timeout", "read_timeout", "write_timeout", "pool_timeout"):
for attr in timeouts:
if not isinstance(getattr(self, f"_{prefix}{attr}"), DefaultValue):
raise RuntimeError(_TWO_ARGS_REQ.format(name, attr))

Expand Down Expand Up @@ -617,6 +628,26 @@ def write_timeout(self: BuilderType, write_timeout: Optional[float]) -> BuilderT
self._write_timeout = write_timeout
return self

def media_write_timeout(
self: BuilderType, media_write_timeout: Optional[float]
) -> BuilderType:
"""Sets the media write operation timeout for the
Comment thread
harshil21 marked this conversation as resolved.
:paramref:`~telegram.request.HTTPXRequest.media_write_timeout` parameter of
:attr:`telegram.Bot.request`. Defaults to ``20``.

.. versionadded:: NEXT.VERSION

Args:
media_write_timeout (:obj:`float`): See
:paramref:`telegram.request.HTTPXRequest.media_write_timeout` for more information.

Returns:
:class:`ApplicationBuilder`: The same builder with the updated argument.
"""
self._request_param_check(name="media_write_timeout", get_updates=False)
self._media_write_timeout = media_write_timeout
return self

def pool_timeout(self: BuilderType, pool_timeout: Optional[float]) -> BuilderType:
"""Sets the connection pool's connection freeing timeout for the
:paramref:`~telegram.request.HTTPXRequest.pool_timeout` parameter of
Expand Down
20 changes: 14 additions & 6 deletions telegram/request/_httpxrequest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class HTTPXRequest(BaseRequest):
a network socket; i.e. POSTing a request or uploading a file).
This value is used unless a different value is passed to :meth:`do_request`.
Defaults to ``5``.

Hint:
This timeout is used for all requests except for those that upload media/files.
For the latter, :paramref:`media_write_timeout` is used.
connect_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the
maximum amount of time (in seconds) to wait for a connection attempt to a server
to succeed. This value is used unless a different value is passed to
Expand Down Expand Up @@ -112,10 +116,16 @@ class HTTPXRequest(BaseRequest):
.. _the docs of httpx: https://www.python-httpx.org/environment_variables/#proxies

.. versionadded:: 20.7
media_write_timeout (:obj:`float` | :obj:`None`, optional): Like :paramref:`write_timeout`,
but used only for requests that upload media/files. This value is used unless a
different value is passed to :paramref:`do_request.write_timeout` of
:meth:`do_request`. Defaults to ``20`` seconds.

.. versionadded:: NEXT.VERSION

"""

__slots__ = ("_client", "_client_kwargs", "_http_version")
__slots__ = ("_client", "_client_kwargs", "_http_version", "_media_write_timeout")

def __init__(
self,
Expand All @@ -128,6 +138,7 @@ def __init__(
http_version: HTTPVersion = "1.1",
socket_options: Optional[Collection[SocketOpt]] = None,
proxy: Optional[Union[str, httpx.Proxy, httpx.URL]] = None,
media_write_timeout: Optional[float] = 20.0,
):
if proxy_url is not None and proxy is not None:
raise ValueError("The parameters `proxy_url` and `proxy` are mutually exclusive.")
Expand All @@ -142,6 +153,7 @@ def __init__(
)

self._http_version = http_version
self._media_write_timeout = media_write_timeout
Comment thread
Bibo-Joshi marked this conversation as resolved.
timeout = httpx.Timeout(
connect=connect_timeout,
read=read_timeout,
Expand Down Expand Up @@ -251,11 +263,7 @@ async def do_request(
pool_timeout = self._client.timeout.pool

if isinstance(write_timeout, DefaultValue):
# Making the networking backend decide on the proper timeout values instead of doing
# it via the default values of the Bot methods was introduced in version 20.7.
# We hard-code the value here for now until we add additional parameters to this
# class to control the media_write_timeout separately.
write_timeout = self._client.timeout.write if not files else 20
write_timeout = self._client.timeout.write if not files else self._media_write_timeout

timeout = httpx.Timeout(
connect=connect_timeout,
Expand Down
19 changes: 18 additions & 1 deletion tests/ext/test_applicationbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def test_all_methods_request(self, builder, get_updates):
for argument in arguments:
if argument == "self":
continue
if argument == "media_write_timeout" and get_updates:
# get_updates never makes media requests
continue
assert hasattr(builder, prefix + argument), f"missing method {prefix}{argument}"

@pytest.mark.parametrize("bot_class", [Bot, ExtBot])
Expand Down Expand Up @@ -202,6 +205,7 @@ def test_mutually_exclusive_for_bot(self, builder, method, description):
"pool_timeout",
"read_timeout",
"write_timeout",
"media_write_timeout",
"proxy",
"proxy_url",
"socket_options",
Expand Down Expand Up @@ -272,6 +276,7 @@ def test_mutually_exclusive_for_get_updates_request(self, builder, method):
"pool_timeout",
"read_timeout",
"write_timeout",
"media_write_timeout",
"proxy",
"proxy_url",
"socket_options",
Expand Down Expand Up @@ -316,6 +321,7 @@ def test_mutually_exclusive_for_updater(self, builder, method):
"pool_timeout",
"read_timeout",
"write_timeout",
"media_write_timeout",
"proxy",
"proxy_url",
"socket_options",
Expand Down Expand Up @@ -384,12 +390,20 @@ class Client:
http2: object
transport: object = None

original_init = HTTPXRequest.__init__
media_write_timeout = []

def init_httpx_request(self_, *args, **kwargs):
media_write_timeout.append(kwargs.get("media_write_timeout"))
original_init(self_, *args, **kwargs)

monkeypatch.setattr(httpx, "AsyncClient", Client)
monkeypatch.setattr(HTTPXRequest, "__init__", init_httpx_request)

builder = ApplicationBuilder().token(bot.token)
builder.connection_pool_size(1).connect_timeout(2).pool_timeout(3).read_timeout(
4
).write_timeout(5).http_version("1.1")
).write_timeout(5).media_write_timeout(6).http_version("1.1")
getattr(builder, proxy_method)("proxy")
app = builder.build()
client = app.bot.request._client
Expand All @@ -399,7 +413,9 @@ class Client:
assert client.proxy == "proxy"
assert client.http1 is True
assert client.http2 is False
assert media_write_timeout == [6, None]

media_write_timeout.clear()
builder = ApplicationBuilder().token(bot.token)
builder.get_updates_connection_pool_size(1).get_updates_connect_timeout(
2
Expand All @@ -417,6 +433,7 @@ class Client:
assert client.proxy == "get_updates_proxy"
assert client.http1 is True
assert client.http2 is False
assert media_write_timeout == [None, None]

def test_custom_socket_options(self, builder, monkeypatch, bot):
httpx_request_kwargs = []
Expand Down
33 changes: 33 additions & 0 deletions tests/request/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,39 @@ async def request(_, **kwargs):
# other than HTTPXRequest
assert len(recwarn) == 0

@pytest.mark.parametrize("init", [True, False])
async def test_setting_media_write_timeout(
self, monkeypatch, init, input_media_photo, recwarn # noqa: F811
):
httpx_request = HTTPXRequest(media_write_timeout=42) if init else HTTPXRequest()

async def request(_, **kwargs):
self.test_flag = kwargs["timeout"].write
return httpx.Response(HTTPStatus.OK, content=b'{"ok": "True", "result": {}}')

monkeypatch.setattr(httpx.AsyncClient, "request", request)

data = {"string": "string", "int": 1, "float": 1.0, "media": input_media_photo}
request_data = RequestData(
parameters=[RequestParameter.from_input(key, value) for key, value in data.items()],
)

# First make sure that custom timeouts are always respected
await httpx_request.post(
"url",
request_data,
write_timeout=43,
)
assert self.test_flag == 43

# Now also ensure that the init value is respected
await httpx_request.post("url", request_data)
assert self.test_flag == 42 if init else 20

# Just for double-checking, since warnings are issued for implementations of BaseRequest
# other than HTTPXRequest
assert len(recwarn) == 0

async def test_socket_opts(self, monkeypatch):
transport_kwargs = {}
transport_init = AsyncHTTPTransport.__init__
Expand Down