local mqtt = {} local socket = require("socket") local clientMt = {} local function bytesUsedByVarIntForValue(value) if value <= 128-1 then return 1, nil elseif value <= 128*128-1 then return 2, nil elseif value <= 128*128*128-1 then return 3, nil elseif value <= 128*128*128*128-1 then return 4, nil else return nil, "invalid byte length" end end local function bytesUsedByString(string) return 2 + #string end function clientMt:receiveByte() end function clientMt:flush() local start = 1 while start <= #self.buffer do print("flushing data") local _, err, sent = self.connection:send(self.buffer, start) if err then print("Error: " .. err) return nil, err else self.buffer = "" return true, err end end return true, nil end function clientMt:sendByte(byte) self.buffer = self.buffer .. string.char(byte) if #self.buffer > 1024 then return self:flush() end end function clientMt:sendData(data) local result, err = self:flush() if err then return result, err end self.buffer = data self:flush() local result, err = self:flush() if err then return result, err end end function clientMt:sendVarInt(value) repeat local encoded = value & 0x7F value = value >> 7 if value > 0 then encoded = encoded | 128 end local _, err = self:sendByte(encoded) if err then return nil, err end until value <= 0 return true, nil end function clientMt:sendShort(value) local _, err = self:sendByte((value >> 8) & 0xFF) if err then return nil, err end local _, err = self:sendByte(value & 0xFF) if err then return nil, err end end function clientMt:sendString(text) local _, err = self:sendShort(#text) if err then return nil, err end local _, err = self:sendData(text) if err then return nil, err end end function clientMt:handleError(err, result) if err then print("Got error") if self.connection then self.connection:close() end self.connection = nil end return result, err end function clientMt:sendPacket() local result, err = self:flush() return self:handleError(err, result) end function clientMt:connect() if self.connection then return true, nil end local conn, err = socket.connect(self.uri, 1883) if not conn then return nil, "failed to connect: " .. err end conn:setoption("tcp-nodelay", true) conn:setoption("linger", {on = true, timeout = 100}) conn:settimeout(nil) self.connection = conn local _, err = self:sendByte(0x10) if err then return nil, self:handleError(err) end local length = 0 local protocolName = "MQTT" local protocolNameLength = bytesUsedByString(protocolName) length = length + protocolNameLength local protocolVersion = 4 local connectFlag = 0x02 -- 1 byte length = length + 2 local keepAlive = 0 -- 2 bytes length = length + 2 local clientIdLength = bytesUsedByString(self.id) length = length + clientIdLength _, err = self:sendVarInt(length) if err then return nil, self:handleError(err) end _, err = self:sendString(protocolName) if err then return nil, self:handleError(err) end _, err = self:sendByte(protocolVersion) if err then return nil, self:handleError(err) end _, err = self:sendByte(connectFlag) if err then return nil, self:handleError(err) end _, err = self:sendShort(keepAlive) if err then return nil, self:handleError(err) end _, err = self:sendString(self.id) if err then return nil, self:handleError(err) end return self:sendPacket() end function clientMt:disconnect(args) if not self.connection then return true, nil end local _, err = self:sendByte(0xE0) if err then return nil, self:handleError(err) end local _, err = self:sendByte(0x00) if err then return nil, self:handleError(err) end local result, err = self:sendPacket() self.connection:shutdown("both") local peer repeat peer = self.connection:getpeername() if peer then socket.sleep(0.02) end until peer self.connection:close() self.connection = nil return result, err end function clientMt:publish(args) local topic = args.topic local payload = args.payload local _, err = self:connect() if err then return nil, self:handleError(err) end local retain = args.retain and 0x01 or 0x00 local _, err = self:sendByte(0x30 | retain) if err then return nil, self:handleError(err) end local topicLength = bytesUsedByString(topic) local payloadLength = #payload _, err = self:sendVarInt(topicLength + payloadLength) if err then return nil, self:handleError(err) end _, err = self:sendString(topic) if err then return nil, self:handleError(err) end _, err = self:sendData(payload) if err then return nil, self:handleError(err) end return self:sendPacket() end function clientMt:subscribe(args) local topic = args.topic local _, err = self:connect() if err then return nil, self:handleError(err) end local _, err = self:sendByte(0x82) if err then return nil, self:handleError(err) end local packetIdentifier = self.packetIdentifier self.packetIdentifier = self.packetIdentifier + 1 if self.packetIdentifier > 0xFF00 then self.packetIdentifier = 1 end local topicFilter = 0 local topicLength = bytesUsedByString(topic) local length = 2 + topicLength + 1 _, err = self:sendVarInt(length) if err then return nil, self:handleError(err) end _, err = self:sendShort(packetIdentifier) if err then return nil, self:handleError(err) end _, err = self:sendString(topic) if err then return nil, self:handleError(err) end _, err = self:sendByte(topicFilter) if err then return nil, self:handleError(err) end return self:sendPacket() end function clientMt:fireEvent(event, ...) if not self.eventHandlers then return end if not self.eventHandlers[event] then return end self.eventHandlers[event](...) end function clientMt:receiveBytes(count) local result, err, partial = nil, nil, "" while true do result, err, partial = client:receive(1 - #partial, partial) if err == "timeout" then coroutine.yield() elseif result then return result else return nil, err end end end function clientMt:receiveByte() return string.byte(self:receiveBytes(1)) end function clientMt:receiveVarInt() local multiplier = 1 local value = 0 local encodedByte repeat encodedByte, err = receiveByte() if err then return nil, self:handleError(err) end value = value + (encodedByte & 127) * multiplier multiplier = multiplier * 128 if multiplier > 128*128*128 then return nil, "malformed remaining length" end until (encodedByte & 128) ~= 0 return value end function clientMt:receivePacket() local firstByte, err = self:receiveByte() if err then return self:handleError(err) end local remainingLength, err = self:receiveVarInt() if err then return self:handleError(err) end local packetType = (firstByte >> 4) & 0xF print("Got packet of type " .. packetType) if packetType == 2 then -- CONNACK assert(remainingLength == 2, "Invalid CONNACK length") local flags, err = self:receiveByte() if err then return self:handleError(err) end local returnCode, err = self:receiveByte() if err then return self:handleError(err) end print("Connected") local sessionPresent = flags & 1 if not sessionPresent then self:fireEvent("connect") end else -- Unsupported or error self:handleError("Invalid packet type " .. packetType) end end function clientMt:threadReceive() -- local status, err = pcall(function() while true do if self.connection then local err = self:receiveAndHandlePacket() if err then self:handleError(err) end else coroutine.yield() end end -- end) -- if err then -- print(err) -- error(err) -- end end function clientMt:runForever() while true do coroutine.resume(self.thread) end end function clientMt:on(eventHandlers) self.eventHandlers = eventHandlers end function mqtt.client(args) local client = { uri = args.uri, id = args.id, reconnect = 5, connection = nil, packetIdentifier = 1, buffer = "" } setmetatable(client, {__index = clientMt}) client.thread = coroutine.create(function() client:threadReceive() end) return client end return mqtt