Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 55 additions & 29 deletions telegram/ext/_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class Updater(AsyncContextManager["Updater"]):
"__lock",
"__polling_cleanup_cb",
"__polling_task",
"__polling_task_stop_event",
"_httpd",
"_initialized",
"_last_update_id",
Expand All @@ -126,6 +127,7 @@ def __init__(
self._httpd: Optional[WebhookServer] = None
self.__lock = asyncio.Lock()
self.__polling_task: Optional[asyncio.Task] = None
self.__polling_task_stop_event: asyncio.Event = asyncio.Event()
Comment thread
Bibo-Joshi marked this conversation as resolved.
self.__polling_cleanup_cb: Optional[Callable[[], Coroutine[Any, Any, None]]] = None

async def __aenter__(self: _UpdaterType) -> _UpdaterType: # noqa: PYI019
Expand Down Expand Up @@ -417,6 +419,7 @@ def default_error_callback(exc: TelegramError) -> None:
on_err_cb=error_callback or default_error_callback,
description="getting Updates",
interval=poll_interval,
stop_event=self.__polling_task_stop_event,
),
name="Updater:start_polling:polling_task",
)
Expand Down Expand Up @@ -693,6 +696,7 @@ async def _network_loop_retry(
on_err_cb: Callable[[TelegramError], None],
description: str,
interval: float,
stop_event: Optional[asyncio.Event],
) -> None:
"""Perform a loop calling `action_cb`, retrying after network errors.

Expand All @@ -706,39 +710,58 @@ async def _network_loop_retry(
description (:obj:`str`): Description text to use for logs and exception raised.
interval (:obj:`float` | :obj:`int`): Interval to sleep between each call to
`action_cb`.
stop_event (:class:`asyncio.Event` | :obj:`None`): Event to wait on for stopping the
loop. Setting the event will make the loop exit even if `action_cb` is currently
running.

"""

async def do_action() -> bool:
if not stop_event:
return await action_cb()

action_cb_task = asyncio.create_task(action_cb())
stop_task = asyncio.create_task(stop_event.wait())
done, pending = await asyncio.wait(
(action_cb_task, stop_task), return_when=asyncio.FIRST_COMPLETED
)
Comment thread
Bibo-Joshi marked this conversation as resolved.
with contextlib.suppress(asyncio.CancelledError):
for task in pending:
task.cancel()

if stop_task in done:
_LOGGER.debug("Network loop retry %s was cancelled", description)
return False

return action_cb_task.result()

_LOGGER.debug("Start network loop retry %s", description)
cur_interval = interval
try:
while self.running:
try:
if not await action_cb():
break
except RetryAfter as exc:
_LOGGER.info("%s", exc)
cur_interval = 0.5 + exc.retry_after
except TimedOut as toe:
_LOGGER.debug("Timed out %s: %s", description, toe)
# If failure is due to timeout, we should retry asap.
cur_interval = 0
except InvalidToken as pex:
_LOGGER.error("Invalid token; aborting")
raise pex
except TelegramError as telegram_exc:
_LOGGER.error("Error while %s: %s", description, telegram_exc)
on_err_cb(telegram_exc)

# increase waiting times on subsequent errors up to 30secs
cur_interval = 1 if cur_interval == 0 else min(30, 1.5 * cur_interval)
else:
cur_interval = interval

if cur_interval:
await asyncio.sleep(cur_interval)
while self.running:
try:
if not await do_action():
break
except RetryAfter as exc:
_LOGGER.info("%s", exc)
cur_interval = 0.5 + exc.retry_after
except TimedOut as toe:
_LOGGER.debug("Timed out %s: %s", description, toe)
# If failure is due to timeout, we should retry asap.
cur_interval = 0
except InvalidToken as pex:
_LOGGER.error("Invalid token; aborting")
raise pex
except TelegramError as telegram_exc:
_LOGGER.error("Error while %s: %s", description, telegram_exc)
on_err_cb(telegram_exc)

# increase waiting times on subsequent errors up to 30secs
cur_interval = 1 if cur_interval == 0 else min(30, 1.5 * cur_interval)
else:
cur_interval = interval

except asyncio.CancelledError:
_LOGGER.debug("Network loop retry %s was cancelled", description)
if cur_interval:
await asyncio.sleep(cur_interval)

async def _bootstrap(
self,
Expand Down Expand Up @@ -804,6 +827,7 @@ def bootstrap_on_err_cb(exc: Exception) -> None:
bootstrap_on_err_cb,
"bootstrap del webhook",
bootstrap_interval,
stop_event=None,
)

