221 lines
7.1 KiB
C++
221 lines
7.1 KiB
C++
#include "huffman_table.h"
|
||
#include "node.h"
|
||
|
||
#include <map>
|
||
#include <set>
|
||
#include <queue>
|
||
#include <iostream>
|
||
|
||
void initialize_table(const int sumLen,
|
||
const std::map<int, std::set<char> > &huffmanLengths,
|
||
std::unordered_map<char, std::pair<int, short> > &codingTable,
|
||
std::vector<char> &symbols,
|
||
std::array<int, MAX_LEN> &counts)
|
||
{
|
||
int nextbl = 0, offset = 0; // offset = total offset to symbols of current len
|
||
short code = 0;
|
||
|
||
symbols.resize(sumLen);
|
||
counts.fill(0);
|
||
|
||
// std::cerr << "Sum len " << sumLen << std::endl;
|
||
|
||
for (auto lenCodePairIt = huffmanLengths.begin(); lenCodePairIt != huffmanLengths.end(); lenCodePairIt++)
|
||
{
|
||
int cnt = 0; // counter of symbols of current code length
|
||
auto lenCodePair = *lenCodePairIt;
|
||
|
||
counts[lenCodePair.first] = lenCodePair.second.size();
|
||
// std::cerr << "Counts[" << lenCodePair.first << "] " << counts[lenCodePair.first] << std::endl;
|
||
|
||
for (auto it = lenCodePair.second.begin(); it != lenCodePair.second.end(); it++)
|
||
{
|
||
codingTable[*it].first = lenCodePair.first; // save current bit length for code
|
||
codingTable[*it].second = code;
|
||
|
||
// code := (code + 1) << ((bit length of the next symbol) − (current bit length))
|
||
// code++;
|
||
|
||
// if last symbol of length and has symbols with greater length
|
||
if (next(it) == lenCodePair.second.end() && next(lenCodePairIt) != huffmanLengths.end()) {
|
||
nextbl = (*next(lenCodePairIt)).first;
|
||
} else {
|
||
nextbl = lenCodePair.first;
|
||
}
|
||
|
||
// std::cerr << "symbols[" << offset + cnt << "] =" << (*it) << " code " << std::bitset<16>(code) << std::endl;
|
||
symbols[offset + cnt] = *it;
|
||
|
||
code = (code + 1) << (nextbl - lenCodePair.first);
|
||
cnt++;
|
||
}
|
||
|
||
offset += cnt;
|
||
// code <<= 1;
|
||
}
|
||
}
|
||
|
||
HuffmanTable::HuffmanTable(const char *header) {
|
||
int cnt1, cnt2, total_cnt = 0;
|
||
std::map<int, std::set<char> > huffmanLengths;
|
||
|
||
for (int i = 0; i < HEADER_SIZE; i++) {
|
||
cnt1 = ((header[i] & 0b11110000) >> 4);
|
||
cnt2 = (header[i] & 0b1111);
|
||
if (cnt1 != 0) huffmanLengths[cnt1].insert((char)(i * 2));
|
||
if (cnt2 != 0) huffmanLengths[cnt2].insert((char)(i * 2 + 1));
|
||
|
||
total_cnt += cnt1 + cnt2;
|
||
}
|
||
|
||
// build up codes
|
||
initialize_table(total_cnt, huffmanLengths, this->huffmanCodes, this->symbols, this->counts);
|
||
}
|
||
|
||
void get_lengths(Node* root, int len, int &cnt,
|
||
std::map<int, std::set<char> > &huffmanLengths)
|
||
{
|
||
if (!root)
|
||
return;
|
||
// std::cerr << "Get lengths node: " << root->getChar() << " " << root->getFreq() << root->getLeft() << " " << root->getRight() << std::endl;
|
||
// found a leaf node
|
||
if (root->isLeaf()) {
|
||
// huffmanCode[root->ch] = str;
|
||
// std::cerr << "Got leaf: " << root->getChar() << std::endl;
|
||
cnt++;
|
||
huffmanLengths[len].insert(root->getChar());
|
||
}
|
||
|
||
get_lengths(root->getLeft(), len + 1, cnt, huffmanLengths);
|
||
get_lengths(root->getRight(), len + 1, cnt, huffmanLengths);
|
||
}
|
||
|
||
HuffmanTable::HuffmanTable(std::basic_istream<char> &is) {
|
||
// count frequency of appearance of each character
|
||
// and store it in a map
|
||
std::unordered_map<char, int> freq;
|
||
char ch;
|
||
while (is.get(ch)) {
|
||
freq[ch]++;
|
||
}
|
||
|
||
// std::cerr << "Calculated freqs" << std::endl;
|
||
|
||
// Create a priority queue to store live nodes of
|
||
// Huffman tree;
|
||
std::priority_queue<Node*, std::vector<Node*>, NodeComp> pq;
|
||
|
||
// Create a leaf node for each character and add it
|
||
// to the priority queue.
|
||
for (auto pair: freq) {
|
||
Node *new_node = new Node(pair.first, pair.second, nullptr, nullptr);
|
||
pq.push(new_node);
|
||
}
|
||
|
||
// std::cerr << "Filled PQ: " << pq.size() << std::endl;
|
||
|
||
// do till there is more than one node in the queue
|
||
while (pq.size() != 1)
|
||
{
|
||
// Remove the two nodes of highest priority
|
||
// (lowest frequency) from the queue
|
||
Node *left = pq.top(); pq.pop();
|
||
Node *right = pq.top(); pq.pop();
|
||
|
||
// Create a new internal node with these two nodes
|
||
// as children and with frequency equal to the sum
|
||
// of the two nodes' frequencies. Add the new node
|
||
// to the priority queue.
|
||
int sum = left->getFreq() + right->getFreq();
|
||
Node *new_node = new Node('\0', sum, left, right);
|
||
pq.push(new_node);
|
||
}
|
||
|
||
// std::cerr << "Built tree: " << pq.size() << " " << pq.top()->getFreq() << std::endl;
|
||
|
||
// root stores pointer to root of Huffman Tree
|
||
Node* root = pq.top();
|
||
|
||
std::map<int, std::set<char> > huffmanLengths;
|
||
int total_cnt = 0;
|
||
get_lengths(root, 0, total_cnt, huffmanLengths);
|
||
|
||
// std::cerr << "Got lengths: " << huffmanLengths.size() << std::endl;
|
||
|
||
initialize_table(total_cnt, huffmanLengths, this->huffmanCodes, this->symbols, this->counts);
|
||
|
||
// for (auto s : this->symbols) {
|
||
// std::cerr << "Symbol " << s << std::endl;
|
||
// }
|
||
|
||
// for (int i = 0; i < this->counts.size(); i++) {
|
||
// std::cerr << "Count for len " << i << " " << this->counts[i] << std::endl;
|
||
// }
|
||
}
|
||
|
||
std::pair<int, short> HuffmanTable::operator[](const char &c) {
|
||
return huffmanCodes[c];
|
||
}
|
||
|
||
void HuffmanTable::write_symbol(obitstream &os, const char &c) {
|
||
if (huffmanCodes.find(c) == huffmanCodes.end()) throw std::runtime_error("No code in table for char!");
|
||
|
||
std::cerr << "Write code for " << c << " " << (int)c << " : " << std::bitset<16>(huffmanCodes[c].second) << " " << " len " << huffmanCodes[c].first << std::endl;
|
||
|
||
os.writebits(huffmanCodes[c].second, huffmanCodes[c].first);
|
||
}
|
||
|
||
int HuffmanTable::decode_one_symbol(ibitstream &bs)
|
||
{
|
||
uint16_t code = 0;
|
||
int len = 1, first = 0, index = 0;
|
||
|
||
while (len <= MAX_LEN) {
|
||
// read one bit
|
||
uint16_t bit = (uint16_t) bs.getbits(1);
|
||
|
||
code |= bit;
|
||
|
||
|
||
int count = this->counts[len];
|
||
|
||
// std::cerr << "Read bit " << bit << " code " << std::bitset<16>(code) << " len " << len <<
|
||
// " first " << std::bitset<16>(first) << " index " << index << " count " << count << std::endl;
|
||
|
||
|
||
if (code < first + count) {
|
||
return this->symbols[index + (code - first)];
|
||
}
|
||
|
||
index += count;
|
||
first += count;
|
||
first <<= 1;
|
||
code <<= 1;
|
||
len++;
|
||
}
|
||
|
||
return -1;
|
||
}
|
||
|
||
char *HuffmanTable::to_header() {
|
||
char *header = new char[HEADER_SIZE];
|
||
|
||
for (size_t i = 0; i < HEADER_SIZE; i++)
|
||
{
|
||
if (huffmanCodes.find(2 * i) != huffmanCodes.end()) {
|
||
int len = huffmanCodes[2 * i].first;
|
||
if (len > 0xf) throw std::runtime_error("Codes longer than 0xf are not allowed!");
|
||
|
||
header[i] |= (len & 0xf) << 4;
|
||
}
|
||
|
||
if (huffmanCodes.find(2 * i + 1) != huffmanCodes.end()) {
|
||
int len = huffmanCodes[2 * i + 1].first;
|
||
if (len > 0xf) throw std::runtime_error("Codes longer than 0xf are not allowed!");
|
||
|
||
header[i] |= (len & 0xf);
|
||
}
|
||
}
|
||
|
||
return header;
|
||
} |