Files
mashzip/huffman_table.cpp
Andrey Gumirov b6a619303d WIP: MVP
2024-01-11 02:44:53 +07:00

221 lines
7.1 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "huffman_table.h"
#include "node.h"
#include <map>
#include <set>
#include <queue>
#include <iostream>
void initialize_table(const int sumLen,
const std::map<int, std::set<char> > &huffmanLengths,
std::unordered_map<char, std::pair<int, short> > &codingTable,
std::vector<char> &symbols,
std::array<int, MAX_LEN> &counts)
{
int nextbl = 0, offset = 0; // offset = total offset to symbols of current len
short code = 0;
symbols.resize(sumLen);
counts.fill(0);
// std::cerr << "Sum len " << sumLen << std::endl;
for (auto lenCodePairIt = huffmanLengths.begin(); lenCodePairIt != huffmanLengths.end(); lenCodePairIt++)
{
int cnt = 0; // counter of symbols of current code length
auto lenCodePair = *lenCodePairIt;
counts[lenCodePair.first] = lenCodePair.second.size();
// std::cerr << "Counts[" << lenCodePair.first << "] " << counts[lenCodePair.first] << std::endl;
for (auto it = lenCodePair.second.begin(); it != lenCodePair.second.end(); it++)
{
codingTable[*it].first = lenCodePair.first; // save current bit length for code
codingTable[*it].second = code;
// code := (code + 1) << ((bit length of the next symbol) (current bit length))
// code++;
// if last symbol of length and has symbols with greater length
if (next(it) == lenCodePair.second.end() && next(lenCodePairIt) != huffmanLengths.end()) {
nextbl = (*next(lenCodePairIt)).first;
} else {
nextbl = lenCodePair.first;
}
// std::cerr << "symbols[" << offset + cnt << "] =" << (*it) << " code " << std::bitset<16>(code) << std::endl;
symbols[offset + cnt] = *it;
code = (code + 1) << (nextbl - lenCodePair.first);
cnt++;
}
offset += cnt;
// code <<= 1;
}
}
HuffmanTable::HuffmanTable(const char *header) {
int cnt1, cnt2, total_cnt = 0;
std::map<int, std::set<char> > huffmanLengths;
for (int i = 0; i < HEADER_SIZE; i++) {
cnt1 = ((header[i] & 0b11110000) >> 4);
cnt2 = (header[i] & 0b1111);
if (cnt1 != 0) huffmanLengths[cnt1].insert((char)(i * 2));
if (cnt2 != 0) huffmanLengths[cnt2].insert((char)(i * 2 + 1));
total_cnt += cnt1 + cnt2;
}
// build up codes
initialize_table(total_cnt, huffmanLengths, this->huffmanCodes, this->symbols, this->counts);
}
void get_lengths(Node* root, int len, int &cnt,
std::map<int, std::set<char> > &huffmanLengths)
{
if (!root)
return;
// std::cerr << "Get lengths node: " << root->getChar() << " " << root->getFreq() << root->getLeft() << " " << root->getRight() << std::endl;
// found a leaf node
if (root->isLeaf()) {
// huffmanCode[root->ch] = str;
// std::cerr << "Got leaf: " << root->getChar() << std::endl;
cnt++;
huffmanLengths[len].insert(root->getChar());
}
get_lengths(root->getLeft(), len + 1, cnt, huffmanLengths);
get_lengths(root->getRight(), len + 1, cnt, huffmanLengths);
}
HuffmanTable::HuffmanTable(std::basic_istream<char> &is) {
// count frequency of appearance of each character
// and store it in a map
std::unordered_map<char, int> freq;
char ch;
while (is.get(ch)) {
freq[ch]++;
}
// std::cerr << "Calculated freqs" << std::endl;
// Create a priority queue to store live nodes of
// Huffman tree;
std::priority_queue<Node*, std::vector<Node*>, NodeComp> pq;
// Create a leaf node for each character and add it
// to the priority queue.
for (auto pair: freq) {
Node *new_node = new Node(pair.first, pair.second, nullptr, nullptr);
pq.push(new_node);
}
// std::cerr << "Filled PQ: " << pq.size() << std::endl;
// do till there is more than one node in the queue
while (pq.size() != 1)
{
// Remove the two nodes of highest priority
// (lowest frequency) from the queue
Node *left = pq.top(); pq.pop();
Node *right = pq.top(); pq.pop();
// Create a new internal node with these two nodes
// as children and with frequency equal to the sum
// of the two nodes' frequencies. Add the new node
// to the priority queue.
int sum = left->getFreq() + right->getFreq();
Node *new_node = new Node('\0', sum, left, right);
pq.push(new_node);
}
// std::cerr << "Built tree: " << pq.size() << " " << pq.top()->getFreq() << std::endl;
// root stores pointer to root of Huffman Tree
Node* root = pq.top();
std::map<int, std::set<char> > huffmanLengths;
int total_cnt = 0;
get_lengths(root, 0, total_cnt, huffmanLengths);
// std::cerr << "Got lengths: " << huffmanLengths.size() << std::endl;
initialize_table(total_cnt, huffmanLengths, this->huffmanCodes, this->symbols, this->counts);
// for (auto s : this->symbols) {
// std::cerr << "Symbol " << s << std::endl;
// }
// for (int i = 0; i < this->counts.size(); i++) {
// std::cerr << "Count for len " << i << " " << this->counts[i] << std::endl;
// }
}
std::pair<int, short> HuffmanTable::operator[](const char &c) {
return huffmanCodes[c];
}
void HuffmanTable::write_symbol(obitstream &os, const char &c) {
if (huffmanCodes.find(c) == huffmanCodes.end()) throw std::runtime_error("No code in table for char!");
std::cerr << "Write code for " << c << " " << (int)c << " : " << std::bitset<16>(huffmanCodes[c].second) << " " << " len " << huffmanCodes[c].first << std::endl;
os.writebits(huffmanCodes[c].second, huffmanCodes[c].first);
}
int HuffmanTable::decode_one_symbol(ibitstream &bs)
{
uint16_t code = 0;
int len = 1, first = 0, index = 0;
while (len <= MAX_LEN) {
// read one bit
uint16_t bit = (uint16_t) bs.getbits(1);
code |= bit;
int count = this->counts[len];
// std::cerr << "Read bit " << bit << " code " << std::bitset<16>(code) << " len " << len <<
// " first " << std::bitset<16>(first) << " index " << index << " count " << count << std::endl;
if (code < first + count) {
return this->symbols[index + (code - first)];
}
index += count;
first += count;
first <<= 1;
code <<= 1;
len++;
}
return -1;
}
char *HuffmanTable::to_header() {
char *header = new char[HEADER_SIZE];
for (size_t i = 0; i < HEADER_SIZE; i++)
{
if (huffmanCodes.find(2 * i) != huffmanCodes.end()) {
int len = huffmanCodes[2 * i].first;
if (len > 0xf) throw std::runtime_error("Codes longer than 0xf are not allowed!");
header[i] |= (len & 0xf) << 4;
}
if (huffmanCodes.find(2 * i + 1) != huffmanCodes.end()) {
int len = huffmanCodes[2 * i + 1].first;
if (len > 0xf) throw std::runtime_error("Codes longer than 0xf are not allowed!");
header[i] |= (len & 0xf);
}
}
return header;
}