generated from guanyuankai/bonus-edge-proxy
304 lines
11 KiB
Python
304 lines
11 KiB
Python
"""
|
|
*******************************************************************
|
|
Copyright (c) 2013, 2020 IBM Corp.
|
|
|
|
All rights reserved. This program and the accompanying materials
|
|
are made available under the terms of the Eclipse Public License v2.0
|
|
and Eclipse Distribution License v1.0 which accompany this distribution.
|
|
|
|
The Eclipse Public License is available at
|
|
https://www.eclipse.org/legal/epl-2.0/
|
|
and the Eclipse Distribution License is available at
|
|
http://www.eclipse.org/org/documents/edl-v10.php.
|
|
|
|
Contributors:
|
|
Ian Craggs - initial implementation and/or documentation
|
|
Ian Craggs - add MQTTV5 support
|
|
*******************************************************************
|
|
"""
|
|
from __future__ import print_function
|
|
|
|
import socket
|
|
import sys
|
|
import select
|
|
import traceback
|
|
import datetime
|
|
import os
|
|
import base64
|
|
import hashlib
|
|
import logging
|
|
try:
|
|
import socketserver
|
|
import MQTTV311 # Trace MQTT traffic - Python 3 version
|
|
import MQTTV5
|
|
except:
|
|
traceback.print_exc()
|
|
import SocketServer as socketserver
|
|
import MQTTV3112 as MQTTV311 # Trace MQTT traffic - Python 2 version
|
|
import MQTTV5
|
|
|
|
MQTT = MQTTV311
|
|
logging = True
|
|
myWindow = None
|
|
|
|
|
|
class BufferedSockets:
|
|
|
|
def __init__(self, socket):
|
|
self.socket = socket
|
|
self.buffer = bytearray()
|
|
self.websockets = False
|
|
|
|
def close(self):
|
|
self.socket.shutdown(socket.SHUT_RDWR)
|
|
self.socket.close()
|
|
|
|
def rebuffer(self, data):
|
|
self.buffer = data + self.buffer
|
|
|
|
def wsrecv(self):
|
|
try:
|
|
header1 = ord(self.socket.recv(1))
|
|
header2 = ord(self.socket.recv(1))
|
|
except:
|
|
return
|
|
|
|
opcode = (header1 & 0x0f)
|
|
maskbit = (header2 & 0x80) == 0x80
|
|
length = (header2 & 0x7f) # works for 0 to 125 inclusive
|
|
if length == 126: # for 126 to 65535 inclusive
|
|
lb1 = ord(self.socket.recv(1))
|
|
lb2 = ord(self.socket.recv(1))
|
|
length = lb1*256 + lb2
|
|
elif length == 127:
|
|
length = 0
|
|
for i in range(0, 8):
|
|
length += ord(self.socket.recv(1)) * 2**((7 - i)*8)
|
|
assert maskbit == True
|
|
if maskbit:
|
|
mask = self.socket.recv(4)
|
|
mpayload = bytearray()
|
|
while len(mpayload) < length:
|
|
mpayload += self.socket.recv(length - len(mpayload))
|
|
buffer = bytearray()
|
|
if maskbit:
|
|
mi = 0
|
|
for i in mpayload:
|
|
buffer.append(i ^ mask[mi])
|
|
mi = (mi+1) % 4
|
|
else:
|
|
buffer = mpayload
|
|
self.buffer += buffer
|
|
|
|
def recv(self, bufsize):
|
|
if self.websockets:
|
|
while len(self.buffer) < bufsize:
|
|
self.wsrecv()
|
|
out = self.buffer[:bufsize]
|
|
self.buffer = self.buffer[bufsize:]
|
|
else:
|
|
if bufsize <= len(self.buffer):
|
|
out = self.buffer[:bufsize]
|
|
self.buffer = self.buffer[bufsize:]
|
|
else:
|
|
out = self.buffer + \
|
|
self.socket.recv(bufsize - len(self.buffer))
|
|
self.buffer = bytes()
|
|
return out
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self.socket, name)
|
|
|
|
def send(self, data):
|
|
header = bytearray()
|
|
if self.websockets:
|
|
header.append(0x82) # opcode
|
|
l = len(data)
|
|
if l < 126:
|
|
header.append(l)
|
|
elif l < 65536:
|
|
""" If 126, the following 2 bytes interpreted as a 16-bit unsigned integer are
|
|
the payload length.
|
|
"""
|
|
header += bytearray([126, l // 256, l % 256])
|
|
elif l < 2**64:
|
|
""" If 127, the following 8 bytes interpreted as a 64-bit unsigned integer (the
|
|
most significant bit MUST be 0) are the payload length.
|
|
"""
|
|
mybytes = [127]
|
|
for i in range(0, 7):
|
|
divisor = 2**((7 - i)*8)
|
|
mybytes.append(l // divisor)
|
|
l %= divisor
|
|
mybytes.append(l) # units
|
|
header += bytearray(mybytes)
|
|
totaldata = header + data
|
|
# Ensure the entire packet is sent by calling send again if necessary
|
|
sent = self.socket.send(totaldata)
|
|
while sent < len(totaldata):
|
|
sent += self.socket.send(totaldata[sent:])
|
|
return sent
|
|
|
|
|
|
def timestamp():
|
|
now = datetime.datetime.now()
|
|
return now.strftime('%Y%m%d %H%M%S')+str(float("."+str(now.microsecond)))[1:]
|
|
|
|
|
|
suspended = []
|
|
|
|
|
|
class MyHandler(socketserver.StreamRequestHandler):
|
|
|
|
def getheaders(self, data):
|
|
"return headers: keys are converted to upper case so that checks are case insensitive"
|
|
headers = {}
|
|
lines = data.splitlines()
|
|
for curline in lines[1:]:
|
|
if curline.find(":") != -1:
|
|
key, value = curline.split(": ", 1)
|
|
headers[key.upper()] = value # headers are case insensitive
|
|
return headers
|
|
|
|
def handshake(self, client):
|
|
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
|
data = client.recv(1024).decode('utf-8')
|
|
headers = self.getheaders(data)
|
|
digest = base64.b64encode(hashlib.sha1(
|
|
(headers['SEC-WEBSOCKET-KEY'] + GUID).encode("utf-8")).digest())
|
|
resp = b"HTTP/1.1 101 Switching Protocols\r\n" +\
|
|
b"Upgrade: websocket\r\n" +\
|
|
b"Connection: Upgrade\r\n" +\
|
|
b"Sec-WebSocket-Protocol: mqtt\r\n" +\
|
|
b"Sec-WebSocket-Accept: " + digest + b"\r\n\r\n"
|
|
return client.send(resp)
|
|
|
|
def handle(self):
|
|
global MQTT
|
|
if not hasattr(self, "ids"):
|
|
self.ids = {}
|
|
if not hasattr(self, "versions"):
|
|
self.versions = {}
|
|
inbuf = True
|
|
first = True
|
|
i = o = e = None
|
|
try:
|
|
clients = BufferedSockets(self.request)
|
|
sock_no = clients.fileno()
|
|
brokers = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
brokers.connect((brokerhost, brokerport))
|
|
terminated = False
|
|
while inbuf != None and not terminated:
|
|
(i, o, e) = select.select([clients, brokers], [], [])
|
|
for s in i:
|
|
if s in suspended:
|
|
print("suspended")
|
|
if s == clients and s not in suspended:
|
|
if first:
|
|
char = clients.recv(1)
|
|
clients.rebuffer(char)
|
|
if char == b"G": # should be websocket connection
|
|
self.handshake(clients)
|
|
clients.websockets = True
|
|
print("Switching to websockets for socket %d" % sock_no)
|
|
inbuf = MQTT.getPacket(clients) # get one packet
|
|
if inbuf == None:
|
|
break
|
|
try:
|
|
# if connect, this could be MQTTV3 or MQTTV5
|
|
if inbuf[0] >> 4 == 1: # connect packet
|
|
protocol_string = b'MQTT'
|
|
pos = inbuf.find(protocol_string)
|
|
if pos != -1:
|
|
version = inbuf[pos +
|
|
len(protocol_string)]
|
|
if version == 5:
|
|
MQTT = MQTTV5
|
|
else:
|
|
MQTT = MQTTV311
|
|
packet = MQTT.unpackPacket(inbuf)
|
|
if hasattr(packet.fh, "MessageType"):
|
|
packet_type = packet.fh.MessageType
|
|
publish_type = MQTT.PUBLISH
|
|
connect_type = MQTT.CONNECT
|
|
else:
|
|
packet_type = packet.fh.PacketType
|
|
publish_type = MQTT.PacketTypes.PUBLISH
|
|
connect_type = MQTT.PacketTypes.CONNECT
|
|
if packet_type == publish_type and \
|
|
packet.topicName == "MQTTSAS topic" and \
|
|
packet.data == b"TERMINATE":
|
|
print("Terminating client", self.ids[id(clients)])
|
|
brokers.close()
|
|
clients.close()
|
|
terminated = True
|
|
break
|
|
elif packet_type == publish_type and \
|
|
packet.topicName == "MQTTSAS topic" and \
|
|
packet.data == b"TERMINATE_SERVER":
|
|
print("Suspending client ", self.ids[id(clients)])
|
|
suspended.append(clients)
|
|
elif packet_type == connect_type:
|
|
self.ids[id(clients)
|
|
] = packet.ClientIdentifier
|
|
self.versions[id(clients)] = 3
|
|
print(timestamp(), "C to S",
|
|
self.ids[id(clients)], str(packet))
|
|
#print([hex(b) for b in inbuf])
|
|
# print(inbuf)
|
|
except:
|
|
traceback.print_exc()
|
|
brokers.send(inbuf) # pass it on
|
|
elif s == brokers:
|
|
inbuf = MQTT.getPacket(brokers) # get one packet
|
|
if inbuf == None:
|
|
break
|
|
try:
|
|
print(timestamp(), "S to C", self.ids[id(clients)], str(MQTT.unpackPacket(inbuf)))
|
|
except:
|
|
traceback.print_exc()
|
|
clients.send(inbuf)
|
|
print(timestamp()+" client " + self.ids[id(clients)]+" connection closing")
|
|
first = False
|
|
except:
|
|
print(repr((i, o, e)), repr(inbuf))
|
|
traceback.print_exc()
|
|
if id(clients) in self.ids.keys():
|
|
del self.ids[id(clients)]
|
|
elif id(clients) in self.versions.keys():
|
|
del self.versions[id(clients)]
|
|
|
|
|
|
class ThreadingTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
|
|
pass
|
|
|
|
|
|
def run():
|
|
global brokerhost, brokerport
|
|
myhost = '127.0.0.1'
|
|
if len(sys.argv) > 1:
|
|
brokerhost = sys.argv[1]
|
|
else:
|
|
brokerhost = '127.0.0.1'
|
|
|
|
if len(sys.argv) > 2:
|
|
brokerport = int(sys.argv[2])
|
|
else:
|
|
brokerport = 1883
|
|
|
|
if len(sys.argv) > 3:
|
|
myport = int(sys.argv[3])
|
|
else:
|
|
if brokerhost == myhost:
|
|
myport = brokerport + 1
|
|
else:
|
|
myport = 1883
|
|
|
|
print("Listening on port", str(myport)+", broker on port", brokerport)
|
|
s = ThreadingTCPServer(("127.0.0.1", myport), MyHandler)
|
|
s.serve_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run()
|