This commit is contained in:
Andrey Gumirov
2024-01-08 02:39:45 +07:00
parent fdcba0b869
commit 606ce553ac
9 changed files with 1594 additions and 0 deletions

160
huffman_table.cpp Normal file
View File

@ -0,0 +1,160 @@
#include "huffman_table.h"
#include "node.h"
#include <map>
#include <set>
#include <queue>
#include <iostream>
#define HEADER_SIZE 128
void initialize_table(const std::map<int, std::set<char> > &huffmanLengths,
std::unordered_map<char, std::pair<int, short> > &table)
{
int nextbl = 0;
short code = 0;
for (auto lenCodePairIt = huffmanLengths.begin(); lenCodePairIt != huffmanLengths.end(); lenCodePairIt++)
{
auto lenCodePair = *lenCodePairIt;
for (auto it = lenCodePair.second.begin(); it != lenCodePair.second.end(); it++)
{
table[*it].first = lenCodePair.first; // save current bit length for code
table[*it].second = code;
// code := (code + 1) << ((bit length of the next symbol) (current bit length))
// code++;
// if last symbol of length and has symbols with greater length
if (next(it) == lenCodePair.second.end() && next(lenCodePairIt) != huffmanLengths.end()) {
nextbl = (*next(lenCodePairIt)).first;
} else {
nextbl = lenCodePair.first;
}
code = (code + 1) << (nextbl - lenCodePair.first);
}
// code <<= 1;
}
}
HuffmanTable::HuffmanTable(uint8_t *header) {
int cnt1, cnt2;
std::map<int, std::set<char> > huffmanLengths;
for (int i = 0; i < HEADER_SIZE; i++) {
cnt1 = ((header[i] & 0b11110000) >> 4);
cnt2 = (header[i] & 0b1111);
if (cnt1 != 0) huffmanLengths[cnt1].insert((char)(i * 2));
if (cnt2 != 0) huffmanLengths[cnt2].insert((char)(i * 2 + 1));
}
// build up codes
initialize_table(huffmanLengths, this->huffmanCodes);
}
void get_lengths(Node* root, int len,
std::map<int, std::set<char> > &huffmanLengths)
{
if (!root)
return;
// std::cerr << "Get lengths node: " << root->getChar() << " " << root->getFreq() << root->getLeft() << " " << root->getRight() << std::endl;
// found a leaf node
if (root->isLeaf()) {
// huffmanCode[root->ch] = str;
// std::cerr << "Got leaf: " << root->getChar() << std::endl;
huffmanLengths[len].insert(root->getChar());
}
get_lengths(root->getLeft(), len + 1, huffmanLengths);
get_lengths(root->getRight(), len + 1, huffmanLengths);
}
HuffmanTable::HuffmanTable(std::basic_istream<char> &is) {
// count frequency of appearance of each character
// and store it in a map
std::unordered_map<char, int> freq;
char ch;
while (is.get(ch)) {
freq[ch]++;
}
std::cerr << "Calculated freqs" << std::endl;
// Create a priority queue to store live nodes of
// Huffman tree;
std::priority_queue<Node*, std::vector<Node*>, NodeComp> pq;
// Create a leaf node for each character and add it
// to the priority queue.
for (auto pair: freq) {
Node *new_node = new Node(pair.first, pair.second, nullptr, nullptr);
pq.push(new_node);
}
std::cerr << "Filled PQ: " << pq.size() << std::endl;
// do till there is more than one node in the queue
while (pq.size() != 1)
{
// Remove the two nodes of highest priority
// (lowest frequency) from the queue
Node *left = pq.top(); pq.pop();
Node *right = pq.top(); pq.pop();
// Create a new internal node with these two nodes
// as children and with frequency equal to the sum
// of the two nodes' frequencies. Add the new node
// to the priority queue.
int sum = left->getFreq() + right->getFreq();
Node *new_node = new Node('\0', sum, left, right);
pq.push(new_node);
}
// std::cerr << "Built tree: " << pq.size() << " " << pq.top()->getFreq() << std::endl;
// root stores pointer to root of Huffman Tree
Node* root = pq.top();
std::map<int, std::set<char> > huffmanLengths;
get_lengths(root, 0, huffmanLengths);
// std::cerr << "Got lengths: " << huffmanLengths.size() << std::endl;
initialize_table(huffmanLengths, this->huffmanCodes);
}
std::pair<int, short> HuffmanTable::operator[](const char &c) {
return huffmanCodes[c];
}
void HuffmanTable::write_symbol(obitstream &os, const char &c) {
if (huffmanCodes.find(c) == huffmanCodes.end()) throw std::runtime_error("No code in table for char!");
os.writebits(huffmanCodes[c].second, huffmanCodes[c].first);
}
uint8_t *HuffmanTable::to_header() {
uint8_t *header = new uint8_t[HEADER_SIZE];
for (size_t i = 0; i < HEADER_SIZE; i++)
{
if (huffmanCodes.find(2 * i) != huffmanCodes.end()) {
int len = huffmanCodes[2 * i].first;
if (len > 0xf) throw std::runtime_error("Codes longer than 0xf are not allowed!");
header[i] |= (len & 0xf) << 4;
}
if (huffmanCodes.find(2 * i + 1) != huffmanCodes.end()) {
int len = huffmanCodes[2 * i + 1].first;
if (len > 0xf) throw std::runtime_error("Codes longer than 0xf are not allowed!");
header[i] |= (len & 0xf);
}
}
return header;
}