|
@@ -4,6 +4,7 @@ import hashlib
|
|
|
import logging
|
|
|
import ssl
|
|
|
import socket
|
|
|
+import time
|
|
|
|
|
|
import click
|
|
|
|
|
@@ -14,16 +15,44 @@ logging.basicConfig(level=logging.INFO,
|
|
|
logger = logging.getLogger('certo')
|
|
|
|
|
|
|
|
|
+# The following two functions from
|
|
|
+# https://stackoverflow.com/questions/17667903/python-socket-receive-large-amount-of-data
|
|
|
+def recv_msg(sock, timeout):
|
|
|
+ # Read message length and unpack it into an integer
|
|
|
+ raw_msglen = recvall(sock, 1024, timeout)
|
|
|
+ if not raw_msglen:
|
|
|
+ return None
|
|
|
+ msglen = struct.unpack('>I', raw_msglen)[0]
|
|
|
+ # Read the message data
|
|
|
+ return recvall(sock, msglen, timeout)
|
|
|
+
|
|
|
+
|
|
|
+def recvall(sock, n, timeout):
|
|
|
+ # Helper function to recv n bytes or return None if EOF is hit
|
|
|
+ data = b''
|
|
|
+ begin = time.time()
|
|
|
+ while len(data) < n:
|
|
|
+ packet = sock.recv(n - len(data))
|
|
|
+ if not packet:
|
|
|
+ return None
|
|
|
+ data += packet
|
|
|
+ if time.time() - begin > timeout:
|
|
|
+ break
|
|
|
+ return data
|
|
|
+
|
|
|
+
|
|
|
def establish_conn(addr, port, starttls):
|
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
- sock.settimeout(1)
|
|
|
+ sock.settimeout(5)
|
|
|
try:
|
|
|
if starttls:
|
|
|
logger.debug("Using STARTTLS")
|
|
|
logger.debug("Connecting to %s:%s" % (addr, port))
|
|
|
sock.connect((addr, port))
|
|
|
- sock.send(b"STARTTLS\n")
|
|
|
- sock.recv(1000)
|
|
|
+ sock.send(b"STARTTLS\r\n")
|
|
|
+ data = recv_msg(sock, 5)
|
|
|
+ if data is None:
|
|
|
+ raise socket.error
|
|
|
wrapped_socket = ssl.wrap_socket(sock)
|
|
|
else:
|
|
|
wrapped_socket = ssl.wrap_socket(sock)
|