# Reset the retries counter for the next _network_loop_retry call
Expand All @@ -817,6 +841,7 @@ def bootstrap_on_err_cb(exc: Exception) -> None:
bootstrap_on_err_cb,
"bootstrap set webhook",
bootstrap_interval,
stop_event=None,
)

async def stop(self) -> None:
Expand Down Expand Up @@ -852,14 +877,15 @@ async def _stop_polling(self) -> None:
"""Stops the polling task by awaiting it."""
if self.__polling_task:
_LOGGER.debug("Waiting background polling task to finish up.")
self.__polling_task.cancel()
self.__polling_task_stop_event.set()

with contextlib.suppress(asyncio.CancelledError):
await self.__polling_task
# It only fails in rare edge-cases, e.g. when `stop()` is called directly
# after start_polling(), but lets better be safe than sorry ...

self.__polling_task = None
self.__polling_task_stop_event.clear()

if self.__polling_cleanup_cb:
await self.__polling_cleanup_cb()
Expand Down
11 changes: 10 additions & 1 deletion tests/ext/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,15 +1432,22 @@ async def callback(update, context):
)
def test_run_polling_basic(self, app, monkeypatch, caplog):
exception_event = threading.Event()
exception_testing_done = threading.Event()
update_event = threading.Event()
exception = TelegramError("This is a test error")
assertions = {}

async def get_updates(*args, **kwargs):
if exception_event.is_set():
raise exception

# This makes sure that other coroutines have a chance of running as well
await asyncio.sleep(0)
if exception_testing_done.is_set() and app.updater.running:
# the longer sleep makes sure that we can exit also while get_updates is running
await asyncio.sleep(20)
Comment thread
harshil21 marked this conversation as resolved.
else:
await asyncio.sleep(0.01)

update_event.set()
return [self.message_update]

Expand All @@ -1466,10 +1473,12 @@ def thread_target():
exception_event.set()
time.sleep(0.05)
assertions["exception_handling"] = self.received == exception.message
exception_testing_done.set()

# So that the get_updates call on shutdown doesn't fail
exception_event.clear()

time.sleep(1)
os.kill(os.getpid(), signal.SIGINT)
time.sleep(0.1)

Expand Down
19 changes: 11 additions & 8 deletions tests/ext/test_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ async def get_updates(*args, **kwargs):
updates.task_done()
return [next_update]

await asyncio.sleep(0)
await asyncio.sleep(0.1)
return []

orig_del_webhook = updater.bot.delete_webhook
Expand Down Expand Up @@ -520,10 +520,13 @@ async def test_start_polling_exceptions_and_error_callback(
):
raise_exception = True
get_updates_event = asyncio.Event()
second_get_updates_event = asyncio.Event()

async def get_updates(*args, **kwargs):
# So that the main task has a chance to be called
await asyncio.sleep(0)
if get_updates_event.is_set():
second_get_updates_event.set()

if not raise_exception:
return []
Expand All @@ -548,6 +551,9 @@ async def get_updates(*args, **kwargs):

# Also makes sure that the error handler was called
await get_updates_event.wait()
# wait for get_updates to be called a second time - only now we can expect that
# all error handling for the previous call has finished
await second_get_updates_event.wait()

if callback_should_be_called:
# Make sure that the error handler was called
Expand Down Expand Up @@ -588,17 +594,14 @@ async def get_updates(*args, **kwargs):
async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog):
update_queue = asyncio.Queue()
await update_queue.put(Update(update_id=1))
await update_queue.put(Update(update_id=2))
first_update_event = asyncio.Event()
second_update_event = asyncio.Event()

async def get_updates(*args, **kwargs):
self.message_count = kwargs.get("offset")
update = await update_queue.get()
if update.update_id == 1:
first_update_event.set()
else:
await second_update_event.wait()
first_update_event.set()
await second_update_event.wait()
return [update]

monkeypatch.setattr(updater.bot, "get_updates", get_updates)
Expand All @@ -611,8 +614,8 @@ async def get_updates(*args, **kwargs):
# Unfortunately we need to use the private attribute here to produce the problem
updater._running = False
second_update_event.set()
await asyncio.sleep(1)

await asyncio.sleep(0.1)
assert caplog.records
assert any(
"Updater stopped unexpectedly." in record.getMessage()
Expand All @@ -621,7 +624,7 @@ async def get_updates(*args, **kwargs):
)

# Make sure that the update_id offset wasn't increased
assert self.message_count == 2
assert self.message_count < 1

async def test_start_polling_not_running_after_failure(self, updater, monkeypatch):
# Unfortunately we have to use some internal logic to trigger an exception
Expand Down