--- MQTT generic protocol components module
-- @module mqtt.protocol
--[[
Here is a generic implementation of MQTT protocols of all supported versions.
MQTT v3.1.1 documentation (DOCv3.1.1):
DOC[1]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html
MQTT v5.0 documentation (DOCv5.0):
DOC[2]: http://docs.oasis-open.org/mqtt/mqtt/v5.0/mqtt-v5.0.html
CONVENTIONS:
* read_func - function to read data from some stream-like object (like network connection).
We are calling it with one argument: number of bytes to read.
Use currying/closures to pass other arguments to this function.
This function should return string of given size on success.
On failure it should return false/nil and an error message.
]]
-- module table
local protocol = {}
-- load required stuff
local type = type
local error = error
local assert = assert
local require = require
local _VERSION = _VERSION -- lua interpreter version, not a mqtt._VERSION
local tostring = tostring
local setmetatable = setmetatable
local table = require("table")
local tbl_concat = table.concat
local unpack = unpack or table.unpack
local string = require("string")
local str_sub = string.sub
local str_char = string.char
local str_byte = string.byte
local str_format = string.format
local const = require("mqtt.const")
local const_v311 = const.v311
local const_v50 = const.v50
local bit = require("mqtt.bitwrap")
local bor = bit.bor
local band = bit.band
local lshift = bit.lshift
local rshift = bit.rshift
local tools = require("mqtt.tools")
local div = tools.div
local sortedpairs = tools.sortedpairs
--- Create bytes of the uint8 value
-- @tparam number val - integer value to convert to bytes
-- @treturn string bytes of the value
function protocol.make_uint8(val)
if val < 0 or val > 0xFF then
error("value is out of range to encode as uint8: "..tostring(val))
end
return str_char(val)
end
local make_uint8 = protocol.make_uint8
--- Create bytes of the uint16 value
-- @tparam number val - integer value to convert to bytes
-- @treturn string bytes of the value
function protocol.make_uint16(val)
if val < 0 or val > 0xFFFF then
error("value is out of range to encode as uint16: "..tostring(val))
end
return str_char(rshift(val, 8), band(val, 0xFF))
end
local make_uint16 = protocol.make_uint16
--- Create bytes of the uint32 value
-- @tparam number val - integer value to convert to bytes
-- @treturn string bytes of the value
function protocol.make_uint32(val)
if val < 0 or val > 0xFFFFFFFF then
error("value is out of range to encode as uint32: "..tostring(val))
end
return str_char(rshift(val, 24), band(rshift(val, 16), 0xFF), band(rshift(val, 8), 0xFF), band(val, 0xFF))
end
--- Create bytes of the UTF-8 string value according to the MQTT spec.
-- Basically it's the same string with its length prefixed as uint16 value.
-- For MQTT v3.1.1: 1.5.3 UTF-8 encoded strings,
-- For MQTT v5.0: 1.5.4 UTF-8 Encoded String.
-- @tparam string str - string value to convert to bytes
-- @treturn string bytes of the value
function protocol.make_string(str)
return make_uint16(str:len())..str
end
--- Maximum integer value (268435455) that can be encoded using variable-length encoding
protocol.max_variable_length = 268435455
local max_variable_length = protocol.max_variable_length
--- Create bytes of the integer value encoded as variable length field
-- For MQTT v3.1.1: 2.2.3 Remaining Length,
-- For MQTT v5.0: 2.1.4 Remaining Length.
-- @tparam number len - integer value to be encoded
-- @treturn string bytes of the value
function protocol.make_var_length(len)
if len < 0 or len > max_variable_length then
error("value is invalid for encoding as variable length field: "..tostring(len))
end
local bytes = {}
local i = 1
repeat
local byte = len % 128
len = div(len, 128)
if len > 0 then
byte = bor(byte, 128)
end
bytes[i] = byte
i = i + 1
until len <= 0
return unpack(bytes)
end
local make_var_length = protocol.make_var_length
--- Make bytes for 1-byte value with only 0 or 1 value allowed
-- @tparam number value - integer value to convert to bytes
-- @treturn string bytes of the value
function protocol.make_uint8_0_or_1(value)
if value ~= 0 and value ~= 1 then
error("expecting 0 or 1 as value")
end
return make_uint8(value)
end
--- Make bytes for 2-byte value with nonzero check
-- @tparam number value - integer value to convert to bytes
-- @treturn string bytes of the value
function protocol.make_uint16_nonzero(value)
if value == 0 then
error("expecting nonzero value")
end
return make_uint16(value)
end
--- Make bytes for variable length value with nonzero value check
-- @tparam number value - integer value to convert to bytes
-- @treturn string bytes of the value
function protocol.make_var_length_nonzero(value)
if value == 0 then
error("expecting nonzero value")
end
return make_var_length(value)
end
--- Read string (or bytes) using given read_func function
-- @tparam function read_func - function to read some bytes from the network layer
-- @treturn string parsed string (or bytes) on success
-- @return OR false and error message on failure
function protocol.parse_string(read_func)
assert(type(read_func) == "function", "expecting read_func to be a function")
local len, err = read_func(2)
if not len then
return false, "failed to read string length: "..err
end
-- convert string length from 2 bytes
local byte1, byte2 = str_byte(len, 1, 2)
len = bor(lshift(byte1, 8), byte2)
-- and return string/bytes of the parsed length
return read_func(len)
end
local parse_string = protocol.parse_string
--- Parse uint8 value using given read_func
-- @tparam function read_func - function to read some bytes from the network layer
-- @treturn number parser value
-- @return OR false and error message on failure
function protocol.parse_uint8(read_func)
assert(type(read_func) == "function", "expecting read_func to be a function")
local value, err = read_func(1)
if not value then
return false, "failed to read 1 byte for uint8: "..err
end
return str_byte(value, 1, 1)
end
local parse_uint8 = protocol.parse_uint8
--- Parse uint8 value using given read_func with only 0 or 1 value allowed
-- @tparam function read_func - function to read some bytes from the network layer
-- @treturn number parser value
-- @return OR false and error message on failure
function protocol.parse_uint8_0_or_1(read_func)
local value, err = parse_uint8(read_func)
if not value then
return false, err
end
if value ~= 0 and value ~= 1 then
return false, "expecting only 0 or 1 but got: "..value
end
return value
end
--- Parse uint16 value using given read_func
-- @tparam function read_func - function to read some bytes from the network layer
-- @treturn number parser value
-- @return OR false and error message on failure
function protocol.parse_uint16(read_func)
assert(type(read_func) == "function", "expecting read_func to be a function")
local value, err = read_func(2)
if not value then
return false, "failed to read 2 byte for uint16: "..err
end
local byte1, byte2 = str_byte(value, 1, 2)
return lshift(byte1, 8) + byte2
end
local parse_uint16 = protocol.parse_uint16
--- Parse uint16 non-zero value using given read_func
-- @tparam function read_func - function to read some bytes from the network layer
-- @treturn number parser value
-- @return OR false and error message on failure
function protocol.parse_uint16_nonzero(read_func)
local value, err = parse_uint16(read_func)
if not value then
return false, err
end
if value == 0 then
return false, "expecting non-zero value"
end
return value
end
--- Parse uint32 value using given read_func
-- @tparam function read_func - function to read some bytes from the network layer
-- @treturn number parser value
-- @return OR false and error message on failure
function protocol.parse_uint32(read_func)
assert(type(read_func) == "function", "expecting read_func to be a function")
local value, err = read_func(4)
if not value then
return false, "failed to read 4 byte for uint32: "..err
end
local byte1, byte2, byte3, byte4 = str_byte(value, 1, 4)
if _VERSION < "Lua 5.3" then
return byte1 * (2 ^ 24) + lshift(byte2, 16) + lshift(byte3, 8) + byte4
else
return lshift(byte1, 24) + lshift(byte2, 16) + lshift(byte3, 8) + byte4
end
end
-- Max multiplier of the variable length integer value
local max_mult = 128 * 128 * 128
--- Parse variable length field value using given read_func.
-- For MQTT v3.1.1: 2.2.3 Remaining Length,
-- For MQTT v5.0: 2.1.4 Remaining Length.
-- @tparam function read_func - function to read some bytes from the network layer
-- @treturn number parser value
-- @return OR false and error message on failure
function protocol.parse_var_length(read_func)
-- DOC[1]: 2.2.3 Remaining Length
-- DOC[2]: 1.5.5 Variable Byte Integer
assert(type(read_func) == "function", "expecting read_func to be a function")
local mult = 1
local val = 0
repeat
local byte, err = read_func(1)
if not byte then
return false, err
end
byte = str_byte(byte, 1, 1)
val = val + band(byte, 127) * mult
if mult > max_mult then
return false, "malformed variable length field data"
end
mult = mult * 128
until band(byte, 128) == 0
return val
end
local parse_var_length = protocol.parse_var_length
--- Parse variable length field value using given read_func with non-zero constraint.
-- For MQTT v3.1.1: 2.2.3 Remaining Length,
-- For MQTT v5.0: 2.1.4 Remaining Length.
-- @tparam function read_func - function to read some bytes from the network layer
-- @treturn number parser value
-- @return OR false and error message on failure
function protocol.parse_var_length_nonzero(read_func)
local value, err = parse_var_length(read_func)
if not value then
return false, err
end
if value == 0 then
return false, "expecting non-zero value"
end
return value
end
--- Create bytes of the MQTT fixed packet header
-- For MQTT v3.1.1: 2.2 Fixed header,
-- For MQTT v5.0: 2.1.1 Fixed Header.
-- @tparam number ptype - MQTT packet type
-- @tparam number flags - MQTT packet flags
-- @tparam number len - MQTT packet length
-- @treturn string bytes of the fixed packet header
function protocol.make_header(ptype, flags, len)
local byte1 = bor(lshift(ptype, 4), band(flags, 0x0F))
return str_char(byte1, make_var_length(len))
end
--- Check if given value is a valid PUBLISH message QoS value
-- @tparam number val - QoS value
-- @treturn boolean true for valid QoS value, otherwise false
function protocol.check_qos(val)
return (val == 0) or (val == 1) or (val == 2)
end
--- Check if given value is a valid Packet Identifier
-- For MQTT v3.1.1: 2.3.1 Packet Identifier,
-- For MQTT v5.0: 2.2.1 Packet Identifier.
-- @tparam number val - Packet ID value
-- @treturn boolean true for valid Packet ID value, otherwise false
function protocol.check_packet_id(val)
return val >= 1 and val <= 0xFFFF
end
--- Returns the next Packet Identifier value relative to given current value.
-- If current is nil - returns 1 as the first possible Packet ID.
-- For MQTT v3.1.1: 2.3.1 Packet Identifier,
-- For MQTT v5.0: 2.2.1 Packet Identifier.
-- @tparam[opt] number curr - current Packet ID value
-- @treturn number next Packet ID value
function protocol.next_packet_id(curr)
if not curr then
return 1
end
assert(type(curr) == "number", "expecting curr to be a number")
assert(curr >= 1, "expecting curr to be >= 1")
curr = curr + 1
if curr > 0xFFFF then
curr = 1
end
return curr
end
--- MQTT protocol fixed header packet types.
-- For MQTT v3.1.1: 2.2.1 MQTT Control Packet type,
-- For MQTT v5.0: 2.1.2 MQTT Control Packet type.
protocol.packet_type = {
CONNECT = 1, -- 1
CONNACK = 2, -- 2
PUBLISH = 3, -- 3
PUBACK = 4, -- 4
PUBREC = 5, -- 5
PUBREL = 6, -- 6
PUBCOMP = 7, -- 7
SUBSCRIBE = 8, -- 8
SUBACK = 9, -- 9
UNSUBSCRIBE = 10, -- 10
UNSUBACK = 11, -- 11
PINGREQ = 12, -- 12
PINGRESP = 13, -- 13
DISCONNECT = 14, -- 14
AUTH = 15, -- 15
[1] = "CONNECT", -- "CONNECT"
[2] = "CONNACK", -- "CONNACK"
[3] = "PUBLISH", -- "PUBLISH"
[4] = "PUBACK", -- "PUBACK"
[5] = "PUBREC", -- "PUBREC"
[6] = "PUBREL", -- "PUBREL"
[7] = "PUBCOMP", -- "PUBCOMP"
[8] = "SUBSCRIBE", -- "SUBSCRIBE"
[9] = "SUBACK", -- "SUBACK"
[10] = "UNSUBSCRIBE", -- "UNSUBSCRIBE"
[11] = "UNSUBACK", -- "UNSUBACK"
[12] = "PINGREQ", -- "PINGREQ"
[13] = "PINGRESP", -- "PINGRESP"
[14] = "DISCONNECT", -- "DISCONNECT"
[15] = "AUTH", -- "AUTH"
}
local packet_type = protocol.packet_type
-- Packet types requiring packet identifier field
-- DOCv3.1.1: 2.3.1 Packet Identifier
-- DOCv5.0: 2.2.1 Packet Identifier
local packets_requiring_packet_id = {
[packet_type.PUBACK] = true,
[packet_type.PUBREC] = true,
[packet_type.PUBREL] = true,
[packet_type.PUBCOMP] = true,
[packet_type.SUBSCRIBE] = true,
[packet_type.SUBACK] = true,
[packet_type.UNSUBSCRIBE] = true,
[packet_type.UNSUBACK] = true,
}
-- CONNACK return code/reason code strings
protocol.connack_rc = {
-- MQTT v3.1.1 Connect return codes, DOCv3.1.1: 3.2.2.3 Connect Return code
[0] = "Connection Accepted",
[1] = "Connection Refused, unacceptable protocol version",
[2] = "Connection Refused, identifier rejected",
[3] = "Connection Refused, Server unavailable",
[4] = "Connection Refused, bad user name or password",
[5] = "Connection Refused, not authorized",
-- MQTT v5.0 Connect reason codes, DOCv5.0: 3.2.2.2 Connect Reason Code
[0x80] = "Unspecified error",
[0x81] = "Malformed Packet",
[0x82] = "Protocol Error",
[0x83] = "Implementation specific error",
[0x84] = "Unsupported Protocol Version",
[0x85] = "Client Identifier not valid",
[0x86] = "Bad User Name or Password",
[0x87] = "Not authorized",
[0x88] = "Server unavailable",
[0x89] = "Server busy",
[0x8A] = "Banned",
[0x8C] = "Bad authentication method",
[0x90] = "Topic Name invalid",
[0x95] = "Packet too large",
[0x97] = "Quota exceeded",
[0x99] = "Payload format invalid",
[0x9A] = "Retain not supported",
[0x9B] = "QoS not supported",
[0x9C] = "Use another server",
[0x9D] = "Server moved",
[0x9F] = "Connection rate exceeded",
}
local connack_rc = protocol.connack_rc
--- Check if Packet Identifier field are required for given packet
-- @tparam table args - args for creating packet
-- @treturn boolean true if Packet Identifier are required for the packet
function protocol.packet_id_required(args)
assert(type(args) == "table", "expecting args to be a table")
assert(type(args.type) == "number", "expecting .type to be a number")
local ptype = args.type
if ptype == packet_type.PUBLISH and args.qos and args.qos > 0 then
return true
end
return packets_requiring_packet_id[ptype]
end
-- Metatable for combined data packet, should looks like a string
local combined_packet_mt = {
-- Convert combined data packet to string
__tostring = function(self)
local strings = {}
for i, part in ipairs(self) do
strings[i] = tostring(part)
end
return tbl_concat(strings)
end,
-- Get length of combined data packet
len = function(self)
local len = 0
for _, part in ipairs(self) do
len = len + part:len()
end
return len
end,
-- Append part to the end of combined data packet
append = function(self, part)
self[#self + 1] = part
end
}
-- Make combined_packet_mt table works like a class
combined_packet_mt.__index = function(_, key)
return combined_packet_mt[key]
end
--- Combine several data parts into one
-- @tparam combined_packet_mt/string ... any amount of strings of combined_packet_mt tables to combine into one packet
-- @treturn combined_packet_mt table suitable to append packet parts or to stringify it into raw packet bytes
function protocol.combine(...)
return setmetatable({...}, combined_packet_mt)
end
-- Convert any value to string, respecting strings and tables
local function value_tostring(value)
local t = type(value)
if t == "string" then
return str_format("%q", value)
elseif t == "table" then
local res = {}
for k, v in sortedpairs(value) do
if type(k) == "number" then
res[#res + 1] = value_tostring(v)
else
if k:match("^[a-zA-Z_][_%w]*$") then
res[#res + 1] = str_format("%s=%s", k, value_tostring(v))
else
res[#res + 1] = str_format("[%q]=%s", k, value_tostring(v))
end
end
end
return str_format("{%s}", tbl_concat(res, ", "))
else
return tostring(value)
end
end
--- Render packet to string representation
-- @tparam packet_mt packet table to convert to string
-- @treturn string human-readable string representation of the packet
function protocol.packet_tostring(packet)
local res = {}
for k, v in sortedpairs(packet) do
res[#res + 1] = str_format("%s=%s", k, value_tostring(v))
end
return str_format("%s{%s}", tostring(packet_type[packet.type]), tbl_concat(res, ", "))
end
local packet_tostring = protocol.packet_tostring
--- Parsed packet metatable
protocol.packet_mt = {
__tostring = packet_tostring, -- packet-to-human-readable-string conversion metamethod using protocol.packet_tostring()
}
--- Parsed CONNACK packet metatable
protocol.connack_packet_mt = {
__tostring = packet_tostring, -- packet-to-human-readable-string conversion metamethod using protocol.packet_tostring()
reason_string = function(self) -- Returns reason string for the CONNACK packet according to its rc field
local reason_string = connack_rc[self.rc]
if not reason_string then
reason_string = "Unknown: "..self.rc
end
return reason_string
end,
}
protocol.connack_packet_mt.__index = protocol.connack_packet_mt
--- Start parsing a new packet
-- @tparam function read_func - function to read data from the network connection
-- @treturn number packet_type
-- @treturn number flags
-- @treturn table input - a table with fields "read_func" and "available" representing a stream-like object
-- to read already received packet data in chunks
-- @return OR false and error_message on failure
function protocol.start_parse_packet(read_func)
assert(type(read_func) == "function", "expecting read_func to be a function")
local byte1, err, len, data
-- parse fixed header
-- DOC[v3.1.1]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html#_Toc442180832
-- DOC[v5.0]: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901020
byte1, err = read_func(1)
if not byte1 then
return false, err
end
byte1 = str_byte(byte1, 1, 1)
local ptype = rshift(byte1, 4)
local flags = band(byte1, 0xF)
len, err = parse_var_length(read_func)
if not len then
return false, err
end
-- create packet parser instance (aka input)
local input = {1, available = 0} -- input data offset and available size
if len > 0 then
data, err = read_func(len)
else
data = ""
end
if not data then
return false, err
end
input.available = data:len()
-- read data function for the input instance
input.read_func = function(size)
if size > input.available then
return false, size
end
local off = input[1]
local res = str_sub(data, off, off + size - 1)
input[1] = off + size
input.available = input.available - size
return res
end
return ptype, flags, input
end
--- Parse CONNECT packet with read_func
-- @tparam function read_func - function to read data from the network connection
-- @tparam[opt] number version - expected protocol version constant or nil to accept both versions
-- @return packet on success or false and error message on failure
function protocol.parse_packet_connect(read_func, version)
-- DOC[v3.1.1]: 3.1 CONNECT – Client requests a connection to a Server
-- DOC[v5.0]: 3.1 CONNECT – Connection Request
local ptype, flags, input = protocol.start_parse_packet(read_func)
if ptype ~= packet_type.CONNECT then
return false, "expecting CONNECT (1) packet type but got "..ptype
end
if flags ~= 0 then
return false, "expecting CONNECT flags to be 0 but got "..flags
end
return protocol.parse_packet_connect_input(input, version)
end
--- Parse CONNECT packet from already received stream-like packet input table
-- @tparam table input - a table with fields "read_func" and "available" representing a stream-like object
-- @tparam[opt] number version - expected protocol version constant or nil to accept both versions
-- @return packet on success or false and error message on failure
function protocol.parse_packet_connect_input(input, version)
-- DOC[v3.1.1]: 3.1 CONNECT – Client requests a connection to a Server
-- DOC[v5.0]: 3.1 CONNECT – Connection Request
local read_func = input.read_func
local err, protocol_name, protocol_ver, connect_flags, keep_alive
-- DOC: 3.1.2.1 Protocol Name
protocol_name, err = parse_string(read_func)
if not protocol_name then
return false, "failed to parse protocol name: "..err
end
if protocol_name ~= "MQTT" then
return false, "expecting 'MQTT' as protocol name but received '"..protocol_name.."'"
end
-- DOC[v3.1.1]: 3.1.2.2 Protocol Level
-- DOC[v5.0]: 3.1.2.2 Protocol Version
protocol_ver, err = parse_uint8(read_func)
if not protocol_ver then
return false, "failed to parse protocol level/version: "..err
end
if version ~= nil and version ~= protocol_ver then
return false, "expecting protocol version "..version.." but received "..protocol_ver
end
-- DOC: 3.1.2.3 Connect Flags
connect_flags, err = parse_uint8(read_func)
if not connect_flags then
return false, "failed to parse connect flags: "..err
end
if band(connect_flags, 0x1) ~= 0 then
return false, "reserved 1st bit in connect flags are set"
end
local clean = (band(connect_flags, 0x2) ~= 0)
local will = (band(connect_flags, 0x4) ~= 0)
local will_qos = band(rshift(connect_flags, 3), 0x3)
local will_retain = (band(connect_flags, 0x20) ~= 0)
local password_flag = (band(connect_flags, 0x40) ~= 0)
local username_flag = (band(connect_flags, 0x80) ~= 0)
-- DOC: 3.1.2.10 Keep Alive
keep_alive, err = parse_uint16(read_func)
if not keep_alive then
return false, "failed to parse keep alive field: "..err
end
-- continue parsing based on the protocol_ver
-- preparing common connect packet fields
local packet = {
type = packet_type.CONNECT,
version = protocol_ver,
clean = clean,
password = password_flag, -- NOTE: will be replaced
username = username_flag, -- NOTE: will be replaced
keep_alive = keep_alive,
}
if will then
packet.will = {
qos = will_qos,
retain = will_retain,
topic = "", -- NOTE: will be replaced
payload = "", -- NOTE: will be replaced
}
end
if protocol_ver == const_v311 then
return require("mqtt.protocol4")._parse_packet_connect_continue(input, packet)
elseif protocol_ver == const_v50 then
return require("mqtt.protocol5")._parse_packet_connect_continue(input, packet)
else
return false, "unexpected protocol version to continue parsing: "..protocol_ver
end
end
-- export module table
return protocol
-- vim: ts=4 sts=4 sw=4 noet ft=lua