360 lines
8.1 KiB
Lua

local mqtt = {}
local socket = require("socket")
local clientMt = {}
local function bytesUsedByVarIntForValue(value)
if value <= 128-1 then
return 1, nil
elseif value <= 128*128-1 then
return 2, nil
elseif value <= 128*128*128-1 then
return 3, nil
elseif value <= 128*128*128*128-1 then
return 4, nil
else
return nil, "invalid byte length"
end
end
local function bytesUsedByString(string)
return 2 + #string
end
function clientMt:receiveByte()
end
function clientMt:flush()
local start = 1
while start <= #self.buffer do
print("flushing data")
local _, err, sent = self.connection:send(self.buffer, start)
if err then
print("Error: " .. err)
return nil, err
else
self.buffer = ""
return true, err
end
end
return true, nil
end
function clientMt:sendByte(byte)
self.buffer = self.buffer .. string.char(byte)
if #self.buffer > 1024 then
return self:flush()
end
end
function clientMt:sendData(data)
local result, err = self:flush()
if err then return result, err end
self.buffer = data
self:flush()
local result, err = self:flush()
if err then return result, err end
end
function clientMt:sendVarInt(value)
repeat
local encoded = value & 0x7F
value = value >> 7
if value > 0 then
encoded = encoded | 128
end
local _, err = self:sendByte(encoded)
if err then
return nil, err
end
until value <= 0
return true, nil
end
function clientMt:sendShort(value)
local _, err = self:sendByte((value >> 8) & 0xFF)
if err then return nil, err end
local _, err = self:sendByte(value & 0xFF)
if err then return nil, err end
end
function clientMt:sendString(text)
local _, err = self:sendShort(#text)
if err then return nil, err end
local _, err = self:sendData(text)
if err then return nil, err end
end
function clientMt:handleError(err, result)
if err then
print("Got error")
if self.connection then
self.connection:close()
end
self.connection = nil
end
return result, err
end
function clientMt:sendPacket()
local result, err = self:flush()
return self:handleError(err, result)
end
function clientMt:connect()
if self.connection then
return true, nil
end
local conn, err = socket.connect(self.uri, 1883)
if not conn then
return nil, "failed to connect: " .. err
end
conn:setoption("tcp-nodelay", true)
conn:setoption("linger", {on = true, timeout = 100})
conn:settimeout(nil)
self.connection = conn
local _, err = self:sendByte(0x10)
if err then return nil, self:handleError(err) end
local length = 0
local protocolName = "MQTT"
local protocolNameLength = bytesUsedByString(protocolName)
length = length + protocolNameLength
local protocolVersion = 4
local connectFlag = 0x02 -- 1 byte
length = length + 2
local keepAlive = 0 -- 2 bytes
length = length + 2
local clientIdLength = bytesUsedByString(self.id)
length = length + clientIdLength
_, err = self:sendVarInt(length)
if err then return nil, self:handleError(err) end
_, err = self:sendString(protocolName)
if err then return nil, self:handleError(err) end
_, err = self:sendByte(protocolVersion)
if err then return nil, self:handleError(err) end
_, err = self:sendByte(connectFlag)
if err then return nil, self:handleError(err) end
_, err = self:sendShort(keepAlive)
if err then return nil, self:handleError(err) end
_, err = self:sendString(self.id)
if err then return nil, self:handleError(err) end
return self:sendPacket()
end
function clientMt:disconnect(args)
if not self.connection then
return true, nil
end
local _, err = self:sendByte(0xE0)
if err then return nil, self:handleError(err) end
local _, err = self:sendByte(0x00)
if err then return nil, self:handleError(err) end
local result, err = self:sendPacket()
self.connection:shutdown("both")
local peer
repeat
peer = self.connection:getpeername()
if peer then socket.sleep(0.02) end
until peer
self.connection:close()
self.connection = nil
return result, err
end
function clientMt:publish(args)
local topic = args.topic
local payload = args.payload
local _, err = self:connect()
if err then return nil, self:handleError(err) end
local retain = args.retain and 0x01 or 0x00
local _, err = self:sendByte(0x30 | retain)
if err then return nil, self:handleError(err) end
local topicLength = bytesUsedByString(topic)
local payloadLength = #payload
_, err = self:sendVarInt(topicLength + payloadLength)
if err then return nil, self:handleError(err) end
_, err = self:sendString(topic)
if err then return nil, self:handleError(err) end
_, err = self:sendData(payload)
if err then return nil, self:handleError(err) end
return self:sendPacket()
end
function clientMt:subscribe(args)
local topic = args.topic
local _, err = self:connect()
if err then return nil, self:handleError(err) end
local _, err = self:sendByte(0x82)
if err then return nil, self:handleError(err) end
local packetIdentifier = self.packetIdentifier
self.packetIdentifier = self.packetIdentifier + 1
if self.packetIdentifier > 0xFF00 then
self.packetIdentifier = 1
end
local topicFilter = 0
local topicLength = bytesUsedByString(topic)
local length = 2 + topicLength + 1
_, err = self:sendVarInt(length)
if err then return nil, self:handleError(err) end
_, err = self:sendShort(packetIdentifier)
if err then return nil, self:handleError(err) end
_, err = self:sendString(topic)
if err then return nil, self:handleError(err) end
_, err = self:sendByte(topicFilter)
if err then return nil, self:handleError(err) end
return self:sendPacket()
end
function clientMt:fireEvent(event, ...)
if not self.eventHandlers then
return
end
if not self.eventHandlers[event] then
return
end
self.eventHandlers[event](...)
end
function clientMt:receiveBytes(count)
local result, err, partial = nil, nil, ""
while true do
result, err, partial = client:receive(1 - #partial, partial)
if err == "timeout" then
coroutine.yield()
elseif result then
return result
else
return nil, err
end
end
end
function clientMt:receiveByte()
return string.byte(self:receiveBytes(1))
end
function clientMt:receiveVarInt()
local multiplier = 1
local value = 0
local encodedByte
repeat
encodedByte, err = receiveByte()
if err then return nil, self:handleError(err) end
value = value + (encodedByte & 127) * multiplier
multiplier = multiplier * 128
if multiplier > 128*128*128 then
return nil, "malformed remaining length"
end
until (encodedByte & 128) ~= 0
return value
end
function clientMt:receivePacket()
local firstByte, err = self:receiveByte()
if err then return self:handleError(err) end
local remainingLength, err = self:receiveVarInt()
if err then return self:handleError(err) end
local packetType = (firstByte >> 4) & 0xF
print("Got packet of type " .. packetType)
if packetType == 2 then
-- CONNACK
assert(remainingLength == 2, "Invalid CONNACK length")
local flags, err = self:receiveByte()
if err then return self:handleError(err) end
local returnCode, err = self:receiveByte()
if err then return self:handleError(err) end
print("Connected")
local sessionPresent = flags & 1
if not sessionPresent then
self:fireEvent("connect")
end
else
-- Unsupported or error
self:handleError("Invalid packet type " .. packetType)
end
end
function clientMt:threadReceive()
-- local status, err = pcall(function()
while true do
if self.connection then
local err = self:receiveAndHandlePacket()
if err then
self:handleError(err)
end
else
coroutine.yield()
end
end
-- end)
-- if err then
-- print(err)
-- error(err)
-- end
end
function clientMt:runForever()
while true do
coroutine.resume(self.thread)
end
end
function clientMt:on(eventHandlers)
self.eventHandlers = eventHandlers
end
function mqtt.client(args)
local client = {
uri = args.uri,
id = args.id,
reconnect = 5,
connection = nil,
packetIdentifier = 1,
buffer = ""
}
setmetatable(client, {__index = clientMt})
client.thread = coroutine.create(function() client:threadReceive() end)
return client
end
return mqtt