216 lines
7.2 KiB
Python
Executable File

#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0-or-later
# Copyright (C) 2022, Ideas on Board Oy
#
# Generate color space conversion table coefficients with configurable
# fixed-point precision
import argparse
import enum
import numpy as np
import sys
encodings = {
'rec601': [
[ 0.299, 0.587, 0.114 ],
[ -0.299 / 1.772, -0.587 / 1.772, 0.886 / 1.772 ],
[ 0.701 / 1.402, -0.587 / 1.402, -0.114 / 1.402 ]
],
'rec709': [
[ 0.2126, 0.7152, 0.0722 ],
[ -0.2126 / 1.8556, -0.7152 / 1.8556, 0.9278 / 1.8556 ],
[ 0.7874 / 1.5748, -0.7152 / 1.5748, -0.0722 / 1.5748 ]
],
'rec2020': [
[ 0.2627, 0.6780, 0.0593 ],
[ -0.2627 / 1.8814, -0.6780 / 1.8814, 0.9407 / 1.8814 ],
[ 0.7373 / 1.4746, -0.6780 / 1.4746, -0.0593 / 1.4746 ],
],
'smpte240m': [
[ 0.2122, 0.7013, 0.0865 ],
[ -0.2122 / 1.8270, -0.7013 / 1.8270, 0.9135 / 1.8270 ],
[ 0.7878 / 1.5756, -0.7013 / 1.5756, -0.0865 / 1.5756 ],
],
}
class Precision(object):
def __init__(self, precision):
if precision[0].upper() != 'Q':
raise RuntimeError(f'Invalid precision `{precision}`')
prec = precision[1:].split('.')
if len(prec) != 2:
raise RuntimeError(f'Invalid precision `{precision}`')
self.__prec = [int(v) for v in prec]
@property
def integer(self):
return self.__prec[0]
@property
def fractional(self):
return self.__prec[1]
@property
def total(self):
# Add 1 for the sign bit
return self.__prec[0] + self.__prec[1] + 1
class Quantization(enum.Enum):
FULL = 0
LIMITED = 1
def scale_coeff(coeff, quantization, luma):
"""Scale a coefficient to the output range dictated by the quantization.
Parameters
----------
coeff : float
The CSC matrix coefficient to scale
quantization : Quantization
The quantization, either FULL or LIMITED
luma : bool
True if the coefficient corresponds to a luma value, False otherwise
"""
# Assume the input range is 8 bits. The output range is set by the
# quantization and differs between luma and chrome components for limited
# range.
in_range = 255 - 0
if quantization == Quantization.FULL:
out_range = 255 - 0
elif luma:
out_range = 235 - 16
else:
out_range = 240 - 16
return coeff * out_range / in_range
def round_array(values):
"""Round a list of signed floating point values to the closest integer while
preserving the (rounded) value of the sum of all elements.
"""
# Calculate the rounding error as the difference between the rounded sum of
# values and the sum of rounded values. This is by definition an integer
# (positive or negative), which indicates how many values will need to be
# 'flipped' to the opposite rounding.
rounded_values = [round(value) for value in values]
sum_values = round(sum(values))
sum_error = sum_values - sum(rounded_values)
if sum_error == 0:
return rounded_values
# The next step is to distribute the error among the values, in a way that
# will minimize the relative error introduced in individual values. We
# extend the values list with the rounded value and original index for each
# element, and sort by rounding error. Then we modify the elements with the
# highest or lowest error, depending on whether the sum error is negative
# or positive.
values = [[value, round(value), index] for index, value in enumerate(values)]
values.sort(key=lambda v: v[1] - v[0])
# It could also be argued that the key for the sort order should not be the
# absolute rouding error but the relative error, as the impact of identical
# rounding errors will differ for coefficients with widely different values.
# This is a topic for further research.
#
# values.sort(key=lambda v: (v[1] - v[0]) / abs(v[0]))
if sum_error > 0:
for i in range(sum_error):
values[i][1] += 1
else:
for i in range(-sum_error):
values[len(values) - i - 1][1] -= 1
# Finally, sort back by index, make sure the total rounding error is now 0,
# and return the rounded values.
values.sort(key=lambda v: v[2])
values = [value[1] for value in values]
assert(sum(values) == sum_values)
return values
def main(argv):
# Parse command line arguments.
parser = argparse.ArgumentParser(
description='Generate color space conversion table coefficients with '
'configurable fixed-point precision.'
)
parser.add_argument('--invert', '-i', action='store_true',
help='Invert the color space conversion (YUV -> RGB)')
parser.add_argument('--precision', '-p', default='Q1.7',
help='The output fixed point precision in Q notation (sign bit excluded)')
parser.add_argument('--quantization', '-q', choices=['full', 'limited'],
default='limited', help='Quantization range')
parser.add_argument('encoding', choices=encodings.keys(), help='YCbCr encoding')
args = parser.parse_args(argv[1:])
try:
precision = Precision(args.precision)
except Exception:
print(f'Invalid precision `{args.precision}`')
return 1
encoding = encodings[args.encoding]
quantization = Quantization[args.quantization.upper()]
# Scale and round the encoding coefficients based on the precision and
# quantization range.
luma = True
scaled_coeffs = []
for line in encoding:
line = [scale_coeff(coeff, quantization, luma) for coeff in line]
scaled_coeffs.append(line)
luma = False
if args.invert:
scaled_coeffs = np.linalg.inv(scaled_coeffs)
rounded_coeffs = []
for line in scaled_coeffs:
line = [coeff * (1 << precision.fractional) for coeff in line]
# For the RGB to YUV conversion, use a rounding method that preserves
# the rounded sum of each line to avoid biases and overflow, as the sum
# of luma and chroma coefficients should be 1.0 and 0.0 respectively
# (in full range). For the YUV to RGB conversion, there is no such
# constraint, so use simple rounding.
if args.invert:
line = [round(coeff) for coeff in line]
else:
line = round_array(line)
# Convert coefficients to the number of bits selected by the precision.
# Negative values will be turned into positive integers using 2's
# complement.
line = [coeff & ((1 << precision.total) - 1) for coeff in line]
rounded_coeffs.append(line)
# Print the result as C code.
nbits = 1 << (precision.total - 1).bit_length()
nbytes = nbits // 4
print(f'static const u{nbits} {"yuv2rgb" if args.invert else "rgb2yuv"}_{args.encoding}_{quantization.name.lower()}_coeffs[] = {{')
for line in rounded_coeffs:
line = [f'0x{coeff:0{nbytes}x}' for coeff in line]
print(f'\t{", ".join(line)},')
print('};')
return 0
if __name__ == '__main__':
sys.exit(main(sys.argv))