@@ -504,6 +504,75 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
504
504
return addrs , params , config
505
505
506
506
507
+ async def _create_ssl_connection2 (protocol_factory , host , port , * ,
508
+ loop , ssl_context , ssl_is_advisory = False ):
509
+
510
+ class TLSUpgradeProto (asyncio .Protocol ):
511
+ def __init__ (self ):
512
+ self .on_data = loop .create_future ()
513
+
514
+ def data_received (self , data ):
515
+ if data == b'S' :
516
+ self .on_data .set_result (True )
517
+ elif (ssl_is_advisory and
518
+ ssl_context .verify_mode == ssl_module .CERT_NONE and
519
+ data == b'N' ):
520
+ # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
521
+ # since the only way to get ssl_is_advisory is from
522
+ # sslmode=prefer (or sslmode=allow). But be extra sure to
523
+ # disallow insecure connections when the ssl context asks for
524
+ # real security.
525
+ self .on_data .set_result (False )
526
+ else :
527
+ self .on_data .set_exception (
528
+ ConnectionError (
529
+ 'PostgreSQL server at "{}:{}" '
530
+ 'rejected SSL upgrade' .format (host , port )))
531
+
532
+ def connection_lost (self , exc ):
533
+ if not self .on_data .done ():
534
+ if exc is None :
535
+ exc = ConnectionError ('unexpected connection_lost() call' )
536
+ self .on_data .set_exception (exc )
537
+
538
+ if hasattr (loop , 'start_tls' ):
539
+ tr , pr = await loop .create_connection (TLSUpgradeProto , host , port )
540
+ tr .write (struct .pack ('!ll' , 8 , 80877103 )) # SSLRequest message.
541
+
542
+ try :
543
+ ssl_upgrade = await pr .on_data
544
+ except (Exception , asyncio .CancelledError ):
545
+ tr .close ()
546
+ raise
547
+
548
+ if ssl_upgrade :
549
+ if ssl_context is True :
550
+ ssl_context = ssl_module .create_default_context ()
551
+
552
+ try :
553
+ new_tr = await loop .start_tls (
554
+ tr , pr , ssl_context , server_hostname = host )
555
+ except (Exception , asyncio .CancelledError ):
556
+ tr .close ()
557
+ raise
558
+ else :
559
+ new_tr = tr
560
+
561
+ pg_proto = protocol_factory ()
562
+ pg_proto .connection_made (new_tr )
563
+ new_tr .set_protocol (pg_proto )
564
+
565
+ return new_tr , pg_proto
566
+ else :
567
+ return await _negotiate_ssl_connection (
568
+ host , port ,
569
+ functools .partial (loop .create_connection , protocol_factory ),
570
+ loop = loop ,
571
+ ssl = ssl_context ,
572
+ server_hostname = host ,
573
+ ssl_is_advisory = ssl_is_advisory )
574
+
575
+
507
576
async def _connect_addr (* , addr , loop , timeout , params , config ,
508
577
connection_class ):
509
578
assert loop is not None
@@ -520,7 +589,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
520
589
assert not params .ssl
521
590
connector = loop .create_unix_connection (proto_factory , addr )
522
591
elif params .ssl :
523
- connector = _create_ssl_connection (
592
+ connector = _create_ssl_connection2 (
524
593
proto_factory , * addr , loop = loop , ssl_context = params .ssl ,
525
594
ssl_is_advisory = params .ssl_is_advisory )
526
595
else :
0 commit comments