Skip to content

Commit 570b91d

Browse files
elpransJames William Pye
authored andcommitted
driver: Refactor connectors and reduce code duplication
Connectors have some unnecessarily duplicated code. This refactors them a bit while also allowing to specify a custom socket factory.
1 parent 8222fbf commit 570b91d

1 file changed

Lines changed: 47 additions & 45 deletions

File tree

postgresql/driver/pq3.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2785,79 +2785,75 @@ def socket_factory_sequence(self):
27852785
to the target host.
27862786
"""
27872787

2788-
class IP4(SocketConnector):
2789-
'Connector for establishing IPv4 connections'
2790-
ipv = 4
2788+
def create_socket_factory(self, **params):
2789+
return SocketFactory(**params)
2790+
2791+
class IPConnector(SocketConnector):
27912792
def socket_factory_sequence(self):
27922793
return self._socketcreators
27932794

2794-
def __init__(self,
2795-
host : "IPv4 Address (str)" = None,
2796-
port : int = None,
2797-
ipv = 4,
2798-
**kw
2799-
):
2795+
def socket_factory_params(self, host, port, ipv, **kw):
28002796
if ipv != self.ipv:
2801-
raise TypeError("'ipv' keyword must be '4'")
2797+
raise TypeError("'ipv' keyword must be '%d'" % self.ipv)
28022798
if host is None:
28032799
raise TypeError("'host' is a required keyword and cannot be 'None'")
28042800
if port is None:
28052801
raise TypeError("'port' is a required keyword and cannot be 'None'")
2806-
self.host = host
2807-
self.port = int(port)
2802+
2803+
return {'socket_create': (self.address_family, socket.SOCK_STREAM),
2804+
'socket_connect': (host, int(port))}
2805+
2806+
def __init__(self, host, port, ipv, **kw):
2807+
params = self.socket_factory_params(host, port, ipv, **kw)
2808+
self.host, self.port = params['socket_connect']
28082809
# constant socket connector
2809-
self._socketcreator = SocketFactory(
2810-
(socket.AF_INET, socket.SOCK_STREAM),
2811-
(self.host, self.port)
2812-
)
2813-
self._socketcreators = (
2814-
self._socketcreator,
2815-
)
2810+
self._socketcreator = self.create_socket_factory(**params)
2811+
self._socketcreators = (self._socketcreator,)
28162812
super().__init__(**kw)
28172813

2818-
class IP6(SocketConnector):
2814+
class IP4(IPConnector):
2815+
'Connector for establishing IPv4 connections'
2816+
ipv = 4
2817+
address_family = socket.AF_INET
2818+
2819+
def __init__(self,
2820+
host : "IPv4 Address (str)" = None,
2821+
port : int = None,
2822+
ipv = 4,
2823+
**kw
2824+
):
2825+
super().__init__(host, port, ipv, **kw)
2826+
2827+
class IP6(IPConnector):
28192828
'Connector for establishing IPv6 connections'
28202829
ipv = 6
2821-
def socket_factory_sequence(self):
2822-
return self._socketcreators
2830+
address_family = socket.AF_INET6
28232831

28242832
def __init__(self,
28252833
host : "IPv6 Address (str)" = None,
28262834
port : int = None,
28272835
ipv = 6,
28282836
**kw
28292837
):
2830-
if ipv != self.ipv:
2831-
raise TypeError("'ipv' keyword must be '6'")
2832-
if host is None:
2833-
raise TypeError("'host' is a required keyword and cannot be 'None'")
2834-
if port is None:
2835-
raise TypeError("'port' is a required keyword and cannot be 'None'")
2836-
self.host = host
2837-
self.port = int(port)
2838-
# constant socket connector
2839-
self._socketcreator = SocketFactory(
2840-
(socket.AF_INET6, socket.SOCK_STREAM),
2841-
(self.host, self.port)
2842-
)
2843-
self._socketcreators = (
2844-
self._socketcreator,
2845-
)
2846-
super().__init__(**kw)
2838+
super().__init__(host, port, ipv, **kw)
28472839

28482840
class Unix(SocketConnector):
28492841
'Connector for establishing unix domain socket connections'
28502842
def socket_factory_sequence(self):
28512843
return self._socketcreators
28522844

2853-
def __init__(self, unix = None, **kw):
2845+
def socket_factory_params(self, unix):
28542846
if unix is None:
28552847
raise TypeError("'unix' is a required keyword and cannot be 'None'")
2856-
self.unix = unix
2848+
2849+
return {'socket_create': (socket.AF_UNIX, socket.SOCK_STREAM),
2850+
'socket_connect': unix}
2851+
2852+
def __init__(self, unix = None, **kw):
2853+
params = self.socket_factory_params(unix)
2854+
self.unix = params['socket_connect']
28572855
# constant socket connector
2858-
self._socketcreator = SocketFactory(
2859-
(socket.AF_UNIX, socket.SOCK_STREAM), self.unix
2860-
)
2856+
self._socketcreator = self.create_socket_factory(**params)
28612857
self._socketcreators = (self._socketcreator,)
28622858
super().__init__(**kw)
28632859

@@ -2874,12 +2870,18 @@ def socket_factory_sequence(self):
28742870
"""
28752871
return [
28762872
# (AF, socktype, proto), (IP, Port)
2877-
SocketFactory(x[0:3], x[4][:2], self._socket_secure)
2873+
self.create_socket_factory(**(self.socket_factory_params(x[0:3], x[4][:2],
2874+
self._socket_secure)))
28782875
for x in socket.getaddrinfo(
28792876
self.host, self.port, self._address_family, socket.SOCK_STREAM
28802877
)
28812878
]
28822879

2880+
def socket_factory_params(self, socktype, address, sslparams):
2881+
return {'socket_create': socktype,
2882+
'socket_connect': address,
2883+
'socket_secure': sslparams}
2884+
28832885
def __init__(self,
28842886
host : str = None,
28852887
port : (str, int) = None,

0 commit comments

Comments
 (0)