I am implementing a new compression algorithm for the weights of a neural network for the Leela Chess project. the weights are roughly 100Mb of float32s which I want to compress as small as possible. Error tolerance for this application is 2^-17, so lossy compression is clearly the right answer here. All of the weights are between -5 and 5, but 99.995% are in (-.25,.25) and most reasonably closely clumped around zero.
The basic idea with this algorithm is to turn floats into integer multiples of the error tolerance, and then use a utf-8 inspired encoding to represent small values with only 1 byte.
import numpy as np
import bz2
def compress(in_path, out_path):
with open(in_path, 'rb') as array:
net = np.fromfile(in_path, dtype=np.float32)
# Quantize
net = np.asarray(net * 2**17, np.int32)
# Zigzag encode
net = (net >> 31) ^ (net << 1)
# To variable length
result = np.zeros(len(net)*3, dtype=np.uint8)
for i in range(3):
big = (net >= 128) << 7
result[i::3] = (net % 128) + big
net >>= 7
# Delete non-essential indices
zeroes = np.where(result == 0)[0]
zeroes = zeroes[np.where(zeroes % 3 != 0)]
result = np.delete(result, zeroes)
with bz2.open(out_path, 'wb') as out:
out.write(result.tobytes())
def decompress(in_path, out_path):
with bz2.open(in_path, 'rb') as array:
result = np.frombuffer(array.read(), dtype=np.uint8)
start_inds = np.where(result<128)[0]
not_zeroed = np.ones(len(start_inds), dtype=np.bool)
# append zeroe so loop doesn't go out of bounds
result = np.append(result, np.zeros(4, dtype=np.uint8))
# Get back fixed length from variable length
net = np.zeros(len(start_inds), dtype=np.uint32)
for i in range(3):
change = (result[start_inds] % 128) * not_zeroed
net[np.where(not_zeroed)[0]] *= 128
net += change
start_inds += 1
not_zeroed &= result[start_inds] >= 128
# Zigzag decode
net = (net >> 1) ^ -(net & 1)
print(np.mean(net))
# Un-quantize
net = np.asarray(net, np.float32)
net /= 2**17
with open(out_path, 'wb') as out:
out.write(version)
out.write(net.tobytes())
compress('diff.hex','diff.bz2')
decompress('diff.bz2','round.hex')
The main type of advice I'm looking for is algorithm and performance advice, but ways to make the code readable are always nice.