Main Page   Class Hierarchy   Alphabetical List   Compound List   File List   Compound Members   File Members  

filters.cpp

00001 // filters.cpp - written and placed in the public domain by Wei Dai
00002 
00003 #include "pch.h"
00004 #include "filters.h"
00005 #include "mqueue.h"
00006 #include <memory>
00007 
00008 NAMESPACE_BEGIN(CryptoPP)
00009 
00010 BitBucket g_bitBucket;
00011 
00012 Filter::Filter(BufferedTransformation *outQ)
00013         : m_outQueue(outQ ? outQ : new MessageQueue)
00014 {
00015 }
00016 
00017 void Filter::Detach(BufferedTransformation *newOut)
00018 {
00019         m_outQueue.reset(newOut ? newOut : new MessageQueue);
00020         NotifyAttachmentChange();
00021 }
00022 
00023 void Filter::Insert(Filter *filter)
00024 {
00025         filter->m_outQueue.reset(m_outQueue.release());
00026         m_outQueue.reset(filter);
00027         NotifyAttachmentChange();
00028 }
00029 
00030 // *************************************************************
00031 
00032 FilterWithBufferedInput::BlockQueue::BlockQueue(unsigned int blockSize, unsigned int maxBlocks)
00033         : m_buffer(blockSize * maxBlocks)
00034 {
00035         ResetQueue(blockSize, maxBlocks);
00036 }
00037 
00038 void FilterWithBufferedInput::BlockQueue::ResetQueue(unsigned int blockSize, unsigned int maxBlocks)
00039 {
00040         m_buffer.Resize(blockSize * maxBlocks);
00041         m_blockSize = blockSize;
00042         m_maxBlocks = maxBlocks;
00043         m_size = 0;
00044         m_begin = m_buffer;
00045 }
00046 
00047 const byte *FilterWithBufferedInput::BlockQueue::GetBlock()
00048 {
00049         if (m_size >= m_blockSize)
00050         {
00051                 const byte *ptr = m_begin;
00052                 if ((m_begin+=m_blockSize) == m_buffer.End())
00053                         m_begin = m_buffer;
00054                 m_size -= m_blockSize;
00055                 return ptr;
00056         }
00057         else
00058                 return NULL;
00059 }
00060 
00061 const byte *FilterWithBufferedInput::BlockQueue::GetContigousBlocks(unsigned int &numberOfBlocks)
00062 {
00063         numberOfBlocks = STDMIN(numberOfBlocks, STDMIN((unsigned int)(m_buffer.End()-m_begin), m_size)/m_blockSize);
00064         const byte *ptr = m_begin;
00065         if ((m_begin+=m_blockSize*numberOfBlocks) == m_buffer.End())
00066                 m_begin = m_buffer;
00067         m_size -= m_blockSize*numberOfBlocks;
00068         return ptr;
00069 }
00070 
00071 unsigned int FilterWithBufferedInput::BlockQueue::GetAll(byte *outString)
00072 {
00073         unsigned int size = m_size;
00074         unsigned int numberOfBlocks = m_maxBlocks;
00075         const byte *ptr = GetContigousBlocks(numberOfBlocks);
00076         memcpy(outString, ptr, numberOfBlocks*m_blockSize);
00077         memcpy(outString+numberOfBlocks*m_blockSize, m_begin, m_size);
00078         m_size = 0;
00079         return size;
00080 }
00081 
00082 void FilterWithBufferedInput::BlockQueue::Put(const byte *inString, unsigned int length)
00083 {
00084         assert(m_size + length <= m_buffer.size);
00085         byte *end = (m_size < m_buffer+m_buffer.size-m_begin) ? m_begin + m_size : m_begin + m_size - m_buffer.size;
00086         unsigned int len = STDMIN(length, (unsigned int)(m_buffer+m_buffer.size-end));
00087         memcpy(end, inString, len);
00088         if (len < length)
00089                 memcpy(m_buffer, inString+len, length-len);
00090         m_size += length;
00091 }
00092 
00093 FilterWithBufferedInput::FilterWithBufferedInput(unsigned int firstSize, unsigned int blockSize, unsigned int lastSize, BufferedTransformation *outQ)
00094         : Filter(outQ), m_firstSize(firstSize), m_blockSize(blockSize), m_lastSize(lastSize)
00095         , m_firstInputDone(false)
00096         , m_queue(1, m_firstSize)
00097 {
00098 }
00099 
00100 void FilterWithBufferedInput::Put(byte inByte)
00101 {
00102         Put(&inByte, 1);
00103 }
00104 
00105 void FilterWithBufferedInput::Put(const byte *inString, unsigned int length)
00106 {
00107         if (length == 0)
00108                 return;
00109 
00110         unsigned int newLength = m_queue.CurrentSize() + length;
00111 
00112         if (!m_firstInputDone && newLength >= m_firstSize)
00113         {
00114                 unsigned int len = m_firstSize - m_queue.CurrentSize();
00115                 m_queue.Put(inString, len);
00116                 FirstPut(m_queue.GetContigousBlocks(m_firstSize));
00117                 assert(m_queue.CurrentSize() == 0);
00118                 m_queue.ResetQueue(m_blockSize, (2*m_blockSize+m_lastSize-2)/m_blockSize);
00119 
00120                 inString += len;
00121                 newLength -= m_firstSize;
00122                 m_firstInputDone = true;
00123         }
00124 
00125         if (m_firstInputDone)
00126         {
00127                 if (m_blockSize == 1)
00128                 {
00129                         while (newLength > m_lastSize && m_queue.CurrentSize() > 0)
00130                         {
00131                                 unsigned int len = newLength - m_lastSize;
00132                                 const byte *ptr = m_queue.GetContigousBlocks(len);
00133                                 NextPut(ptr, len);
00134                                 newLength -= len;
00135                         }
00136 
00137                         if (newLength > m_lastSize)
00138                         {
00139                                 unsigned int len = newLength - m_lastSize;
00140                                 NextPut(inString, len);
00141                                 inString += len;
00142                                 newLength -= len;
00143                         }
00144                 }
00145                 else
00146                 {
00147                         while (newLength >= m_blockSize + m_lastSize && m_queue.CurrentSize() >= m_blockSize)
00148                         {
00149                                 NextPut(m_queue.GetBlock(), m_blockSize);
00150                                 newLength -= m_blockSize;
00151                         }
00152 
00153                         if (newLength >= m_blockSize + m_lastSize && m_queue.CurrentSize() > 0)
00154                         {
00155                                 assert(m_queue.CurrentSize() < m_blockSize);
00156                                 unsigned int len = m_blockSize - m_queue.CurrentSize();
00157                                 m_queue.Put(inString, len);
00158                                 inString += len;
00159                                 NextPut(m_queue.GetBlock(), m_blockSize);
00160                                 newLength -= m_blockSize;
00161                         }
00162 
00163                         while (newLength >= m_blockSize + m_lastSize)
00164                         {
00165                                 NextPut(inString, m_blockSize);
00166                                 inString += m_blockSize;
00167                                 newLength -= m_blockSize;
00168                         }
00169                 }
00170         }
00171 
00172         m_queue.Put(inString, newLength - m_queue.CurrentSize());
00173 }
00174 
00175 void FilterWithBufferedInput::MessageEnd(int propagation)
00176 {
00177         if (!m_firstInputDone && m_firstSize==0)
00178                 FirstPut(NULL);
00179 
00180         SecByteBlock temp(m_queue.CurrentSize());
00181         m_queue.GetAll(temp);
00182         LastPut(temp, temp.size);
00183 
00184         m_firstInputDone = false;
00185         m_queue.ResetQueue(1, m_firstSize);
00186 
00187         Filter::MessageEnd(propagation);
00188 }
00189 
00190 void FilterWithBufferedInput::ForceNextPut()
00191 {
00192         if (m_firstInputDone && m_queue.CurrentSize() >= m_blockSize)
00193                 NextPut(m_queue.GetBlock(), m_blockSize);
00194 }
00195 
00196 // *************************************************************
00197 
00198 
00199 
00200 // *************************************************************
00201 
00202 ProxyFilter::ProxyFilter(Filter *filter, unsigned int firstSize, unsigned int lastSize, BufferedTransformation *outQ)
00203         : FilterWithBufferedInput(firstSize, 1, lastSize, outQ), m_filter(filter), m_proxy(NULL)
00204 {
00205         if (m_filter.get())
00206                 m_filter->Attach(m_proxy = new OutputProxy(*this, false));
00207 }
00208 
00209 void ProxyFilter::Flush(bool completeFlush, int propagation)
00210 {
00211         if (m_filter.get())
00212         {
00213                 bool passSignal = m_proxy->GetPassSignal();
00214                 m_proxy->SetPassSignal(false);
00215                 m_filter->Flush(completeFlush, -1);
00216                 m_proxy->SetPassSignal(passSignal);
00217         }
00218         Filter::Flush(completeFlush, propagation);
00219 }
00220 
00221 void ProxyFilter::SetFilter(Filter *filter)
00222 {
00223         bool passSignal = m_proxy ? m_proxy->GetPassSignal() : false;
00224         m_filter.reset(filter);
00225         if (filter)
00226         {
00227                 std::auto_ptr<OutputProxy> temp(m_proxy = new OutputProxy(*this, passSignal));
00228                 m_filter->TransferAllTo(*m_proxy);
00229                 m_filter->Attach(temp.release());
00230         }
00231         else
00232                 m_proxy=NULL;
00233 }
00234 
00235 void ProxyFilter::NextPut(const byte *s, unsigned int len) 
00236 {
00237         if (m_filter.get())
00238                 m_filter->Put(s, len);
00239 }
00240 
00241 // *************************************************************
00242 
00243 void StreamCipherFilter::Put(const byte *inString, unsigned int length)
00244 {
00245         SecByteBlock temp(length);
00246         cipher.ProcessString(temp, inString, length);
00247         AttachedTransformation()->Put(temp, length);
00248 }
00249 
00250 void HashFilter::Put(byte inByte)
00251 {
00252         m_hashModule.Update(&inByte, 1);
00253         if (m_putMessage)
00254                 AttachedTransformation()->Put(inByte);
00255 }
00256 
00257 void HashFilter::Put(const byte *inString, unsigned int length)
00258 {
00259         m_hashModule.Update(inString, length);
00260         if (m_putMessage)
00261                 AttachedTransformation()->Put(inString, length);
00262 }
00263 
00264 void HashFilter::MessageEnd(int propagation)
00265 {
00266         SecByteBlock buf(m_hashModule.DigestSize());
00267         m_hashModule.Final(buf);
00268         AttachedTransformation()->Put(buf, buf.size);
00269         Filter::MessageEnd(propagation);
00270 }
00271 
00272 // *************************************************************
00273 
00274 HashVerifier::HashVerifier(HashModule &hm, BufferedTransformation *outQueue, word32 flags)
00275         : FilterWithBufferedInput(flags & HASH_AT_BEGIN ? hm.DigestSize() : 0, 1, flags & HASH_AT_BEGIN ? 0 : hm.DigestSize(), outQueue)
00276         , m_hashModule(hm), m_flags(flags)
00277         , m_expectedHash(flags & HASH_AT_BEGIN ? hm.DigestSize() : 0), m_verified(false)
00278 {
00279 }
00280 
00281 void HashVerifier::FirstPut(const byte *inString)
00282 {
00283         if (m_flags & HASH_AT_BEGIN)
00284         {
00285                 memcpy(m_expectedHash, inString, m_expectedHash.size);
00286                 if (m_flags & PUT_HASH)
00287                         AttachedTransformation()->Put(inString, m_expectedHash.size);
00288         }
00289 }
00290 
00291 void HashVerifier::NextPut(const byte *inString, unsigned int length)
00292 {
00293         m_hashModule.Update(inString, length);
00294         if (m_flags & PUT_MESSAGE)
00295                 AttachedTransformation()->Put(inString, length);
00296 }
00297 
00298 void HashVerifier::LastPut(const byte *inString, unsigned int length)
00299 {
00300         if (m_flags & HASH_AT_BEGIN)
00301         {
00302                 assert(length == 0);
00303                 m_verified = m_hashModule.Verify(m_expectedHash);
00304         }
00305         else
00306         {
00307                 m_verified = (length==m_hashModule.DigestSize() && m_hashModule.Verify(inString));
00308                 if (m_flags & PUT_HASH)
00309                         AttachedTransformation()->Put(inString, length);
00310         }
00311 
00312         if (m_flags & PUT_RESULT)
00313                 AttachedTransformation()->Put(m_verified);
00314 
00315         if ((m_flags & THROW_EXCEPTION) && !m_verified)
00316                 throw HashVerificationFailed();
00317 }
00318 
00319 // *************************************************************
00320 
00321 void SignerFilter::MessageEnd(int propagation)
00322 {
00323         SecByteBlock buf(m_signer.SignatureLength());
00324         m_signer.Sign(m_rng, m_messageAccumulator.release(), buf);
00325         AttachedTransformation()->Put(buf, buf.size);
00326         Filter::MessageEnd(propagation);
00327         m_messageAccumulator.reset(m_signer.NewMessageAccumulator());
00328 }
00329 
00330 void VerifierFilter::PutSignature(const byte *sig)
00331 {
00332         memcpy(m_signature.ptr, sig, m_signature.size);
00333 }
00334 
00335 void VerifierFilter::MessageEnd(int propagation)
00336 {
00337         AttachedTransformation()->Put((byte)m_verifier.Verify(m_messageAccumulator.release(), m_signature));
00338         Filter::MessageEnd(propagation);
00339         m_messageAccumulator.reset(m_verifier.NewMessageAccumulator());
00340 }
00341 
00342 // *************************************************************
00343 
00344 void Source::PumpAll()
00345 {
00346         while (PumpMessages()) {}
00347         while (Pump()) {}
00348 }
00349 
00350 StringSource::StringSource(const char *string, bool pumpAll, BufferedTransformation *outQueue)
00351         : Source(outQueue), m_store(string)
00352 {
00353         if (pumpAll)
00354                 PumpAll();
00355 }
00356 
00357 StringSource::StringSource(const byte *string, unsigned int length, bool pumpAll, BufferedTransformation *outQueue)
00358         : Source(outQueue), m_store(string, length)
00359 {
00360         if (pumpAll)
00361                 PumpAll();
00362 }
00363 
00364 bool Store::GetNextMessage()
00365 {
00366         if (!m_messageEnd && !AnyRetrievable())
00367         {
00368                 m_messageEnd=true;
00369                 return true;
00370         }
00371         else
00372                 return false;
00373 }
00374 
00375 unsigned int Store::CopyMessagesTo(BufferedTransformation &target, unsigned int count) const
00376 {
00377         if (m_messageEnd || count == 0)
00378                 return 0;
00379         else
00380         {
00381                 CopyTo(target);
00382                 if (GetAutoSignalPropagation())
00383                         target.MessageEnd(GetAutoSignalPropagation()-1);
00384                 return 1;
00385         }
00386 }
00387 
00388 unsigned long StringStore::TransferTo(BufferedTransformation &target, unsigned long transferMax)
00389 {
00390         unsigned long result = CopyTo(target, transferMax);
00391         m_count += result;
00392         return result;
00393 }
00394 
00395 unsigned long StringStore::CopyTo(BufferedTransformation &target, unsigned long copyMax) const
00396 {
00397         unsigned int len = (unsigned int)STDMIN((unsigned long)(m_length-m_count), copyMax);
00398         target.Put(m_store+m_count, len);
00399         return len;
00400 }
00401 
00402 unsigned long RandomNumberStore::CopyTo(BufferedTransformation &target, unsigned long copyMax) const
00403 {
00404         unsigned int len = (unsigned int)STDMIN((unsigned long)(m_length-m_count), copyMax);
00405         for (unsigned int i=0; i<len; i++)
00406                 target.Put(m_rng.GenerateByte());
00407         return len;
00408 }
00409 
00410 unsigned long RandomNumberStore::TransferTo(BufferedTransformation &target, unsigned long transferMax)
00411 {
00412         unsigned long len = RandomNumberStore::CopyTo(target, transferMax);
00413         m_count += len;
00414         return len;
00415 }
00416 
00417 RandomNumberSource::RandomNumberSource(RandomNumberGenerator &rng, unsigned int length, bool pumpAll, BufferedTransformation *outQueue)
00418         : Source(outQueue), m_store(rng, length)
00419 {
00420         if (pumpAll)
00421                 PumpAll();
00422 }
00423 
00424 NAMESPACE_END

Generated at Mon Jan 15 01:16:32 2001 for Crypto++ by doxygen1.2.4 written by Dimitri van Heesch, © 1997-2000