00001
00002
00003 #include "pch.h"
00004 #include "zinflate.h"
00005
00006 NAMESPACE_BEGIN(CryptoPP)
00007
00008 inline bool LowFirstBitReader::FillBuffer(unsigned int length)
00009 {
00010 while (m_bitsBuffered < length)
00011 {
00012 byte b;
00013 if (!m_store.Get(b))
00014 return false;
00015 m_buffer |= (unsigned long)b << m_bitsBuffered;
00016 m_bitsBuffered += 8;
00017 }
00018 assert(m_bitsBuffered <= sizeof(unsigned long)*8);
00019 return true;
00020 }
00021
00022 inline unsigned long LowFirstBitReader::PeekBits(unsigned int length)
00023 {
00024 bool result = FillBuffer(length);
00025 assert(result);
00026 return m_buffer & (((unsigned long)1 << length) - 1);
00027 }
00028
00029 inline void LowFirstBitReader::SkipBits(unsigned int length)
00030 {
00031 assert(m_bitsBuffered >= length);
00032 m_buffer >>= length;
00033 m_bitsBuffered -= length;
00034 }
00035
00036 inline unsigned long LowFirstBitReader::GetBits(unsigned int length)
00037 {
00038 unsigned long result = PeekBits(length);
00039 SkipBits(length);
00040 return result;
00041 }
00042
00043 inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits)
00044 {
00045 return code << (MAX_CODE_BITS - codeBits);
00046 }
00047
00048 void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes)
00049 {
00050 if (nCodes == 0)
00051 throw Err("null code");
00052
00053 m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes);
00054
00055 if (m_maxCodeBits == 0)
00056 throw Err("null code");
00057
00058 SecBlock<unsigned int> blCount(m_maxCodeBits+1);
00059 std::fill(blCount.Begin(), blCount.End(), 0);
00060 unsigned int i;
00061 for (i=0; i<nCodes; i++)
00062 blCount[codeBits[i]]++;
00063
00064 code_t code = 0;
00065 SecBlock<code_t> nextCode(m_maxCodeBits+1);
00066 nextCode[1] = 0;
00067 for (i=2; i<=m_maxCodeBits; i++)
00068 {
00069
00070 if (code > code + blCount[i-1])
00071 throw Err("codes oversubscribed");
00072 code += blCount[i-1];
00073 if (code > (code << 1))
00074 throw Err("codes oversubscribed");
00075 code <<= 1;
00076 nextCode[i] = code;
00077 }
00078
00079 if (code > (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
00080 throw Err("codes oversubscribed");
00081 else if (m_maxCodeBits != 1 && code < (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
00082 throw Err("codes incomplete");
00083
00084 m_codeToValue.Resize(nCodes - blCount[0]);
00085 unsigned int j=0;
00086 for (i=0; i<nCodes; i++)
00087 {
00088 unsigned int len = codeBits[i];
00089 if (len != 0)
00090 {
00091 code = NormalizeCode(nextCode[len]++, len);
00092 m_codeToValue[j].code = code;
00093 m_codeToValue[j].len = len;
00094 m_codeToValue[j].value = i;
00095 j++;
00096 }
00097 }
00098 std::sort(m_codeToValue.Begin(), m_codeToValue.End());
00099
00100 m_cacheBits = STDMIN(9U, m_maxCodeBits);
00101 m_cacheMask = (1 << m_cacheBits) - 1;
00102 code_t leftoverMask = ~NormalizeCode(m_cacheMask, m_cacheBits);
00103 m_cache.Resize(1 << m_cacheBits);
00104
00105 for (i=0; i<m_cache.size; i++)
00106 {
00107 code_t normalizedCode = bitReverse(i);
00108 const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.Begin(), m_codeToValue.End(), normalizedCode, CodeLessThan)-1);
00109 if (codeInfo.len <= m_cacheBits)
00110 {
00111 m_cache[i].type = 0;
00112 m_cache[i].value = codeInfo.value;
00113 m_cache[i].len = codeInfo.len;
00114 }
00115 else
00116 {
00117 m_cache[i].begin = &codeInfo;
00118 const CodeInfo *last = std::upper_bound(m_codeToValue.Begin(), m_codeToValue.End(), normalizedCode + leftoverMask, CodeLessThan)-1;
00119 if (codeInfo.len == last->len)
00120 {
00121 m_cache[i].type = 1;
00122 m_cache[i].len = codeInfo.len;
00123 }
00124 else
00125 {
00126 m_cache[i].type = 2;
00127 m_cache[i].end = last+1;
00128 }
00129 }
00130 }
00131 }
00132
00133 inline unsigned int HuffmanDecoder::Decode(code_t code, value_t &value) const
00134 {
00135 assert(m_codeToValue.size > 0);
00136 const LookupEntry &entry = m_cache[code & m_cacheMask];
00137 if (entry.type == 0)
00138 {
00139 value = entry.value;
00140 return entry.len;
00141 }
00142 else
00143 {
00144 code_t normalizedCode = bitReverse(code);
00145 const CodeInfo &codeInfo = (entry.type == 1)
00146 ? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))]
00147 : *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan)-1);
00148 value = codeInfo.value;
00149 return codeInfo.len;
00150 }
00151 }
00152
00153 bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const
00154 {
00155 reader.FillBuffer(m_maxCodeBits);
00156 unsigned int codeBits = Decode(reader.PeekBuffer(), value);
00157 if (codeBits > reader.BitsBuffered())
00158 return false;
00159 reader.SkipBits(codeBits);
00160 return true;
00161 }
00162
00163
00164
00165 Inflator::Inflator(BufferedTransformation *outQueue, bool repeat)
00166 : Filter(outQueue), m_repeat(repeat), m_decodersInitializedWithFixedCodes(false)
00167 , m_state(PRE_STREAM), m_reader(m_inQueue)
00168 {
00169 }
00170
00171 inline void Inflator::OutputByte(byte b)
00172 {
00173 m_window[m_current++] = b;
00174 if (m_current == m_window.size)
00175 {
00176 ProcessDecompressedData(m_window + m_lastFlush, m_window.size - m_lastFlush);
00177 m_lastFlush = 0;
00178 m_current = 0;
00179 }
00180 if (m_maxDistance < m_window.size)
00181 m_maxDistance++;
00182 }
00183
00184 void Inflator::OutputString(const byte *string, unsigned int length)
00185 {
00186 while (length--)
00187 OutputByte(*string++);
00188 }
00189
00190 void Inflator::OutputPast(unsigned int length, unsigned int distance)
00191 {
00192 if (distance > m_maxDistance)
00193 throw BadBlockErr();
00194 unsigned int start;
00195 if (m_current > distance)
00196 start = m_current - distance;
00197 else
00198 start = m_current + m_window.size - distance;
00199
00200 if (start + length > m_window.size)
00201 {
00202 for (; start < m_window.size; start++, length--)
00203 OutputByte(m_window[start]);
00204 start = 0;
00205 }
00206
00207 if (start + length > m_current || m_current + length >= m_window.size)
00208 {
00209 while (length--)
00210 OutputByte(m_window[start++]);
00211 }
00212 else
00213 {
00214 memcpy(m_window + m_current, m_window + start, length);
00215 m_current += length;
00216 m_maxDistance = STDMIN(m_window.size, m_maxDistance + length);
00217 }
00218 }
00219
00220 void Inflator::Put(const byte *inString, unsigned int length)
00221 {
00222 LazyPutter lp(m_inQueue, inString, length);
00223 ProcessInput(false);
00224 }
00225
00226 void Inflator::Flush(bool completeFlush, int propagation)
00227 {
00228 if (completeFlush)
00229 ProcessInput(true);
00230 FlushOutput();
00231 Filter::Flush(completeFlush, propagation);
00232 }
00233
00234 void Inflator::MessageEnd(int propagation)
00235 {
00236 ProcessInput(true);
00237 if (!(m_state == PRE_STREAM || m_state == AFTER_END))
00238 throw UnexpectedEndErr();
00239 Filter::MessageEnd(propagation);
00240 }
00241
00242 void Inflator::ProcessInput(bool flush)
00243 {
00244 while (1)
00245 {
00246 switch (m_state)
00247 {
00248 case PRE_STREAM:
00249 if (!flush && m_inQueue.CurrentSize() < MaxPrestreamHeaderSize())
00250 return;
00251 ProcessPrestreamHeader();
00252 m_state = WAIT_HEADER;
00253 m_maxDistance = 0;
00254 m_current = 0;
00255 m_lastFlush = 0;
00256 m_window.Resize(1 << GetLog2WindowSize());
00257 break;
00258 case WAIT_HEADER:
00259 {
00260
00261 const unsigned int MAX_HEADER_SIZE = bitsToBytes(3+5+5+4+19*7+286*15+19*15);
00262 if (m_inQueue.CurrentSize() < (flush ? 1 : MAX_HEADER_SIZE))
00263 return;
00264 DecodeHeader();
00265 break;
00266 }
00267 case DECODING_BODY:
00268 if (!DecodeBody())
00269 return;
00270 break;
00271 case POST_STREAM:
00272 if (!flush && m_inQueue.CurrentSize() < MaxPoststreamTailSize())
00273 return;
00274 ProcessPoststreamTail();
00275 m_state = m_repeat ? PRE_STREAM : AFTER_END;
00276 Filter::MessageEnd(GetAutoSignalPropagation());
00277 break;
00278 case AFTER_END:
00279 m_inQueue.TransferTo(*AttachedTransformation());
00280 return;
00281 }
00282 }
00283 }
00284
00285 void Inflator::DecodeHeader()
00286 {
00287 if (!m_reader.FillBuffer(3))
00288 throw UnexpectedEndErr();
00289 m_eof = m_reader.GetBits(1);
00290 m_blockType = m_reader.GetBits(2);
00291 switch (m_blockType)
00292 {
00293 case 0:
00294 {
00295 m_reader.SkipBits(m_reader.BitsBuffered() % 8);
00296 if (!m_reader.FillBuffer(32))
00297 throw UnexpectedEndErr();
00298 m_storedLen = m_reader.GetBits(16);
00299 word16 nlen = m_reader.GetBits(16);
00300 if (nlen != (word16)~m_storedLen)
00301 throw BadBlockErr();
00302 break;
00303 }
00304 case 1:
00305 if (!m_decodersInitializedWithFixedCodes)
00306 {
00307 unsigned int codeLengths[288];
00308 std::fill(codeLengths + 0, codeLengths + 144, 8);
00309 std::fill(codeLengths + 144, codeLengths + 256, 9);
00310 std::fill(codeLengths + 256, codeLengths + 280, 7);
00311 std::fill(codeLengths + 280, codeLengths + 288, 8);
00312 m_literalDecoder.Initialize(codeLengths, 288);
00313 std::fill(codeLengths + 0, codeLengths + 32, 5);
00314 m_distanceDecoder.Initialize(codeLengths, 32);
00315 m_decodersInitializedWithFixedCodes = true;
00316 }
00317 m_nextDecode = LITERAL;
00318 break;
00319 case 2:
00320 {
00321 m_decodersInitializedWithFixedCodes = false;
00322 if (!m_reader.FillBuffer(5+5+4))
00323 throw UnexpectedEndErr();
00324 unsigned int hlit = m_reader.GetBits(5);
00325 unsigned int hdist = m_reader.GetBits(5);
00326 unsigned int hclen = m_reader.GetBits(4);
00327
00328 SecBlock<unsigned int> codeLengths(286+32);
00329 unsigned int i;
00330 static const unsigned int border[] = {
00331 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
00332 std::fill(codeLengths.ptr, codeLengths+19, 0);
00333 for (i=0; i<hclen+4; i++)
00334 codeLengths[border[i]] = m_reader.GetBits(3);
00335
00336 try
00337 {
00338 HuffmanDecoder codeLengthDecoder(codeLengths, 19);
00339 for (i = 0; i < hlit+257+hdist+1; )
00340 {
00341 unsigned int k, count, repeater;
00342 bool result = codeLengthDecoder.Decode(m_reader, k);
00343 if (!result)
00344 throw UnexpectedEndErr();
00345 if (k <= 15)
00346 {
00347 count = 1;
00348 repeater = k;
00349 }
00350 else switch (k)
00351 {
00352 case 16:
00353 if (!m_reader.FillBuffer(2))
00354 throw UnexpectedEndErr();
00355 count = 3 + m_reader.GetBits(2);
00356 if (i == 0)
00357 throw BadBlockErr();
00358 repeater = codeLengths[i-1];
00359 break;
00360 case 17:
00361 if (!m_reader.FillBuffer(3))
00362 throw UnexpectedEndErr();
00363 count = 3 + m_reader.GetBits(3);
00364 repeater = 0;
00365 break;
00366 case 18:
00367 if (!m_reader.FillBuffer(7))
00368 throw UnexpectedEndErr();
00369 count = 11 + m_reader.GetBits(7);
00370 repeater = 0;
00371 break;
00372 }
00373 if (i + count > hlit+257+hdist+1)
00374 throw BadBlockErr();
00375 std::fill(codeLengths + i, codeLengths + i + count, repeater);
00376 i += count;
00377 }
00378 m_literalDecoder.Initialize(codeLengths, hlit+257);
00379 if (hdist == 0 && codeLengths[hlit+257] == 0)
00380 {
00381 if (hlit != 0)
00382 throw BadBlockErr();
00383 }
00384 else
00385 m_distanceDecoder.Initialize(codeLengths+hlit+257, hdist+1);
00386 m_nextDecode = LITERAL;
00387 }
00388 catch (HuffmanDecoder::Err &)
00389 {
00390 throw BadBlockErr();
00391 }
00392 break;
00393 }
00394 default:
00395 throw BadBlockErr();
00396 }
00397 m_state = DECODING_BODY;
00398 }
00399
00400 bool Inflator::DecodeBody()
00401 {
00402 bool blockEnd = false;
00403 switch (m_blockType)
00404 {
00405 case 0:
00406 assert(m_reader.BitsBuffered() == 0);
00407 while (!m_inQueue.IsEmpty() && !blockEnd)
00408 {
00409 unsigned int size;
00410 const byte *block = m_inQueue.Spy(size);
00411 size = STDMIN(size, (unsigned int)m_storedLen);
00412 OutputString(block, size);
00413 m_inQueue.Skip(size);
00414 m_storedLen -= size;
00415 if (m_storedLen == 0)
00416 blockEnd = true;
00417 }
00418 break;
00419 case 1:
00420 case 2:
00421 static const unsigned int lengthStarts[] = {
00422 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
00423 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258};
00424 static const unsigned int lengthExtraBits[] = {
00425 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
00426 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0};
00427 static const unsigned int distanceStarts[] = {
00428 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
00429 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145,
00430 8193, 12289, 16385, 24577};
00431 static const unsigned int distanceExtraBits[] = {
00432 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
00433 7, 7, 8, 8, 9, 9, 10, 10, 11, 11,
00434 12, 12, 13, 13};
00435
00436 switch (m_nextDecode)
00437 {
00438 while (true)
00439 {
00440 case LITERAL:
00441 if (!m_literalDecoder.Decode(m_reader, m_literal))
00442 {
00443 m_nextDecode = LITERAL;
00444 break;
00445 }
00446 if (m_literal < 256)
00447 OutputByte((byte)m_literal);
00448 else if (m_literal == 256)
00449 {
00450 blockEnd = true;
00451 break;
00452 }
00453 else
00454 {
00455 if (m_literal > 285)
00456 throw BadBlockErr();
00457 unsigned int bits;
00458 case LENGTH_BITS:
00459 bits = lengthExtraBits[m_literal-257];
00460 if (!m_reader.FillBuffer(bits))
00461 {
00462 m_nextDecode = LENGTH_BITS;
00463 break;
00464 }
00465 m_literal = m_reader.GetBits(bits) + lengthStarts[m_literal-257];
00466 case DISTANCE:
00467 if (!m_distanceDecoder.Decode(m_reader, m_distance))
00468 {
00469 m_nextDecode = DISTANCE;
00470 break;
00471 }
00472 case DISTANCE_BITS:
00473 bits = distanceExtraBits[m_distance];
00474 if (!m_reader.FillBuffer(bits))
00475 {
00476 m_nextDecode = DISTANCE_BITS;
00477 break;
00478 }
00479 m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance];
00480 OutputPast(m_literal, m_distance);
00481 }
00482 }
00483 }
00484 }
00485 if (blockEnd)
00486 {
00487 if (m_eof)
00488 {
00489 FlushOutput();
00490 m_reader.SkipBits(m_reader.BitsBuffered()%8);
00491 if (m_reader.BitsBuffered())
00492 {
00493
00494 SecByteBlock buffer(m_reader.BitsBuffered() / 8);
00495 for (unsigned int i=0; i<buffer.size; i++)
00496 buffer[i] = m_reader.GetBits(8);
00497 m_inQueue.Unget(buffer, buffer.size);
00498 }
00499 m_state = POST_STREAM;
00500 }
00501 else
00502 m_state = WAIT_HEADER;
00503 }
00504 return blockEnd;
00505 }
00506
00507 void Inflator::FlushOutput()
00508 {
00509 assert(m_current >= m_lastFlush);
00510 ProcessDecompressedData(m_window + m_lastFlush, m_current - m_lastFlush);
00511 m_lastFlush = m_current;
00512 }
00513
00514 NAMESPACE_END