347 lines
8.0 KiB
Lua
347 lines
8.0 KiB
Lua
local mqtt = {}
|
|
local socket = require("socket")
|
|
|
|
local clientMt = {}
|
|
|
|
local function bytesUsedByVarIntForValue(value)
|
|
if value <= 128-1 then
|
|
return 1, nil
|
|
elseif value <= 128*128-1 then
|
|
return 2, nil
|
|
elseif value <= 128*128*128-1 then
|
|
return 3, nil
|
|
elseif value <= 128*128*128*128-1 then
|
|
return 4, nil
|
|
else
|
|
return nil, "invalid byte length"
|
|
end
|
|
end
|
|
|
|
local function bytesUsedByString(string)
|
|
return 2 + #string
|
|
end
|
|
|
|
function clientMt:receiveByte()
|
|
|
|
end
|
|
|
|
function clientMt:flush()
|
|
local start = 1
|
|
while start <= #self.buffer do
|
|
print("flushing data")
|
|
local _, err, sent = self.connection:send(self.buffer, start)
|
|
if err then
|
|
print("Error: " .. err)
|
|
return nil, err
|
|
else
|
|
self.buffer = ""
|
|
return true, err
|
|
end
|
|
end
|
|
return true, nil
|
|
end
|
|
|
|
function clientMt:sendByte(byte)
|
|
self.buffer = self.buffer .. string.char(byte)
|
|
if #self.buffer > 1024 then
|
|
return self:flush()
|
|
end
|
|
end
|
|
|
|
function clientMt:sendData(data)
|
|
local result, err = self:flush()
|
|
if err then return result, err end
|
|
|
|
self.buffer = data
|
|
|
|
self:flush()
|
|
local result, err = self:flush()
|
|
if err then return result, err end
|
|
end
|
|
|
|
function clientMt:sendVarInt(value)
|
|
repeat
|
|
local encoded = value & 0x7F
|
|
value = value >> 7
|
|
if value > 0 then
|
|
encoded = encoded | 128
|
|
end
|
|
local _, err = self:sendByte(encoded)
|
|
if err then
|
|
return nil, err
|
|
end
|
|
until value <= 0
|
|
return true, nil
|
|
end
|
|
|
|
function clientMt:sendShort(value)
|
|
local _, err = self:sendByte((value >> 8) & 0xFF)
|
|
if err then return nil, err end
|
|
local _, err = self:sendByte(value & 0xFF)
|
|
if err then return nil, err end
|
|
end
|
|
|
|
function clientMt:sendString(text)
|
|
local _, err = self:sendShort(#text)
|
|
if err then return nil, err end
|
|
local _, err = self:sendData(text)
|
|
if err then return nil, err end
|
|
end
|
|
|
|
function clientMt:handleError(err, result)
|
|
if err then
|
|
print("Got error")
|
|
if self.connection then
|
|
self.connection:close()
|
|
end
|
|
self.connection = nil
|
|
end
|
|
return result, err
|
|
end
|
|
|
|
function clientMt:sendPacket()
|
|
local result, err = self:flush()
|
|
return self:handleError(err, result)
|
|
end
|
|
|
|
function clientMt:connect()
|
|
if self.connection then
|
|
return true, nil
|
|
end
|
|
|
|
local conn, err = socket.connect(self.uri, 1883)
|
|
if not conn then
|
|
return nil, "failed to connect: " .. err
|
|
end
|
|
conn:setoption("tcp-nodelay", true)
|
|
conn:setoption("linger", {on = true, timeout = 100})
|
|
conn:settimeout(nil)
|
|
self.connection = conn
|
|
|
|
local _, err = self:sendByte(0x10)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
local length = 0
|
|
|
|
local protocolName = "MQTT"
|
|
local protocolNameLength = bytesUsedByString(protocolName)
|
|
length = length + protocolNameLength
|
|
|
|
local protocolVersion = 4
|
|
local connectFlag = 0x02 -- 1 byte
|
|
length = length + 2
|
|
|
|
local keepAlive = 0 -- 2 bytes
|
|
length = length + 2
|
|
|
|
local clientIdLength = bytesUsedByString(self.id)
|
|
length = length + clientIdLength
|
|
|
|
_, err = self:sendVarInt(length)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
_, err = self:sendString(protocolName)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
_, err = self:sendByte(protocolVersion)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
_, err = self:sendByte(connectFlag)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
_, err = self:sendShort(keepAlive)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
_, err = self:sendString(self.id)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
return self:sendPacket()
|
|
end
|
|
|
|
function clientMt:disconnect(args)
|
|
if not self.connection then
|
|
return true, nil
|
|
end
|
|
|
|
local _, err = self:sendByte(0xE0)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
local _, err = self:sendByte(0x00)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
local result, err = self:sendPacket()
|
|
self.connection:shutdown("both")
|
|
local peer
|
|
repeat
|
|
peer = self.connection:getpeername()
|
|
if peer then socket.sleep(0.02) end
|
|
until peer
|
|
self.connection:close()
|
|
self.connection = nil
|
|
return result, err
|
|
end
|
|
|
|
function clientMt:publish(args)
|
|
local topic = args.topic
|
|
local payload = args.payload
|
|
|
|
local _, err = self:connect()
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
local retain = args.retain and 0x01 or 0x00
|
|
local _, err = self:sendByte(0x30 | retain)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
local topicLength = bytesUsedByString(topic)
|
|
local payloadLength = #payload
|
|
|
|
_, err = self:sendVarInt(topicLength + payloadLength)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
_, err = self:sendString(topic)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
_, err = self:sendData(payload)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
return self:sendPacket()
|
|
end
|
|
|
|
function clientMt:subscribe(args)
|
|
local topic = args.topic
|
|
|
|
local _, err = self:connect()
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
local _, err = self:sendByte(0x82)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
local packetIdentifier = self.packetIdentifier
|
|
self.packetIdentifier = self.packetIdentifier + 1
|
|
if self.packetIdentifier > 0xFF00 then
|
|
self.packetIdentifier = 1
|
|
end
|
|
|
|
local topicFilter = 0
|
|
local topicLength = bytesUsedByString(topic)
|
|
|
|
local length = 2 + topicLength + 1
|
|
_, err = self:sendVarInt(length)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
_, err = self:sendShort(packetIdentifier)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
_, err = self:sendString(topic)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
_, err = self:sendByte(topicFilter)
|
|
if err then return nil, self:handleError(err) end
|
|
|
|
return self:sendPacket()
|
|
end
|
|
|
|
function clientMt:fireEvent(event, ...)
|
|
if not self.eventHandlers then
|
|
return
|
|
end
|
|
if not self.eventHandlers[event] then
|
|
return
|
|
end
|
|
self.eventHandlers[event](...)
|
|
end
|
|
|
|
function clientMt:receiveBytes(count)
|
|
local result, err, partial = nil, nil, ""
|
|
while true do
|
|
result, err, partial = client:receive(1 - #partial, partial)
|
|
if err == "timeout" then
|
|
coroutine.yield()
|
|
else if result then
|
|
return result
|
|
else
|
|
return nil, err
|
|
end
|
|
end
|
|
end
|
|
|
|
function clientMt:receiveByte()
|
|
return string.byte(self:receiveBytes(1)))
|
|
end
|
|
|
|
function clientMt:receiveVarInt()
|
|
local multiplier = 1
|
|
local value = 0
|
|
local encodedByte
|
|
repeat
|
|
encodedByte, err = receiveByte()
|
|
if err then return nil, self:handleError(err) end
|
|
value = value + (encodedByte & 127) * multiplier
|
|
multiplier = multiplier * 128
|
|
if multiplier > 128*128*128 then
|
|
return nil, "malformed remaining length"
|
|
end
|
|
until (encodedByte & 128) ~= 0
|
|
return value
|
|
end
|
|
|
|
function clientMt:receivePacket()
|
|
local firstByte, err = self:receiveByte()
|
|
if err then return self:handleError(err) end
|
|
|
|
local remainingLength, err = self:receiveVarInt()
|
|
if err then return self:handleError(err) end
|
|
|
|
local packetType = (firstByte >> 4) & 0xF
|
|
if packetType == 2 then
|
|
-- CONNACK
|
|
assert(remainingLength == 2, "Invalid CONNACK length")
|
|
local flags, err = self:receiveByte()
|
|
if err then return self:handleError(err) end
|
|
local returnCode, err = self:receiveByte()
|
|
if err then return self:handleError(err) end
|
|
|
|
print("Connected")
|
|
local sessionPresent = flags & 1
|
|
if not sessionPresent then
|
|
self:fireEvent("connect")
|
|
end
|
|
else
|
|
-- Unsupported or error
|
|
self:handleError("Invalid packet type " .. packetType)
|
|
end
|
|
end
|
|
|
|
function clientMt:threadReceive()
|
|
while true do
|
|
if self.connection then
|
|
local err = self:receiveAndHandlePacket()
|
|
if err then
|
|
self:handleError(err)
|
|
end
|
|
else
|
|
coroutine.yield()
|
|
end
|
|
end
|
|
end
|
|
|
|
function clientMt:on(eventHandlers)
|
|
self.eventHandlers = eventHandlers
|
|
end
|
|
|
|
function mqtt.client(args)
|
|
local client = {
|
|
uri = args.uri,
|
|
id = args.id,
|
|
reconnect = 5,
|
|
connection = nil,
|
|
packetIdentifier = 1,
|
|
buffer = ""
|
|
}
|
|
setmetatable(client, {__index = clientMt})
|
|
client.create(function() client:threadReceive() end)
|
|
return client
|
|
end
|
|
|
|
return mqtt
|