Main Page   Class Hierarchy   Alphabetical List   Compound List   File List   Compound Members   File Members  

zinflate.cpp

00001 // zinflate.cpp - written and placed in the public domain by Wei Dai
00002 
00003 #include "pch.h"
00004 #include "zinflate.h"
00005 
00006 NAMESPACE_BEGIN(CryptoPP)
00007 
00008 inline bool LowFirstBitReader::FillBuffer(unsigned int length)
00009 {
00010         while (m_bitsBuffered < length)
00011         {
00012                 byte b;
00013                 if (!m_store.Get(b))
00014                         return false;
00015                 m_buffer |= (unsigned long)b << m_bitsBuffered;
00016                 m_bitsBuffered += 8;
00017         }
00018         assert(m_bitsBuffered <= sizeof(unsigned long)*8);
00019         return true;
00020 }
00021 
00022 inline unsigned long LowFirstBitReader::PeekBits(unsigned int length)
00023 {
00024         bool result = FillBuffer(length);
00025         assert(result);
00026         return m_buffer & (((unsigned long)1 << length) - 1);
00027 }
00028 
00029 inline void LowFirstBitReader::SkipBits(unsigned int length)
00030 {
00031         assert(m_bitsBuffered >= length);
00032         m_buffer >>= length;
00033         m_bitsBuffered -= length;
00034 }
00035 
00036 inline unsigned long LowFirstBitReader::GetBits(unsigned int length)
00037 {
00038         unsigned long result = PeekBits(length);
00039         SkipBits(length);
00040         return result;
00041 }
00042 
00043 inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits)
00044 {
00045         return code << (MAX_CODE_BITS - codeBits);
00046 }
00047 
00048 void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes)
00049 {
00050         if (nCodes == 0)
00051                 throw Err("null code");
00052 
00053         m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes);
00054 
00055         if (m_maxCodeBits == 0)
00056                 throw Err("null code");
00057 
00058         SecBlock<unsigned int> blCount(m_maxCodeBits+1);
00059         std::fill(blCount.Begin(), blCount.End(), 0);
00060         unsigned int i;
00061         for (i=0; i<nCodes; i++)
00062                 blCount[codeBits[i]]++;
00063 
00064         code_t code = 0;
00065         SecBlock<code_t> nextCode(m_maxCodeBits+1);
00066         nextCode[1] = 0;
00067         for (i=2; i<=m_maxCodeBits; i++)
00068         {
00069                 // compute this while checking for overflow: code = (code + blCount[i-1]) << 1
00070                 if (code > code + blCount[i-1])
00071                         throw Err("codes oversubscribed");
00072                 code += blCount[i-1];
00073                 if (code > (code << 1))
00074                         throw Err("codes oversubscribed");
00075                 code <<= 1;
00076                 nextCode[i] = code;
00077         }
00078 
00079         if (code > (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
00080                 throw Err("codes oversubscribed");
00081         else if (m_maxCodeBits != 1 && code < (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
00082                 throw Err("codes incomplete");
00083 
00084         m_codeToValue.Resize(nCodes - blCount[0]);
00085         unsigned int j=0;
00086         for (i=0; i<nCodes; i++) 
00087         {
00088                 unsigned int len = codeBits[i];
00089                 if (len != 0)
00090                 {
00091                         code = NormalizeCode(nextCode[len]++, len);
00092                         m_codeToValue[j].code = code;
00093                         m_codeToValue[j].len = len;
00094                         m_codeToValue[j].value = i;
00095                         j++;
00096                 }
00097         }
00098         std::sort(m_codeToValue.Begin(), m_codeToValue.End());
00099 
00100         m_cacheBits = STDMIN(9U, m_maxCodeBits);
00101         m_cacheMask = (1 << m_cacheBits) - 1;
00102         code_t leftoverMask = ~NormalizeCode(m_cacheMask, m_cacheBits);
00103         m_cache.Resize(1 << m_cacheBits);
00104 
00105         for (i=0; i<m_cache.size; i++)
00106         {
00107                 code_t normalizedCode = bitReverse(i);
00108                 const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.Begin(), m_codeToValue.End(), normalizedCode, CodeLessThan)-1);
00109                 if (codeInfo.len <= m_cacheBits)
00110                 {
00111                         m_cache[i].type = 0;
00112                         m_cache[i].value = codeInfo.value;
00113                         m_cache[i].len = codeInfo.len;
00114                 }
00115                 else
00116                 {
00117                         m_cache[i].begin = &codeInfo;
00118                         const CodeInfo *last = std::upper_bound(m_codeToValue.Begin(), m_codeToValue.End(), normalizedCode + leftoverMask, CodeLessThan)-1;
00119                         if (codeInfo.len == last->len)
00120                         {
00121                                 m_cache[i].type = 1;
00122                                 m_cache[i].len = codeInfo.len;
00123                         }
00124                         else
00125                         {
00126                                 m_cache[i].type = 2;
00127                                 m_cache[i].end = last+1;
00128                         }
00129                 }
00130         }
00131 }
00132 
00133 inline unsigned int HuffmanDecoder::Decode(code_t code, /* out */ value_t &value) const
00134 {
00135         assert(m_codeToValue.size > 0);
00136         const LookupEntry &entry = m_cache[code & m_cacheMask];
00137         if (entry.type == 0)
00138         {
00139                 value = entry.value;
00140                 return entry.len;
00141         }
00142         else
00143         {
00144                 code_t normalizedCode = bitReverse(code);
00145                 const CodeInfo &codeInfo = (entry.type == 1)
00146                         ? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))]
00147                         : *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan)-1);
00148                 value = codeInfo.value;
00149                 return codeInfo.len;
00150         }
00151 }
00152 
00153 bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const
00154 {
00155         reader.FillBuffer(m_maxCodeBits);
00156         unsigned int codeBits = Decode(reader.PeekBuffer(), value);
00157         if (codeBits > reader.BitsBuffered())
00158                 return false;
00159         reader.SkipBits(codeBits);
00160         return true;
00161 }
00162 
00163 // *************************************************************
00164 
00165 Inflator::Inflator(BufferedTransformation *outQueue, bool repeat)
00166         : Filter(outQueue), m_repeat(repeat), m_decodersInitializedWithFixedCodes(false)
00167         , m_state(PRE_STREAM), m_reader(m_inQueue)
00168 {
00169 }
00170 
00171 inline void Inflator::OutputByte(byte b)
00172 {
00173         m_window[m_current++] = b;
00174         if (m_current == m_window.size)
00175         {
00176                 ProcessDecompressedData(m_window + m_lastFlush, m_window.size - m_lastFlush);
00177                 m_lastFlush = 0;
00178                 m_current = 0;
00179         }
00180         if (m_maxDistance < m_window.size)
00181                 m_maxDistance++;
00182 }
00183 
00184 void Inflator::OutputString(const byte *string, unsigned int length)
00185 {
00186         while (length--)
00187                 OutputByte(*string++);
00188 }
00189 
00190 void Inflator::OutputPast(unsigned int length, unsigned int distance)
00191 {
00192         if (distance > m_maxDistance)
00193                 throw BadBlockErr();
00194         unsigned int start;
00195         if (m_current > distance)
00196                 start = m_current - distance;
00197         else
00198                 start = m_current + m_window.size - distance;
00199 
00200         if (start + length > m_window.size)
00201         {
00202                 for (; start < m_window.size; start++, length--)
00203                         OutputByte(m_window[start]);
00204                 start = 0;
00205         }
00206 
00207         if (start + length > m_current || m_current + length >= m_window.size)
00208         {
00209                 while (length--)
00210                         OutputByte(m_window[start++]);
00211         }
00212         else
00213         {
00214                 memcpy(m_window + m_current, m_window + start, length);
00215                 m_current += length;
00216                 m_maxDistance = STDMIN(m_window.size, m_maxDistance + length);
00217         }
00218 }
00219 
00220 void Inflator::Put(const byte *inString, unsigned int length)
00221 {
00222         LazyPutter lp(m_inQueue, inString, length);
00223         ProcessInput(false);
00224 }
00225 
00226 void Inflator::Flush(bool completeFlush, int propagation)
00227 {
00228         if (completeFlush)
00229                 ProcessInput(true);
00230         FlushOutput();
00231         Filter::Flush(completeFlush, propagation);
00232 }
00233 
00234 void Inflator::MessageEnd(int propagation)
00235 {
00236         ProcessInput(true);
00237         if (!(m_state == PRE_STREAM || m_state == AFTER_END))
00238                 throw UnexpectedEndErr();
00239         Filter::MessageEnd(propagation);
00240 }
00241 
00242 void Inflator::ProcessInput(bool flush)
00243 {
00244         while (1)
00245         {
00246                 switch (m_state)
00247                 {
00248                 case PRE_STREAM:
00249                         if (!flush && m_inQueue.CurrentSize() < MaxPrestreamHeaderSize())
00250                                 return;
00251                         ProcessPrestreamHeader();
00252                         m_state = WAIT_HEADER;
00253                         m_maxDistance = 0;
00254                         m_current = 0;
00255                         m_lastFlush = 0;
00256                         m_window.Resize(1 << GetLog2WindowSize());
00257                         break;
00258                 case WAIT_HEADER:
00259                         {
00260                         // maximum number of bytes before actual compressed data starts
00261                         const unsigned int MAX_HEADER_SIZE = bitsToBytes(3+5+5+4+19*7+286*15+19*15);
00262                         if (m_inQueue.CurrentSize() < (flush ? 1 : MAX_HEADER_SIZE))
00263                                 return;
00264                         DecodeHeader();
00265                         break;
00266                         }
00267                 case DECODING_BODY:
00268                         if (!DecodeBody())
00269                                 return;
00270                         break;
00271                 case POST_STREAM:
00272                         if (!flush && m_inQueue.CurrentSize() < MaxPoststreamTailSize())
00273                                 return;
00274                         ProcessPoststreamTail();
00275                         m_state = m_repeat ? PRE_STREAM : AFTER_END;
00276                         Filter::MessageEnd(GetAutoSignalPropagation());
00277                         break;
00278                 case AFTER_END:
00279                         m_inQueue.TransferTo(*AttachedTransformation());
00280                         return;
00281                 }
00282         }
00283 }
00284 
00285 void Inflator::DecodeHeader()
00286 {
00287         if (!m_reader.FillBuffer(3))
00288                 throw UnexpectedEndErr();
00289         m_eof = m_reader.GetBits(1);
00290         m_blockType = m_reader.GetBits(2);
00291         switch (m_blockType)
00292         {
00293         case 0: // stored
00294                 {
00295                 m_reader.SkipBits(m_reader.BitsBuffered() % 8);
00296                 if (!m_reader.FillBuffer(32))
00297                         throw UnexpectedEndErr();
00298                 m_storedLen = m_reader.GetBits(16);
00299                 word16 nlen = m_reader.GetBits(16);
00300                 if (nlen != (word16)~m_storedLen)
00301                         throw BadBlockErr();
00302                 break;
00303                 }
00304         case 1: // fixed codes
00305                 if (!m_decodersInitializedWithFixedCodes)
00306                 {
00307                         unsigned int codeLengths[288];
00308                         std::fill(codeLengths + 0, codeLengths + 144, 8);
00309                         std::fill(codeLengths + 144, codeLengths + 256, 9);
00310                         std::fill(codeLengths + 256, codeLengths + 280, 7);
00311                         std::fill(codeLengths + 280, codeLengths + 288, 8);
00312                         m_literalDecoder.Initialize(codeLengths, 288);
00313                         std::fill(codeLengths + 0, codeLengths + 32, 5);
00314                         m_distanceDecoder.Initialize(codeLengths, 32);
00315                         m_decodersInitializedWithFixedCodes = true;
00316                 }
00317                 m_nextDecode = LITERAL;
00318                 break;
00319         case 2: // dynamic codes
00320                 {
00321                 m_decodersInitializedWithFixedCodes = false;
00322                 if (!m_reader.FillBuffer(5+5+4))
00323                         throw UnexpectedEndErr();
00324                 unsigned int hlit = m_reader.GetBits(5);
00325                 unsigned int hdist = m_reader.GetBits(5);
00326                 unsigned int hclen = m_reader.GetBits(4);
00327 
00328                 SecBlock<unsigned int> codeLengths(286+32);
00329                 unsigned int i;
00330                 static const unsigned int border[] = {    // Order of the bit length code lengths
00331                         16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
00332                 std::fill(codeLengths.ptr, codeLengths+19, 0);
00333                 for (i=0; i<hclen+4; i++)
00334                         codeLengths[border[i]] = m_reader.GetBits(3);
00335 
00336                 try
00337                 {
00338                         HuffmanDecoder codeLengthDecoder(codeLengths, 19);
00339                         for (i = 0; i < hlit+257+hdist+1; )
00340                         {
00341                                 unsigned int k, count, repeater;
00342                                 bool result = codeLengthDecoder.Decode(m_reader, k);
00343                                 if (!result)
00344                                         throw UnexpectedEndErr();
00345                                 if (k <= 15)
00346                                 {
00347                                         count = 1;
00348                                         repeater = k;
00349                                 }
00350                                 else switch (k)
00351                                 {
00352                                 case 16:
00353                                         if (!m_reader.FillBuffer(2))
00354                                                 throw UnexpectedEndErr();
00355                                         count = 3 + m_reader.GetBits(2);
00356                                         if (i == 0)
00357                                                 throw BadBlockErr();
00358                                         repeater = codeLengths[i-1];
00359                                         break;
00360                                 case 17:
00361                                         if (!m_reader.FillBuffer(3))
00362                                                 throw UnexpectedEndErr();
00363                                         count = 3 + m_reader.GetBits(3);
00364                                         repeater = 0;
00365                                         break;
00366                                 case 18:
00367                                         if (!m_reader.FillBuffer(7))
00368                                                 throw UnexpectedEndErr();
00369                                         count = 11 + m_reader.GetBits(7);
00370                                         repeater = 0;
00371                                         break;
00372                                 }
00373                                 if (i + count > hlit+257+hdist+1)
00374                                         throw BadBlockErr();
00375                                 std::fill(codeLengths + i, codeLengths + i + count, repeater);
00376                                 i += count;
00377                         }
00378                         m_literalDecoder.Initialize(codeLengths, hlit+257);
00379                         if (hdist == 0 && codeLengths[hlit+257] == 0)
00380                         {
00381                                 if (hlit != 0)  // a single zero distance code length means all literals
00382                                         throw BadBlockErr();
00383                         }
00384                         else
00385                                 m_distanceDecoder.Initialize(codeLengths+hlit+257, hdist+1);
00386                         m_nextDecode = LITERAL;
00387                 }
00388                 catch (HuffmanDecoder::Err &)
00389                 {
00390                         throw BadBlockErr();
00391                 }
00392                 break;
00393                 }
00394         default:
00395                 throw BadBlockErr();    // reserved block type
00396         }
00397         m_state = DECODING_BODY;
00398 }
00399 
00400 bool Inflator::DecodeBody()
00401 {
00402         bool blockEnd = false;
00403         switch (m_blockType)
00404         {
00405         case 0: // stored
00406                 assert(m_reader.BitsBuffered() == 0);
00407                 while (!m_inQueue.IsEmpty() && !blockEnd)
00408                 {
00409                         unsigned int size;
00410                         const byte *block = m_inQueue.Spy(size);
00411                         size = STDMIN(size, (unsigned int)m_storedLen);
00412                         OutputString(block, size);
00413                         m_inQueue.Skip(size);
00414                         m_storedLen -= size;
00415                         if (m_storedLen == 0)
00416                                 blockEnd = true;
00417                 }
00418                 break;
00419         case 1: // fixed codes
00420         case 2: // dynamic codes
00421                 static const unsigned int lengthStarts[] = {
00422                         3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
00423                         35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258};
00424                 static const unsigned int lengthExtraBits[] = {
00425                         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
00426                         3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0};
00427                 static const unsigned int distanceStarts[] = {
00428                         1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
00429                         257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145,
00430                         8193, 12289, 16385, 24577};
00431                 static const unsigned int distanceExtraBits[] = {
00432                         0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
00433                         7, 7, 8, 8, 9, 9, 10, 10, 11, 11,
00434                         12, 12, 13, 13};
00435 
00436                 switch (m_nextDecode)
00437                 {
00438                         while (true)
00439                         {
00440                 case LITERAL:
00441                                 if (!m_literalDecoder.Decode(m_reader, m_literal))
00442                                 {
00443                                         m_nextDecode = LITERAL;
00444                                         break;
00445                                 }
00446                                 if (m_literal < 256)
00447                                         OutputByte((byte)m_literal);
00448                                 else if (m_literal == 256)      // end of block
00449                                 {
00450                                         blockEnd = true;
00451                                         break;
00452                                 }
00453                                 else
00454                                 {
00455                                         if (m_literal > 285)
00456                                                 throw BadBlockErr();
00457                                         unsigned int bits;
00458                 case LENGTH_BITS:
00459                                         bits = lengthExtraBits[m_literal-257];
00460                                         if (!m_reader.FillBuffer(bits))
00461                                         {
00462                                                 m_nextDecode = LENGTH_BITS;
00463                                                 break;
00464                                         }
00465                                         m_literal = m_reader.GetBits(bits) + lengthStarts[m_literal-257];
00466                 case DISTANCE:
00467                                         if (!m_distanceDecoder.Decode(m_reader, m_distance))
00468                                         {
00469                                                 m_nextDecode = DISTANCE;
00470                                                 break;
00471                                         }
00472                 case DISTANCE_BITS:
00473                                         bits = distanceExtraBits[m_distance];
00474                                         if (!m_reader.FillBuffer(bits))
00475                                         {
00476                                                 m_nextDecode = DISTANCE_BITS;
00477                                                 break;
00478                                         }
00479                                         m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance];
00480                                         OutputPast(m_literal, m_distance);
00481                                 }
00482                         }
00483                 }
00484         }
00485         if (blockEnd)
00486         {
00487                 if (m_eof)
00488                 {
00489                         FlushOutput();
00490                         m_reader.SkipBits(m_reader.BitsBuffered()%8);
00491                         if (m_reader.BitsBuffered())
00492                         {
00493                                 // undo too much lookahead
00494                                 SecByteBlock buffer(m_reader.BitsBuffered() / 8);
00495                                 for (unsigned int i=0; i<buffer.size; i++)
00496                                         buffer[i] = m_reader.GetBits(8);
00497                                 m_inQueue.Unget(buffer, buffer.size);
00498                         }
00499                         m_state = POST_STREAM;
00500                 }
00501                 else
00502                         m_state = WAIT_HEADER;
00503         }
00504         return blockEnd;
00505 }
00506 
00507 void Inflator::FlushOutput()
00508 {
00509         assert(m_current >= m_lastFlush);
00510         ProcessDecompressedData(m_window + m_lastFlush, m_current - m_lastFlush);
00511         m_lastFlush = m_current;
00512 }
00513 
00514 NAMESPACE_END

Generated at Mon Jan 15 01:16:38 2001 for Crypto++ by doxygen1.2.4 written by Dimitri van Heesch, © 1997-2000