00001
00002
00003 #include "pch.h"
00004 #include "filters.h"
00005 #include "mqueue.h"
00006 #include <memory>
00007
00008 NAMESPACE_BEGIN(CryptoPP)
00009
00010 BitBucket g_bitBucket;
00011
00012 Filter::Filter(BufferedTransformation *outQ)
00013 : m_outQueue(outQ ? outQ : new MessageQueue)
00014 {
00015 }
00016
00017 void Filter::Detach(BufferedTransformation *newOut)
00018 {
00019 m_outQueue.reset(newOut ? newOut : new MessageQueue);
00020 NotifyAttachmentChange();
00021 }
00022
00023 void Filter::Insert(Filter *filter)
00024 {
00025 filter->m_outQueue.reset(m_outQueue.release());
00026 m_outQueue.reset(filter);
00027 NotifyAttachmentChange();
00028 }
00029
00030
00031
00032 FilterWithBufferedInput::BlockQueue::BlockQueue(unsigned int blockSize, unsigned int maxBlocks)
00033 : m_buffer(blockSize * maxBlocks)
00034 {
00035 ResetQueue(blockSize, maxBlocks);
00036 }
00037
00038 void FilterWithBufferedInput::BlockQueue::ResetQueue(unsigned int blockSize, unsigned int maxBlocks)
00039 {
00040 m_buffer.Resize(blockSize * maxBlocks);
00041 m_blockSize = blockSize;
00042 m_maxBlocks = maxBlocks;
00043 m_size = 0;
00044 m_begin = m_buffer;
00045 }
00046
00047 const byte *FilterWithBufferedInput::BlockQueue::GetBlock()
00048 {
00049 if (m_size >= m_blockSize)
00050 {
00051 const byte *ptr = m_begin;
00052 if ((m_begin+=m_blockSize) == m_buffer.End())
00053 m_begin = m_buffer;
00054 m_size -= m_blockSize;
00055 return ptr;
00056 }
00057 else
00058 return NULL;
00059 }
00060
00061 const byte *FilterWithBufferedInput::BlockQueue::GetContigousBlocks(unsigned int &numberOfBlocks)
00062 {
00063 numberOfBlocks = STDMIN(numberOfBlocks, STDMIN((unsigned int)(m_buffer.End()-m_begin), m_size)/m_blockSize);
00064 const byte *ptr = m_begin;
00065 if ((m_begin+=m_blockSize*numberOfBlocks) == m_buffer.End())
00066 m_begin = m_buffer;
00067 m_size -= m_blockSize*numberOfBlocks;
00068 return ptr;
00069 }
00070
00071 unsigned int FilterWithBufferedInput::BlockQueue::GetAll(byte *outString)
00072 {
00073 unsigned int size = m_size;
00074 unsigned int numberOfBlocks = m_maxBlocks;
00075 const byte *ptr = GetContigousBlocks(numberOfBlocks);
00076 memcpy(outString, ptr, numberOfBlocks*m_blockSize);
00077 memcpy(outString+numberOfBlocks*m_blockSize, m_begin, m_size);
00078 m_size = 0;
00079 return size;
00080 }
00081
00082 void FilterWithBufferedInput::BlockQueue::Put(const byte *inString, unsigned int length)
00083 {
00084 assert(m_size + length <= m_buffer.size);
00085 byte *end = (m_size < m_buffer+m_buffer.size-m_begin) ? m_begin + m_size : m_begin + m_size - m_buffer.size;
00086 unsigned int len = STDMIN(length, (unsigned int)(m_buffer+m_buffer.size-end));
00087 memcpy(end, inString, len);
00088 if (len < length)
00089 memcpy(m_buffer, inString+len, length-len);
00090 m_size += length;
00091 }
00092
00093 FilterWithBufferedInput::FilterWithBufferedInput(unsigned int firstSize, unsigned int blockSize, unsigned int lastSize, BufferedTransformation *outQ)
00094 : Filter(outQ), m_firstSize(firstSize), m_blockSize(blockSize), m_lastSize(lastSize)
00095 , m_firstInputDone(false)
00096 , m_queue(1, m_firstSize)
00097 {
00098 }
00099
00100 void FilterWithBufferedInput::Put(byte inByte)
00101 {
00102 Put(&inByte, 1);
00103 }
00104
00105 void FilterWithBufferedInput::Put(const byte *inString, unsigned int length)
00106 {
00107 if (length == 0)
00108 return;
00109
00110 unsigned int newLength = m_queue.CurrentSize() + length;
00111
00112 if (!m_firstInputDone && newLength >= m_firstSize)
00113 {
00114 unsigned int len = m_firstSize - m_queue.CurrentSize();
00115 m_queue.Put(inString, len);
00116 FirstPut(m_queue.GetContigousBlocks(m_firstSize));
00117 assert(m_queue.CurrentSize() == 0);
00118 m_queue.ResetQueue(m_blockSize, (2*m_blockSize+m_lastSize-2)/m_blockSize);
00119
00120 inString += len;
00121 newLength -= m_firstSize;
00122 m_firstInputDone = true;
00123 }
00124
00125 if (m_firstInputDone)
00126 {
00127 if (m_blockSize == 1)
00128 {
00129 while (newLength > m_lastSize && m_queue.CurrentSize() > 0)
00130 {
00131 unsigned int len = newLength - m_lastSize;
00132 const byte *ptr = m_queue.GetContigousBlocks(len);
00133 NextPut(ptr, len);
00134 newLength -= len;
00135 }
00136
00137 if (newLength > m_lastSize)
00138 {
00139 unsigned int len = newLength - m_lastSize;
00140 NextPut(inString, len);
00141 inString += len;
00142 newLength -= len;
00143 }
00144 }
00145 else
00146 {
00147 while (newLength >= m_blockSize + m_lastSize && m_queue.CurrentSize() >= m_blockSize)
00148 {
00149 NextPut(m_queue.GetBlock(), m_blockSize);
00150 newLength -= m_blockSize;
00151 }
00152
00153 if (newLength >= m_blockSize + m_lastSize && m_queue.CurrentSize() > 0)
00154 {
00155 assert(m_queue.CurrentSize() < m_blockSize);
00156 unsigned int len = m_blockSize - m_queue.CurrentSize();
00157 m_queue.Put(inString, len);
00158 inString += len;
00159 NextPut(m_queue.GetBlock(), m_blockSize);
00160 newLength -= m_blockSize;
00161 }
00162
00163 while (newLength >= m_blockSize + m_lastSize)
00164 {
00165 NextPut(inString, m_blockSize);
00166 inString += m_blockSize;
00167 newLength -= m_blockSize;
00168 }
00169 }
00170 }
00171
00172 m_queue.Put(inString, newLength - m_queue.CurrentSize());
00173 }
00174
00175 void FilterWithBufferedInput::MessageEnd(int propagation)
00176 {
00177 if (!m_firstInputDone && m_firstSize==0)
00178 FirstPut(NULL);
00179
00180 SecByteBlock temp(m_queue.CurrentSize());
00181 m_queue.GetAll(temp);
00182 LastPut(temp, temp.size);
00183
00184 m_firstInputDone = false;
00185 m_queue.ResetQueue(1, m_firstSize);
00186
00187 Filter::MessageEnd(propagation);
00188 }
00189
00190 void FilterWithBufferedInput::ForceNextPut()
00191 {
00192 if (m_firstInputDone && m_queue.CurrentSize() >= m_blockSize)
00193 NextPut(m_queue.GetBlock(), m_blockSize);
00194 }
00195
00196
00197
00198
00199
00200
00201
00202 ProxyFilter::ProxyFilter(Filter *filter, unsigned int firstSize, unsigned int lastSize, BufferedTransformation *outQ)
00203 : FilterWithBufferedInput(firstSize, 1, lastSize, outQ), m_filter(filter), m_proxy(NULL)
00204 {
00205 if (m_filter.get())
00206 m_filter->Attach(m_proxy = new OutputProxy(*this, false));
00207 }
00208
00209 void ProxyFilter::Flush(bool completeFlush, int propagation)
00210 {
00211 if (m_filter.get())
00212 {
00213 bool passSignal = m_proxy->GetPassSignal();
00214 m_proxy->SetPassSignal(false);
00215 m_filter->Flush(completeFlush, -1);
00216 m_proxy->SetPassSignal(passSignal);
00217 }
00218 Filter::Flush(completeFlush, propagation);
00219 }
00220
00221 void ProxyFilter::SetFilter(Filter *filter)
00222 {
00223 bool passSignal = m_proxy ? m_proxy->GetPassSignal() : false;
00224 m_filter.reset(filter);
00225 if (filter)
00226 {
00227 std::auto_ptr<OutputProxy> temp(m_proxy = new OutputProxy(*this, passSignal));
00228 m_filter->TransferAllTo(*m_proxy);
00229 m_filter->Attach(temp.release());
00230 }
00231 else
00232 m_proxy=NULL;
00233 }
00234
00235 void ProxyFilter::NextPut(const byte *s, unsigned int len)
00236 {
00237 if (m_filter.get())
00238 m_filter->Put(s, len);
00239 }
00240
00241
00242
00243 void StreamCipherFilter::Put(const byte *inString, unsigned int length)
00244 {
00245 SecByteBlock temp(length);
00246 cipher.ProcessString(temp, inString, length);
00247 AttachedTransformation()->Put(temp, length);
00248 }
00249
00250 void HashFilter::Put(byte inByte)
00251 {
00252 m_hashModule.Update(&inByte, 1);
00253 if (m_putMessage)
00254 AttachedTransformation()->Put(inByte);
00255 }
00256
00257 void HashFilter::Put(const byte *inString, unsigned int length)
00258 {
00259 m_hashModule.Update(inString, length);
00260 if (m_putMessage)
00261 AttachedTransformation()->Put(inString, length);
00262 }
00263
00264 void HashFilter::MessageEnd(int propagation)
00265 {
00266 SecByteBlock buf(m_hashModule.DigestSize());
00267 m_hashModule.Final(buf);
00268 AttachedTransformation()->Put(buf, buf.size);
00269 Filter::MessageEnd(propagation);
00270 }
00271
00272
00273
00274 HashVerifier::HashVerifier(HashModule &hm, BufferedTransformation *outQueue, word32 flags)
00275 : FilterWithBufferedInput(flags & HASH_AT_BEGIN ? hm.DigestSize() : 0, 1, flags & HASH_AT_BEGIN ? 0 : hm.DigestSize(), outQueue)
00276 , m_hashModule(hm), m_flags(flags)
00277 , m_expectedHash(flags & HASH_AT_BEGIN ? hm.DigestSize() : 0), m_verified(false)
00278 {
00279 }
00280
00281 void HashVerifier::FirstPut(const byte *inString)
00282 {
00283 if (m_flags & HASH_AT_BEGIN)
00284 {
00285 memcpy(m_expectedHash, inString, m_expectedHash.size);
00286 if (m_flags & PUT_HASH)
00287 AttachedTransformation()->Put(inString, m_expectedHash.size);
00288 }
00289 }
00290
00291 void HashVerifier::NextPut(const byte *inString, unsigned int length)
00292 {
00293 m_hashModule.Update(inString, length);
00294 if (m_flags & PUT_MESSAGE)
00295 AttachedTransformation()->Put(inString, length);
00296 }
00297
00298 void HashVerifier::LastPut(const byte *inString, unsigned int length)
00299 {
00300 if (m_flags & HASH_AT_BEGIN)
00301 {
00302 assert(length == 0);
00303 m_verified = m_hashModule.Verify(m_expectedHash);
00304 }
00305 else
00306 {
00307 m_verified = (length==m_hashModule.DigestSize() && m_hashModule.Verify(inString));
00308 if (m_flags & PUT_HASH)
00309 AttachedTransformation()->Put(inString, length);
00310 }
00311
00312 if (m_flags & PUT_RESULT)
00313 AttachedTransformation()->Put(m_verified);
00314
00315 if ((m_flags & THROW_EXCEPTION) && !m_verified)
00316 throw HashVerificationFailed();
00317 }
00318
00319
00320
00321 void SignerFilter::MessageEnd(int propagation)
00322 {
00323 SecByteBlock buf(m_signer.SignatureLength());
00324 m_signer.Sign(m_rng, m_messageAccumulator.release(), buf);
00325 AttachedTransformation()->Put(buf, buf.size);
00326 Filter::MessageEnd(propagation);
00327 m_messageAccumulator.reset(m_signer.NewMessageAccumulator());
00328 }
00329
00330 void VerifierFilter::PutSignature(const byte *sig)
00331 {
00332 memcpy(m_signature.ptr, sig, m_signature.size);
00333 }
00334
00335 void VerifierFilter::MessageEnd(int propagation)
00336 {
00337 AttachedTransformation()->Put((byte)m_verifier.Verify(m_messageAccumulator.release(), m_signature));
00338 Filter::MessageEnd(propagation);
00339 m_messageAccumulator.reset(m_verifier.NewMessageAccumulator());
00340 }
00341
00342
00343
00344 void Source::PumpAll()
00345 {
00346 while (PumpMessages()) {}
00347 while (Pump()) {}
00348 }
00349
00350 StringSource::StringSource(const char *string, bool pumpAll, BufferedTransformation *outQueue)
00351 : Source(outQueue), m_store(string)
00352 {
00353 if (pumpAll)
00354 PumpAll();
00355 }
00356
00357 StringSource::StringSource(const byte *string, unsigned int length, bool pumpAll, BufferedTransformation *outQueue)
00358 : Source(outQueue), m_store(string, length)
00359 {
00360 if (pumpAll)
00361 PumpAll();
00362 }
00363
00364 bool Store::GetNextMessage()
00365 {
00366 if (!m_messageEnd && !AnyRetrievable())
00367 {
00368 m_messageEnd=true;
00369 return true;
00370 }
00371 else
00372 return false;
00373 }
00374
00375 unsigned int Store::CopyMessagesTo(BufferedTransformation &target, unsigned int count) const
00376 {
00377 if (m_messageEnd || count == 0)
00378 return 0;
00379 else
00380 {
00381 CopyTo(target);
00382 if (GetAutoSignalPropagation())
00383 target.MessageEnd(GetAutoSignalPropagation()-1);
00384 return 1;
00385 }
00386 }
00387
00388 unsigned long StringStore::TransferTo(BufferedTransformation &target, unsigned long transferMax)
00389 {
00390 unsigned long result = CopyTo(target, transferMax);
00391 m_count += result;
00392 return result;
00393 }
00394
00395 unsigned long StringStore::CopyTo(BufferedTransformation &target, unsigned long copyMax) const
00396 {
00397 unsigned int len = (unsigned int)STDMIN((unsigned long)(m_length-m_count), copyMax);
00398 target.Put(m_store+m_count, len);
00399 return len;
00400 }
00401
00402 unsigned long RandomNumberStore::CopyTo(BufferedTransformation &target, unsigned long copyMax) const
00403 {
00404 unsigned int len = (unsigned int)STDMIN((unsigned long)(m_length-m_count), copyMax);
00405 for (unsigned int i=0; i<len; i++)
00406 target.Put(m_rng.GenerateByte());
00407 return len;
00408 }
00409
00410 unsigned long RandomNumberStore::TransferTo(BufferedTransformation &target, unsigned long transferMax)
00411 {
00412 unsigned long len = RandomNumberStore::CopyTo(target, transferMax);
00413 m_count += len;
00414 return len;
00415 }
00416
00417 RandomNumberSource::RandomNumberSource(RandomNumberGenerator &rng, unsigned int length, bool pumpAll, BufferedTransformation *outQueue)
00418 : Source(outQueue), m_store(rng, length)
00419 {
00420 if (pumpAll)
00421 PumpAll();
00422 }
00423
00424 NAMESPACE_END