diff --git a/controller-host/mymqtt.lua b/controller-host/mymqtt.lua index 21ff70c..b2a5ff8 100644 --- a/controller-host/mymqtt.lua +++ b/controller-host/mymqtt.lua @@ -88,21 +88,20 @@ function clientMt:sendString(text) if err then return nil, err end end -function clientMt:handleError(result, err) +function clientMt:handleError(err, result) if err then print("Got error") if self.connection then self.connection:close() end self.connection = nil - else - assert(result, "Missing result") end return result, err end function clientMt:sendPacket() - return self:handleError(self:flush()) + local result, err = self:flush() + return self:handleError(err, result) end function clientMt:connect() @@ -242,6 +241,16 @@ function clientMt:subscribe(args) return self:sendPacket() end +function clientMt:fireEvent(event, ...) + if not self.eventHandlers then + return + end + if not self.eventHandlers[event] then + return + end + self.eventHandlers[event](...) +end + function clientMt:receiveBytes(count) local result, err, partial = nil, nil, "" while true do @@ -260,14 +269,59 @@ function clientMt:receiveByte() return string.byte(self:receiveBytes(1))) end +function clientMt:receiveVarInt() + local multiplier = 1 + local value = 0 + local encodedByte + repeat + encodedByte, err = receiveByte() + if err then return nil, self:handleError(err) end + value = value + (encodedByte & 127) * multiplier + multiplier = multiplier * 128 + if multiplier > 128*128*128 then + return nil, "malformed remaining length" + end + until (encodedByte & 128) ~= 0 + return value +end + function clientMt:receivePacket() local firstByte, err = self:receiveByte() - if err then return nil, self:handleError(err) end + if err then return self:handleError(err) end + + local remainingLength, err = self:receiveVarInt() + if err then return self:handleError(err) end + + local packetType = (firstByte >> 4) & 0xF + if packetType == 2 then + -- CONNACK + assert(remainingLength == 2, "Invalid CONNACK length") + local flags, err = self:receiveByte() + if err then return self:handleError(err) end + local returnCode, err = self:receiveByte() + if err then return self:handleError(err) end + + print("Connected") + local sessionPresent = flags & 1 + if not sessionPresent then + self:fireEvent("connect") + end + else + -- Unsupported or error + self:handleError("Invalid packet type " .. packetType) + end end function clientMt:threadReceive() while true do - self:receiveAndHandlePacket() + if self.connection then + local err = self:receiveAndHandlePacket() + if err then + self:handleError(err) + end + else + coroutine.yield() + end end end