diff --git a/untls_proxy.py b/untls_proxy.py index 02957f3..44b2adb 100755 --- a/untls_proxy.py +++ b/untls_proxy.py @@ -117,6 +117,7 @@ # import base64 +import enum import select import socket import ssl @@ -124,6 +125,103 @@ import sys import time import threading +class contexts(enum.Enum): + text, tagname, attributename, after_attributename, after_equals, attributevalue, attributevalue_sq, attributevalue_dq = range(8) + +class HtmlProcessor: + def __init__(self): + self.context = contexts.text + self.tag = None + self.attribute = None + self.value = None + + def process_attribute(self): + tag = self.tag.lower() + attribute = self.attribute.lower() + # TODO: handle more attributes + if tag == b'a' and attribute == b'href' or tag == b'img' and attribute == b'src': + if self.value.strip().lower().startswith(b'https://'): + # Space is to keep the response size constant + return b' http://' + self.value.strip()[len(b'https://'):] + else: + return self.value + else: + return self.value + + def process(self, data): + processed = bytearray() + for char in data: + if self.context == contexts.text and char == ord('<'): + self.context = contexts.tagname + self.tag = bytearray() + self.attribute = None + self.value = None + elif self.context not in (contexts.attributevalue_sq, contexts.attributevalue_dq) and char == ord('>'): + if self.context == contexts.attributevalue: processed.extend(self.process_attribute()) + self.context = contexts.text + self.tag = None + self.attribute = None + self.value = None + elif self.context in (contexts.tagname, contexts.attributevalue) and chr(char).isspace(): + if self.context == contexts.attributevalue: processed.extend(self.process_attribute()) + self.context = contexts.attributename + self.attribute = bytearray() + self.value = None + elif self.context == contexts.attributename and chr(char).isspace(): + self.context = contexts.after_attributename + elif self.context == contexts.after_attributename and chr(char).isspace(): + pass + elif self.context in (contexts.attributename, contexts.after_attributename) and char == ord('='): + self.context = contexts.after_equals + elif self.context == contexts.after_equals and chr(char).isspace(): + pass + elif self.context == contexts.after_equals and char == ord("'"): + self.context = contexts.attributevalue_sq + self.value = bytearray() + elif self.context == contexts.after_equals and char == ord('"'): + self.context = contexts.attributevalue_dq + self.value = bytearray() + + elif self.context == contexts.attributevalue_sq and char == ord("'"): + processed.extend(self.process_attribute()) + self.context = contexts.attributename + elif self.context == contexts.attributevalue_dq and char == ord('"'): + processed.extend(self.process_attribute()) + self.context = contexts.attributename + + elif self.context == contexts.tagname: + self.tag.append(char) + elif self.context == contexts.attributename: + self.attribute.append(char) + elif self.context == contexts.after_attributename: + self.context = contexts.attributename + self.attribute = bytearray([char]) + self.value = None + elif self.context == contexts.after_equals: + self.context = contexts.attributevalue + self.value = bytearray([char]) + elif self.context in (contexts.attributevalue, contexts.attributevalue_sq, contexts.attributevalue_dq): + self.value.append(char) + + elif self.context == contexts.text: + pass + + if self.context == contexts.attributevalue: + pass + elif self.context == contexts.attributevalue_sq and char != ord("'"): + pass + elif self.context == contexts.attributevalue_dq and char != ord('"'): + pass + else: + processed.append(char) + + return processed + + def finalize(self): + if self.context in (contexts.attributevalue, contexts.attributevalue_sq, contexts.attributevalue_dq): + return self.process_attribute() + return b'' + def connect(host, port): try: for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM): @@ -228,9 +326,10 @@ def proxy(sock, host): del password # Remove headers that don't need forwarding or are overwritten - headers = dict((key, value) for key, value in headers.items() if not key.startswith(b'proxy-') and not key in (b'connection', b'keep-alive')) + headers = dict((key, value) for key, value in headers.items() if not key.startswith(b'proxy-') and not key in (b'connection', b'accept-encoding', b'keep-alive')) headers[b'connection'] = b'close' + headers[b'accept-enoding'] = b'identity' # Split url into its constituents fields = url.split(b'://', 1) @@ -327,22 +426,17 @@ def proxy(sock, host): remote_sock.settimeout(None) response, _, response_data = response.partition(b'\r\n\r\n') - # Figure out if this is a redirect to HTTPS - # If it is, rewrite to HTTP + # Process response headers + # Figure out if this is a redirect to HTTPS and if so, rewrite to HTTP + # Figure out whether response is html tls_redirect = False - fields = response.split(b'\r\n')[0].split(b' ') - rewritten_response = None - if len(fields) > 1 and fields[1] in (b'301', b'302', b'303', b'307', b'308'): - rewritten_response = bytearray() - rewritten_response.extend(response.split(b'\r\n')[0]) # Include response line as-is - rewritten_response.extend(b'\r\n') - for line in response.split(b'\r\n')[1:]: - fields = line.split(b':', 1) - if len(fields) != 2 or fields[0].lower() != b'location': - rewritten_response.extend(line) - rewritten_response.extend(b'\r\n') - continue - + is_html = True + rewritten_response = bytearray() + rewritten_response.extend(response.split(b'\r\n')[0]) # Include response line as-is + rewritten_response.extend(b'\r\n') + for line in response.split(b'\r\n')[1:]: + fields = line.split(b':', 1) + if len(fields) == 2 and fields[0].lower() == b'location': destination_url = fields[1].strip() if destination_url.startswith(b'https://'): destination_url = b'http://' + destination_url[len(b'https://'):] @@ -355,6 +449,16 @@ def proxy(sock, host): # This redirect is of the current URL but TLS tls_redirect = True + elif len(fields) == 2 and fields[0].lower() == b'content-type': + mimetype = fields[1].split(b';')[0].strip().lower() + is_html = mimetype == b'text/html' + rewritten_response.extend(line) + rewritten_response.extend(b'\r\n') + + else: + rewritten_response.extend(line) + rewritten_response.extend(b'\r\n') + if tls_redirect and not tls: # Do upgrade to TLS transparently to client print('TLS', file=sys.stderr, end=' ') @@ -370,58 +474,69 @@ def proxy(sock, host): continue # Forward response to client - if rewritten_response is not None: - sock.sendall(rewritten_response) - else: - sock.sendall(response) - sock.sendall(b'\r\n\r\n') - sock.sendall(response_data) + sock.sendall(rewritten_response) + sock.sendall(b'\r\n') break del request_data - # TODO: Un-https links + if is_html: + htmlprocessor = HtmlProcessor() + sock.sendall(htmlprocessor.process(response_data)) + else: + sock.sendall(response_data) + print('', file=sys.stderr) sock.settimeout(60) remote_sock.settimeout(60) last_transfer = time.monotonic() - while True: + ending_connection = False + while not ending_connection: events = poll.poll(60_000) if len(events) == 0 and time.monotonic() - last_transfer > 60: - remote_sock.close() - return + break for fd, _ in events: if fd == sock.fileno(): try: data = sock.recv(1024) - except (ConnectionResetError): - return + except ConnectionResetError: + ending_connection = True + break if data != b'': try: remote_sock.sendall(data) - except (ConnectionResetError, BrokenPipeError): - return - except socket.timeout: + except (ConnectionResetError, BrokenPipeError, socket.timeout): pass else: try: data = remote_sock.recv(1024) except (ConnectionResetError, socket.timeout): - return + ending_connection = True + break if data == b'': - remote_sock.close() - return + ending_connection = True + break + if is_html: + data = htmlprocessor.process(data) try: sock.sendall(data) except (ConnectionResetError, BrokenPipeError, socket.timeout): - remote_sock.close() - return + ending_connection = True + break last_transfer = time.monotonic() + remote_sock.close() + + if is_html: + try: + sock.sendall(htmlprocessor.finalize()) + except (ConnectionResetError, BrokenPipeError, socket.timeout): + pass + class ProxyThread(threading.Thread): def __init__(self, sock, host): self.sock = sock