#include "huffman_table.h" #include "node.h" #include #include #include #include #define HEADER_SIZE 128 void initialize_table(const std::map > &huffmanLengths, std::unordered_map > &table) { int nextbl = 0; short code = 0; for (auto lenCodePairIt = huffmanLengths.begin(); lenCodePairIt != huffmanLengths.end(); lenCodePairIt++) { auto lenCodePair = *lenCodePairIt; for (auto it = lenCodePair.second.begin(); it != lenCodePair.second.end(); it++) { table[*it].first = lenCodePair.first; // save current bit length for code table[*it].second = code; // code := (code + 1) << ((bit length of the next symbol) − (current bit length)) // code++; // if last symbol of length and has symbols with greater length if (next(it) == lenCodePair.second.end() && next(lenCodePairIt) != huffmanLengths.end()) { nextbl = (*next(lenCodePairIt)).first; } else { nextbl = lenCodePair.first; } code = (code + 1) << (nextbl - lenCodePair.first); } // code <<= 1; } } HuffmanTable::HuffmanTable(uint8_t *header) { int cnt1, cnt2; std::map > huffmanLengths; for (int i = 0; i < HEADER_SIZE; i++) { cnt1 = ((header[i] & 0b11110000) >> 4); cnt2 = (header[i] & 0b1111); if (cnt1 != 0) huffmanLengths[cnt1].insert((char)(i * 2)); if (cnt2 != 0) huffmanLengths[cnt2].insert((char)(i * 2 + 1)); } // build up codes initialize_table(huffmanLengths, this->huffmanCodes); } void get_lengths(Node* root, int len, std::map > &huffmanLengths) { if (!root) return; // std::cerr << "Get lengths node: " << root->getChar() << " " << root->getFreq() << root->getLeft() << " " << root->getRight() << std::endl; // found a leaf node if (root->isLeaf()) { // huffmanCode[root->ch] = str; // std::cerr << "Got leaf: " << root->getChar() << std::endl; huffmanLengths[len].insert(root->getChar()); } get_lengths(root->getLeft(), len + 1, huffmanLengths); get_lengths(root->getRight(), len + 1, huffmanLengths); } HuffmanTable::HuffmanTable(std::basic_istream &is) { // count frequency of appearance of each character // and store it in a map std::unordered_map freq; char ch; while (is.get(ch)) { freq[ch]++; } std::cerr << "Calculated freqs" << std::endl; // Create a priority queue to store live nodes of // Huffman tree; std::priority_queue, NodeComp> pq; // Create a leaf node for each character and add it // to the priority queue. for (auto pair: freq) { Node *new_node = new Node(pair.first, pair.second, nullptr, nullptr); pq.push(new_node); } std::cerr << "Filled PQ: " << pq.size() << std::endl; // do till there is more than one node in the queue while (pq.size() != 1) { // Remove the two nodes of highest priority // (lowest frequency) from the queue Node *left = pq.top(); pq.pop(); Node *right = pq.top(); pq.pop(); // Create a new internal node with these two nodes // as children and with frequency equal to the sum // of the two nodes' frequencies. Add the new node // to the priority queue. int sum = left->getFreq() + right->getFreq(); Node *new_node = new Node('\0', sum, left, right); pq.push(new_node); } // std::cerr << "Built tree: " << pq.size() << " " << pq.top()->getFreq() << std::endl; // root stores pointer to root of Huffman Tree Node* root = pq.top(); std::map > huffmanLengths; get_lengths(root, 0, huffmanLengths); // std::cerr << "Got lengths: " << huffmanLengths.size() << std::endl; initialize_table(huffmanLengths, this->huffmanCodes); } std::pair HuffmanTable::operator[](const char &c) { return huffmanCodes[c]; } void HuffmanTable::write_symbol(obitstream &os, const char &c) { if (huffmanCodes.find(c) == huffmanCodes.end()) throw std::runtime_error("No code in table for char!"); os.writebits(huffmanCodes[c].second, huffmanCodes[c].first); } uint8_t *HuffmanTable::to_header() { uint8_t *header = new uint8_t[HEADER_SIZE]; for (size_t i = 0; i < HEADER_SIZE; i++) { if (huffmanCodes.find(2 * i) != huffmanCodes.end()) { int len = huffmanCodes[2 * i].first; if (len > 0xf) throw std::runtime_error("Codes longer than 0xf are not allowed!"); header[i] |= (len & 0xf) << 4; } if (huffmanCodes.find(2 * i + 1) != huffmanCodes.end()) { int len = huffmanCodes[2 * i + 1].first; if (len > 0xf) throw std::runtime_error("Codes longer than 0xf are not allowed!"); header[i] |= (len & 0xf); } } return header; }