Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions telegram/ext/_precheckoutqueryhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@
"""This module contains the PreCheckoutQueryHandler class."""


import re
from typing import Optional, Pattern, TypeVar, Union

from telegram import Update
from telegram._utils.defaultvalue import DEFAULT_TRUE
from telegram._utils.types import DVType
from telegram.ext._basehandler import BaseHandler
from telegram.ext._utils.types import CCT
from telegram.ext._utils.types import CCT, HandlerCallback

RT = TypeVar("RT")


class PreCheckoutQueryHandler(BaseHandler[Update, CCT]):
Expand All @@ -48,14 +55,32 @@ async def callback(update: Update, context: CallbackContext)
:meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`.

.. seealso:: :wiki:`Concurrency`
pattern (:obj:`str` | :func:`re.Pattern <re.compile>`, optional): Optional. Regex pattern
to test :attr:`telegram.PreCheckoutQuery.invoice_payload` against.

.. versionadded:: NEXT.VERSION

Attributes:
callback (:term:`coroutine function`): The callback function for this handler.
block (:obj:`bool`): Determines whether the callback will run in a blocking way..
pattern (:obj:`str` | :func:`re.Pattern <re.compile>`, optional): Optional. Regex pattern
to test :attr:`telegram.PreCheckoutQuery.invoice_payload` against.

.. versionadded:: NEXT.VERSION

"""

__slots__ = ()
__slots__ = ("pattern",)

def __init__(
self,
callback: HandlerCallback[Update, CCT, RT],
block: DVType[bool] = DEFAULT_TRUE,
pattern: Optional[Union[str, Pattern[str]]] = None,
):
super().__init__(callback, block=block)

self.pattern: Optional[Pattern[str]] = re.compile(pattern) if pattern is not None else None

def check_update(self, update: object) -> bool:
"""Determines whether an update should be passed to this handler's :attr:`callback`.
Expand All @@ -67,4 +92,11 @@ def check_update(self, update: object) -> bool:
:obj:`bool`

"""
return isinstance(update, Update) and bool(update.pre_checkout_query)
if isinstance(update, Update) and update.pre_checkout_query:
invoice_payload = update.pre_checkout_query.invoice_payload
if self.pattern:
if self.pattern.match(invoice_payload):
return True
else:
return True
return False
40 changes: 36 additions & 4 deletions telegram/ext/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"Sticker",
"STORY",
"SUCCESSFUL_PAYMENT",
"SuccessfulPayment",
"SenderChat",
"StatusUpdate",
"TEXT",
Expand Down Expand Up @@ -2265,14 +2266,45 @@ def filter(self, message: Message) -> bool:
"""


class _SuccessfulPayment(MessageFilter):
__slots__ = ()
class SuccessfulPayment(MessageFilter):
"""Successful Payment Messages. If a list of invoice payloads is passed, it filters
messages to only allow those whose `invoice_payload` is appearing in the given list.

Examples:
`MessageHandler(filters.SuccessfulPayment(['Custom-Payload']), callback_method)`

.. seealso::
:attr:`telegram.ext.filters.SUCCESSFUL_PAYMENT`

Args:
invoice_payloads (List[:obj:`str`] | Tuple[:obj:`str`], optional): Which
invoice payloads to allow. Only exact matches are allowed. If not
specified, will allow any invoice payload.

.. versionadded:: NEXT.VERSION
"""

__slots__ = ("invoice_payloads",)

def __init__(self, invoice_payloads: Optional[Union[List[str], Tuple[str, ...]]] = None):
self.invoice_payloads: Optional[Sequence[str]] = invoice_payloads
super().__init__(
name=f"filters.SuccessfulPayment({invoice_payloads})"
if invoice_payloads
else "filters.SUCCESSFUL_PAYMENT"
)

def filter(self, message: Message) -> bool:
return bool(message.successful_payment)
if self.invoice_payloads is None:
return bool(message.successful_payment)
return (
payment.invoice_payload in self.invoice_payloads
if (payment := message.successful_payment)
else False
)


SUCCESSFUL_PAYMENT = _SuccessfulPayment(name="filters.SUCCESSFUL_PAYMENT")
SUCCESSFUL_PAYMENT = SuccessfulPayment()
"""Messages that contain :attr:`telegram.Message.successful_payment`."""


Expand Down
19 changes: 19 additions & 0 deletions tests/ext/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Message,
MessageEntity,
Sticker,
SuccessfulPayment,
Update,
User,
)
Expand Down Expand Up @@ -1877,6 +1878,24 @@ def test_filters_successful_payment(self, update):
update.message.successful_payment = "test"
assert filters.SUCCESSFUL_PAYMENT.check_update(update)

def test_filters_successful_payment_payloads(self, update):
assert not filters.SuccessfulPayment(("custom-payload",)).check_update(update)
assert not filters.SuccessfulPayment().check_update(update)

update.message.successful_payment = SuccessfulPayment(
"USD", 100, "custom-payload", "123", "123"
)
assert filters.SuccessfulPayment(("custom-payload",)).check_update(update)
assert filters.SuccessfulPayment().check_update(update)
assert not filters.SuccessfulPayment(["test1"]).check_update(update)
Comment thread
Bibo-Joshi marked this conversation as resolved.

def test_filters_successful_payment_repr(self):
f = filters.SuccessfulPayment()
assert str(f) == "filters.SUCCESSFUL_PAYMENT"

f = filters.SuccessfulPayment(["payload1", "payload2"])
assert str(f) == "filters.SuccessfulPayment(['payload1', 'payload2'])"

def test_filters_passport_data(self, update):
assert not filters.PASSPORT_DATA.check_update(update)
update.message.passport_data = "test"
Expand Down
23 changes: 22 additions & 1 deletion tests/ext/test_precheckoutqueryhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import asyncio
import re

import pytest

Expand Down Expand Up @@ -69,12 +70,15 @@ def false_update(request):

@pytest.fixture(scope="class")
def pre_checkout_query():
return Update(
update = Update(
1,
pre_checkout_query=PreCheckoutQuery(
"id", User(1, "test user", False), "EUR", 223, "invoice_payload"
),
)
update._unfreeze()
update.pre_checkout_query._unfreeze()
return update


class TestPreCheckoutQueryHandler:
Expand Down Expand Up @@ -103,6 +107,23 @@ async def callback(self, update, context):
and isinstance(update.pre_checkout_query, PreCheckoutQuery)
)

def test_with_pattern(self, pre_checkout_query):
handler = PreCheckoutQueryHandler(self.callback, pattern=".*voice.*")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also test passing a compiled pattern :)


assert handler.check_update(pre_checkout_query)

pre_checkout_query.pre_checkout_query.invoice_payload = "nothing here"
assert not handler.check_update(pre_checkout_query)

def test_with_compiled_pattern(self, pre_checkout_query):
Comment thread
Bibo-Joshi marked this conversation as resolved.
handler = PreCheckoutQueryHandler(self.callback, pattern=re.compile(r".*payload"))

pre_checkout_query.pre_checkout_query.invoice_payload = "invoice_payload"
assert handler.check_update(pre_checkout_query)

pre_checkout_query.pre_checkout_query.invoice_payload = "nothing here"
assert not handler.check_update(pre_checkout_query)

def test_other_update_types(self, false_update):
handler = PreCheckoutQueryHandler(self.callback)
assert not handler.check_update(false_update)
Expand Down