160 lines
5.0 KiB
C++
160 lines
5.0 KiB
C++
#include "huffman_table.h"
|
||
#include "node.h"
|
||
|
||
#include <map>
|
||
#include <set>
|
||
#include <queue>
|
||
#include <iostream>
|
||
|
||
#define HEADER_SIZE 128
|
||
|
||
void initialize_table(const std::map<int, std::set<char> > &huffmanLengths,
|
||
std::unordered_map<char, std::pair<int, short> > &table)
|
||
{
|
||
int nextbl = 0;
|
||
short code = 0;
|
||
|
||
for (auto lenCodePairIt = huffmanLengths.begin(); lenCodePairIt != huffmanLengths.end(); lenCodePairIt++)
|
||
{
|
||
auto lenCodePair = *lenCodePairIt;
|
||
|
||
for (auto it = lenCodePair.second.begin(); it != lenCodePair.second.end(); it++)
|
||
{
|
||
table[*it].first = lenCodePair.first; // save current bit length for code
|
||
table[*it].second = code;
|
||
|
||
// code := (code + 1) << ((bit length of the next symbol) − (current bit length))
|
||
// code++;
|
||
|
||
// if last symbol of length and has symbols with greater length
|
||
if (next(it) == lenCodePair.second.end() && next(lenCodePairIt) != huffmanLengths.end()) {
|
||
nextbl = (*next(lenCodePairIt)).first;
|
||
} else {
|
||
nextbl = lenCodePair.first;
|
||
}
|
||
|
||
code = (code + 1) << (nextbl - lenCodePair.first);
|
||
}
|
||
|
||
// code <<= 1;
|
||
}
|
||
}
|
||
|
||
HuffmanTable::HuffmanTable(uint8_t *header) {
|
||
int cnt1, cnt2;
|
||
std::map<int, std::set<char> > huffmanLengths;
|
||
|
||
for (int i = 0; i < HEADER_SIZE; i++) {
|
||
cnt1 = ((header[i] & 0b11110000) >> 4);
|
||
cnt2 = (header[i] & 0b1111);
|
||
if (cnt1 != 0) huffmanLengths[cnt1].insert((char)(i * 2));
|
||
if (cnt2 != 0) huffmanLengths[cnt2].insert((char)(i * 2 + 1));
|
||
}
|
||
|
||
// build up codes
|
||
initialize_table(huffmanLengths, this->huffmanCodes);
|
||
}
|
||
|
||
void get_lengths(Node* root, int len,
|
||
std::map<int, std::set<char> > &huffmanLengths)
|
||
{
|
||
if (!root)
|
||
return;
|
||
// std::cerr << "Get lengths node: " << root->getChar() << " " << root->getFreq() << root->getLeft() << " " << root->getRight() << std::endl;
|
||
// found a leaf node
|
||
if (root->isLeaf()) {
|
||
// huffmanCode[root->ch] = str;
|
||
// std::cerr << "Got leaf: " << root->getChar() << std::endl;
|
||
huffmanLengths[len].insert(root->getChar());
|
||
}
|
||
|
||
get_lengths(root->getLeft(), len + 1, huffmanLengths);
|
||
get_lengths(root->getRight(), len + 1, huffmanLengths);
|
||
}
|
||
|
||
HuffmanTable::HuffmanTable(std::basic_istream<char> &is) {
|
||
// count frequency of appearance of each character
|
||
// and store it in a map
|
||
std::unordered_map<char, int> freq;
|
||
char ch;
|
||
while (is.get(ch)) {
|
||
freq[ch]++;
|
||
}
|
||
|
||
std::cerr << "Calculated freqs" << std::endl;
|
||
|
||
// Create a priority queue to store live nodes of
|
||
// Huffman tree;
|
||
std::priority_queue<Node*, std::vector<Node*>, NodeComp> pq;
|
||
|
||
// Create a leaf node for each character and add it
|
||
// to the priority queue.
|
||
for (auto pair: freq) {
|
||
Node *new_node = new Node(pair.first, pair.second, nullptr, nullptr);
|
||
pq.push(new_node);
|
||
}
|
||
|
||
std::cerr << "Filled PQ: " << pq.size() << std::endl;
|
||
|
||
// do till there is more than one node in the queue
|
||
while (pq.size() != 1)
|
||
{
|
||
// Remove the two nodes of highest priority
|
||
// (lowest frequency) from the queue
|
||
Node *left = pq.top(); pq.pop();
|
||
Node *right = pq.top(); pq.pop();
|
||
|
||
// Create a new internal node with these two nodes
|
||
// as children and with frequency equal to the sum
|
||
// of the two nodes' frequencies. Add the new node
|
||
// to the priority queue.
|
||
int sum = left->getFreq() + right->getFreq();
|
||
Node *new_node = new Node('\0', sum, left, right);
|
||
pq.push(new_node);
|
||
}
|
||
|
||
// std::cerr << "Built tree: " << pq.size() << " " << pq.top()->getFreq() << std::endl;
|
||
|
||
// root stores pointer to root of Huffman Tree
|
||
Node* root = pq.top();
|
||
|
||
std::map<int, std::set<char> > huffmanLengths;
|
||
get_lengths(root, 0, huffmanLengths);
|
||
|
||
// std::cerr << "Got lengths: " << huffmanLengths.size() << std::endl;
|
||
|
||
initialize_table(huffmanLengths, this->huffmanCodes);
|
||
}
|
||
|
||
std::pair<int, short> HuffmanTable::operator[](const char &c) {
|
||
return huffmanCodes[c];
|
||
}
|
||
|
||
void HuffmanTable::write_symbol(obitstream &os, const char &c) {
|
||
if (huffmanCodes.find(c) == huffmanCodes.end()) throw std::runtime_error("No code in table for char!");
|
||
|
||
os.writebits(huffmanCodes[c].second, huffmanCodes[c].first);
|
||
}
|
||
|
||
uint8_t *HuffmanTable::to_header() {
|
||
uint8_t *header = new uint8_t[HEADER_SIZE];
|
||
|
||
for (size_t i = 0; i < HEADER_SIZE; i++)
|
||
{
|
||
if (huffmanCodes.find(2 * i) != huffmanCodes.end()) {
|
||
int len = huffmanCodes[2 * i].first;
|
||
if (len > 0xf) throw std::runtime_error("Codes longer than 0xf are not allowed!");
|
||
|
||
header[i] |= (len & 0xf) << 4;
|
||
}
|
||
|
||
if (huffmanCodes.find(2 * i + 1) != huffmanCodes.end()) {
|
||
int len = huffmanCodes[2 * i + 1].first;
|
||
if (len > 0xf) throw std::runtime_error("Codes longer than 0xf are not allowed!");
|
||
|
||
header[i] |= (len & 0xf);
|
||
}
|
||
}
|
||
|
||
return header;
|
||
} |