Main Page   Class Hierarchy   Alphabetical List   Compound List   File List   Compound Members   File Members  

integer.cpp

00001 // integer.cpp - written and placed in the public domain by Wei Dai
00002 
00003 #include "pch.h"
00004 #include "integer.h"
00005 #include "modarith.h"
00006 #include "nbtheory.h"
00007 #include "asn.h"
00008 #include "oids.h"
00009 #include "words.h"
00010 
00011 #include <iostream>
00012 
00013 #include "algebra.cpp"
00014 #include "eprecomp.cpp"
00015 
00016 NAMESPACE_BEGIN(CryptoPP)
00017 
00018 #define MAKE_DWORD(lowWord, highWord) ((dword(highWord)<<WORD_BITS) | (lowWord))
00019 
00020 // CodeWarrior defines _MSC_VER
00021 #if defined(_MSC_VER) && !defined(__MWERKS__) && defined(_M_IX86) && (_M_IX86<=600)
00022 
00023 // Add() and Subtract() are coded in Pentium assembly for a speed increase
00024 // of about 10-20 percent for a RSA signature
00025 
00026 static __declspec(naked) word __fastcall Add(word *C, const word *A, const word *B, unsigned int N)
00027 {
00028         __asm
00029         {
00030                 push ebp
00031                 push ebx
00032                 push esi
00033                 push edi
00034 
00035                 mov esi, [esp+24]       ; N
00036                 mov ebx, [esp+20]       ; B
00037 
00038                 sub ecx, edx
00039                 xor eax, eax
00040 
00041                 sub eax, esi
00042                 lea ebx, [ebx+4*esi]
00043 
00044                 sar eax, 1              // clears the carry flag
00045                 jz      loopend
00046 
00047 loopstart:
00048                 mov    esi,[edx]
00049                 mov    ebp,[edx+4]
00050 
00051                 mov    edi,[ebx+8*eax]
00052                 lea    edx,[edx+8]
00053 
00054                 adc    esi,edi
00055                 mov    edi,[ebx+8*eax+4]
00056 
00057                 adc    ebp,edi
00058                 inc    eax
00059 
00060                 mov    [edx+ecx-8],esi
00061                 mov    [edx+ecx-4],ebp
00062 
00063                 jnz    loopstart
00064 
00065 loopend:
00066                 adc eax, 0
00067                 pop edi
00068                 pop esi
00069                 pop ebx
00070                 pop ebp
00071                 ret 8
00072         }
00073 }
00074 
00075 static __declspec(naked) word __fastcall Subtract(word *C, const word *A, const word *B, unsigned int N)
00076 {
00077         __asm
00078         {
00079                 push ebp
00080                 push ebx
00081                 push esi
00082                 push edi
00083 
00084                 mov esi, [esp+24]       ; N
00085                 mov ebx, [esp+20]       ; B
00086 
00087                 sub ecx, edx
00088                 xor eax, eax
00089 
00090                 sub eax, esi
00091                 lea ebx, [ebx+4*esi]
00092 
00093                 sar eax, 1              // clears the carry flag
00094                 jz      loopend
00095 
00096 loopstart:
00097                 mov    esi,[edx]
00098                 mov    ebp,[edx+4]
00099 
00100                 mov    edi,[ebx+8*eax]
00101                 lea    edx,[edx+8]
00102 
00103                 sbb    esi,edi
00104                 mov    edi,[ebx+8*eax+4]
00105 
00106                 sbb    ebp,edi
00107                 inc    eax
00108 
00109                 mov    [edx+ecx-8],esi
00110                 mov    [edx+ecx-4],ebp
00111 
00112                 jnz    loopstart
00113 
00114 loopend:
00115                 adc eax, 0
00116                 pop edi
00117                 pop esi
00118                 pop ebx
00119                 pop ebp
00120                 ret 8
00121         }
00122 }
00123 
00124 #elif defined(__GNUC__) && defined(__i386__)
00125 
00126 static word Add(word *C, const word *A, const word *B, unsigned int N)
00127 {
00128         assert (N%2 == 0);
00129 
00130         register word carry;
00131 
00132         // Notes and further work (by Alister Lee): 
00133         // - get extended asm to accept parameter into ebx. Currently, the parameter
00134         //   is accepted into eax and moved to ebx resulting in an extra instruction 
00135         //   outside the loop. I think this is a bug in gcc.
00136         // - get extended asm to save and restore ebp through the clobbered list.
00137         //   I think this is a limitation of gcc.
00138 
00139         // on entry esi = N, edx = A, ecx = C, eax = B through extended asm (see below)
00140         __asm__(        
00141                                 "push %%ebp\n\t"                                        // can't automatically save ebp
00142                                 "mov %%eax, %%ebx\n\t"                          // ebx is B (can't automatically accept
00143                                                                                                         // parameter into ebx)  
00144                                 "sub %%edx, %%ecx\n\t"                          // hold the distance between C & A so 
00145                                                                                                         // we can add this to A to get C
00146                                 "xor %%eax, %%eax\n\t"
00147                                 "sub %%esi, %%eax\n\t"                          // eax is a negative index from end of B
00148                                 "lea (%%ebx,%%esi,4), %%ebx\n\t"        // ebx is end of B
00149                                 "sar $1, %%eax\n\t"                                     // eax is number of dwords
00150                                                                                                         // this also clears the carry flag
00151                                 "jz 1f\n"                                                       // to loopend
00152                                                                                                         // if no dwords then nothing to do
00153                         
00154                         "0:\n\t"                                                                // loopstart:
00155                                 "mov 0(%%edx), %%esi\n\t"                       // load next dword of A into ebp:esi
00156                                 "mov 4(%%edx), %%ebp\n\t"
00157                                 "mov (%%ebx,%%eax,8), %%edi\n\t"        // load next word of B, using eax as index
00158                                 "lea 8(%%edx), %%edx\n\t"                       // advance A
00159                                 "adc %%edi, %%esi\n\t"                          // add with carry
00160                                 "mov 4(%%ebx,%%eax,8), %%edi\n\t"       // load next word of B, using eax as index
00161                                 "adc %%edi, %%ebp\n\t"                          // add with carry
00162                                 "inc %%eax\n\t"                                         // advance index into B
00163                                                                                                         // no more words when zero
00164                                 "mov %%esi, -8(%%edx,%%ecx)\n\t"        // store ebp:esi into next dword of C 
00165                                 "mov %%ebp, -4(%%edx,%%ecx)\n\t"        
00166                                 "jnz 0b\n"                                                      // to loopstart
00167                                                                                                         // carry flag feeds into next iteration
00168                         
00169                         "1:\n\t"                                                                // loopend:
00170                                 "adc $0, %%eax\n\t"                                     // capture carry flag
00171                                 "pop %%ebp"                                                             
00172                                                         
00173                         : "=a" (carry)
00174                         : "S" (N), "d" (A), "c" (C), "a" (B)
00175                         : "%edi", "%ebx"
00176                         );
00177                          
00178         return carry;            
00179 }
00180 
00181 static word Subtract(word *C, const word *A, const word *B, unsigned int N)
00182 {
00183         assert (N%2 == 0);
00184 
00185         register word carry;
00186 
00187         // Notes: see notes on Add above
00188         
00189         __asm__(
00190                                 "push %%ebp\n\t"
00191                                 "mov %%eax, %%ebx\n\t"
00192                                 "sub %%edx, %%ecx\n\t"
00193                                 "xor %%eax, %%eax\n\t"
00194                                 "sub %%esi, %%eax\n\t"
00195                                 "lea (%%ebx,%%esi,4), %%ebx\n\t"
00196                                 "sar $1, %%eax\n\t"             
00197                                 "jz 1f\n"
00198 
00199                         "0:\n\t"
00200                                 "mov 0(%%edx), %%esi\n\t"
00201                                 "mov 4(%%edx), %%ebp\n\t"
00202                                 "mov (%%ebx,%%eax,8), %%edi\n\t"
00203                                 "lea 8(%%edx), %%edx\n\t"
00204                                 "sbb %%edi, %%esi\n\t"
00205                                 "mov 4(%%ebx,%%eax,8), %%edi\n\t"
00206                                 "sbb %%edi, %%ebp\n\t"
00207                                 "inc %%eax\n\t"
00208                                 "mov %%esi, -8(%%edx, %%ecx)\n\t"
00209                                 "mov %%ebp, -4(%%edx, %%ecx)\n\t"
00210                                 "jnz 0b\n"
00211 
00212                         "1:\n\t"
00213                                 "adc $0, %%eax\n\t"
00214                                 "pop %%ebp"
00215                 : "=a" (carry)
00216                 : "S" (N), "d" (A), "c" (C), "a" (B)
00217                 : "%edi", "%ebx"
00218         );
00219 
00220         return carry;
00221 }
00222 
00223 #else   // defined(_MSC_VER) && !defined(__MWERKS__) && defined(_M_IX86) && (_M_IX86<=600)
00224 
00225 static word Add(word *C, const word *A, const word *B, unsigned int N)
00226 {
00227         assert (N%2 == 0);
00228 
00229         word carry=0;
00230         for (unsigned i = 0; i < N; i+=2)
00231         {
00232                 dword u = (dword) carry + A[i] + B[i];
00233                 C[i] = LOW_WORD(u);
00234                 u = (dword) HIGH_WORD(u) + A[i+1] + B[i+1];
00235                 C[i+1] = LOW_WORD(u);
00236                 carry = HIGH_WORD(u);
00237         }
00238         return carry;
00239 }
00240 
00241 static word Subtract(word *C, const word *A, const word *B, unsigned int N)
00242 {
00243         assert (N%2 == 0);
00244 
00245         word borrow=0;
00246         for (unsigned i = 0; i < N; i+=2)
00247         {
00248                 dword u = (dword) A[i] - B[i] - borrow;
00249                 C[i] = LOW_WORD(u);
00250                 u = (dword) A[i+1] - B[i+1] - (word)(0-HIGH_WORD(u));
00251                 C[i+1] = LOW_WORD(u);
00252                 borrow = 0-HIGH_WORD(u);
00253         }
00254         return borrow;
00255 }
00256 
00257 #endif  // defined(_MSC_VER) && !defined(__MWERKS__) && defined(_M_IX86) && (_M_IX86<=600)
00258 
00259 static int Compare(const word *A, const word *B, unsigned int N)
00260 {
00261         while (N--)
00262                 if (A[N] > B[N])
00263                         return 1;
00264                 else if (A[N] < B[N])
00265                         return -1;
00266 
00267         return 0;
00268 }
00269 
00270 static word Increment(word *A, unsigned int N, word B=1)
00271 {
00272         assert(N);
00273         word t = A[0];
00274         A[0] = t+B;
00275         if (A[0] >= t)
00276                 return 0;
00277         for (unsigned i=1; i<N; i++)
00278                 if (++A[i])
00279                         return 0;
00280         return 1;
00281 }
00282 
00283 static word Decrement(word *A, unsigned int N, word B=1)
00284 {
00285         assert(N);
00286         word t = A[0];
00287         A[0] = t-B;
00288         if (A[0] <= t)
00289                 return 0;
00290         for (unsigned i=1; i<N; i++)
00291                 if (A[i]--)
00292                         return 0;
00293         return 1;
00294 }
00295 
00296 static void TwosComplement(word *A, unsigned int N)
00297 {
00298         Decrement(A, N);
00299         for (unsigned i=0; i<N; i++)
00300                 A[i] = ~A[i];
00301 }
00302 
00303 static word LinearMultiply(word *C, const word *A, word B, unsigned int N)
00304 {
00305         word carry=0;
00306         for(unsigned i=0; i<N; i++)
00307         {
00308                 dword p = (dword)A[i] * B + carry;
00309                 C[i] = LOW_WORD(p);
00310                 carry = HIGH_WORD(p);
00311         }
00312         return carry;
00313 }
00314 
00315 static void AtomicMultiply(word *C, word A0, word A1, word B0, word B1)
00316 {
00317 /*
00318         word s;
00319         dword d;
00320 
00321         if (A1 >= A0)
00322                 if (B0 >= B1)
00323                 {
00324                         s = 0;
00325                         d = (dword)(A1-A0)*(B0-B1);
00326                 }
00327                 else
00328                 {
00329                         s = (A1-A0);
00330                         d = (dword)s*(word)(B0-B1);
00331                 }
00332         else
00333                 if (B0 > B1)
00334                 {
00335                         s = (B0-B1);
00336                         d = (word)(A1-A0)*(dword)s;
00337                 }
00338                 else
00339                 {
00340                         s = 0;
00341                         d = (dword)(A0-A1)*(B1-B0);
00342                 }
00343 */
00344         // this segment is the branchless equivalent of above
00345         word D[4] = {A1-A0, A0-A1, B0-B1, B1-B0};
00346         unsigned int ai = A1 < A0;
00347         unsigned int bi = B0 < B1;
00348         unsigned int di = ai & bi;
00349         dword d = (dword)D[di]*D[di+2];
00350         D[1] = D[3] = 0;
00351         unsigned int si = ai + !bi;
00352         word s = D[si];
00353 
00354         dword A0B0 = (dword)A0*B0;
00355         C[0] = LOW_WORD(A0B0);
00356 
00357         dword A1B1 = (dword)A1*B1;
00358         dword t = (dword) HIGH_WORD(A0B0) + LOW_WORD(A0B0) + LOW_WORD(d) + LOW_WORD(A1B1);
00359         C[1] = LOW_WORD(t);
00360 
00361         t = A1B1 + HIGH_WORD(t) + HIGH_WORD(A0B0) + HIGH_WORD(d) + HIGH_WORD(A1B1) - s;
00362         C[2] = LOW_WORD(t);
00363         C[3] = HIGH_WORD(t);
00364 }
00365 
00366 static word AtomicMultiplyAdd(word *C, word A0, word A1, word B0, word B1)
00367 {
00368         word D[4] = {A1-A0, A0-A1, B0-B1, B1-B0};
00369         unsigned int ai = A1 < A0;
00370         unsigned int bi = B0 < B1;
00371         unsigned int di = ai & bi;
00372         dword d = (dword)D[di]*D[di+2];
00373         D[1] = D[3] = 0;
00374         unsigned int si = ai + !bi;
00375         word s = D[si];
00376 
00377         dword A0B0 = (dword)A0*B0;
00378         dword t = A0B0 + C[0];
00379         C[0] = LOW_WORD(t);
00380 
00381         dword A1B1 = (dword)A1*B1;
00382         t = (dword) HIGH_WORD(t) + LOW_WORD(A0B0) + LOW_WORD(d) + LOW_WORD(A1B1) + C[1];
00383         C[1] = LOW_WORD(t);
00384 
00385         t = (dword) HIGH_WORD(t) + LOW_WORD(A1B1) + HIGH_WORD(A0B0) + HIGH_WORD(d) + HIGH_WORD(A1B1) - s + C[2];
00386         C[2] = LOW_WORD(t);
00387 
00388         t = (dword) HIGH_WORD(t) + HIGH_WORD(A1B1) + C[3];
00389         C[3] = LOW_WORD(t);
00390         return HIGH_WORD(t);
00391 }
00392 
00393 static inline void AtomicSquare(word *C, word A, word B)
00394 {
00395         dword t1 = (dword) A*A;
00396         C[0] = LOW_WORD(t1);
00397 
00398         dword t2 = (dword) A*B;
00399         t1 = (dword) HIGH_WORD(t1) + LOW_WORD(t2) + LOW_WORD(t2);
00400         C[1] = LOW_WORD(t1);
00401 
00402         t1 = (dword) B*B + HIGH_WORD(t1) + HIGH_WORD(t2) + HIGH_WORD(t2);
00403         C[2] = LOW_WORD(t1);
00404         C[3] = HIGH_WORD(t1);
00405 }
00406 
00407 static inline void AtomicMultiplyBottom(word *C, word A0, word A1, word B0, word B1)
00408 {
00409         dword t = (dword)A0*B0;
00410         C[0] = LOW_WORD(t);
00411         C[1] = HIGH_WORD(t) + A0*B1 + A1*B0;
00412 }
00413 
00414 #define MulAcc(x, y)                                                            \
00415         p = (dword)A[x] * B[y] + c;                                     \
00416         c = LOW_WORD(p);                                                                \
00417         p = (dword)d + HIGH_WORD(p);                                    \
00418         d = LOW_WORD(p);                                                                \
00419         e += HIGH_WORD(p);
00420 
00421 #define SaveMulAcc(s, x, y)                                             \
00422         R[s] = c;                                                                               \
00423         p = (dword)A[x] * B[y] + d;                                     \
00424         c = LOW_WORD(p);                                                                \
00425         p = (dword)e + HIGH_WORD(p);                                    \
00426         d = LOW_WORD(p);                                                                \
00427         e = HIGH_WORD(p);
00428 
00429 #define MulAcc1(x, y)                                                           \
00430         p = (dword)A[x] * A[y] + c;                                     \
00431         c = LOW_WORD(p);                                                                \
00432         p = (dword)d + HIGH_WORD(p);                                    \
00433         d = LOW_WORD(p);                                                                \
00434         e += HIGH_WORD(p);
00435 
00436 #define SaveMulAcc1(s, x, y)                                            \
00437         R[s] = c;                                                                               \
00438         p = (dword)A[x] * A[y] + d;                                     \
00439         c = LOW_WORD(p);                                                                \
00440         p = (dword)e + HIGH_WORD(p);                                    \
00441         d = LOW_WORD(p);                                                                \
00442         e = HIGH_WORD(p);
00443 
00444 #define SquAcc(x, y)                                                            \
00445         p = (dword)A[x] * A[y]; \
00446         p = p + p + c;                                  \
00447         c = LOW_WORD(p);                                                                \
00448         p = (dword)d + HIGH_WORD(p);                                    \
00449         d = LOW_WORD(p);                                                                \
00450         e += HIGH_WORD(p);
00451 
00452 #define SaveSquAcc(s, x, y)                                             \
00453         R[s] = c;                                                                               \
00454         p = (dword)A[x] * A[y]; \
00455         p = p + p + d;                                  \
00456         c = LOW_WORD(p);                                                                \
00457         p = (dword)e + HIGH_WORD(p);                                    \
00458         d = LOW_WORD(p);                                                                \
00459         e = HIGH_WORD(p);
00460 
00461 static void CombaSquare4(word *R, const word *A)
00462 {
00463         dword p;
00464         word c, d, e;
00465 
00466         p = (dword)A[0] * A[0];
00467         R[0] = LOW_WORD(p);
00468         c = HIGH_WORD(p);
00469         d = e = 0;
00470 
00471         SquAcc(0, 1);
00472 
00473         SaveSquAcc(1, 2, 0);
00474         MulAcc1(1, 1);
00475 
00476         SaveSquAcc(2, 0, 3);
00477         SquAcc(1, 2);
00478 
00479         SaveSquAcc(3, 3, 1);
00480         MulAcc1(2, 2);
00481 
00482         SaveSquAcc(4, 2, 3);
00483 
00484         R[5] = c;
00485         p = (dword)A[3] * A[3] + d;
00486         R[6] = LOW_WORD(p);
00487         R[7] = e + HIGH_WORD(p);
00488 }
00489 
00490 static void CombaMultiply4(word *R, const word *A, const word *B)
00491 {
00492         dword p;
00493         word c, d, e;
00494 
00495         p = (dword)A[0] * B[0];
00496         R[0] = LOW_WORD(p);
00497         c = HIGH_WORD(p);
00498         d = e = 0;
00499 
00500         MulAcc(0, 1);
00501         MulAcc(1, 0);
00502 
00503         SaveMulAcc(1, 2, 0);
00504         MulAcc(1, 1);
00505         MulAcc(0, 2);
00506 
00507         SaveMulAcc(2, 0, 3);
00508         MulAcc(1, 2);
00509         MulAcc(2, 1);
00510         MulAcc(3, 0);
00511 
00512         SaveMulAcc(3, 3, 1);
00513         MulAcc(2, 2);
00514         MulAcc(1, 3);
00515 
00516         SaveMulAcc(4, 2, 3);
00517         MulAcc(3, 2);
00518 
00519         R[5] = c;
00520         p = (dword)A[3] * B[3] + d;
00521         R[6] = LOW_WORD(p);
00522         R[7] = e + HIGH_WORD(p);
00523 }
00524 
00525 static void CombaMultiply8(word *R, const word *A, const word *B)
00526 {
00527         dword p;
00528         word c, d, e;
00529 
00530         p = (dword)A[0] * B[0];
00531         R[0] = LOW_WORD(p);
00532         c = HIGH_WORD(p);
00533         d = e = 0;
00534 
00535         MulAcc(0, 1);
00536         MulAcc(1, 0);
00537 
00538         SaveMulAcc(1, 2, 0);
00539         MulAcc(1, 1);
00540         MulAcc(0, 2);
00541 
00542         SaveMulAcc(2, 0, 3);
00543         MulAcc(1, 2);
00544         MulAcc(2, 1);
00545         MulAcc(3, 0);
00546 
00547         SaveMulAcc(3, 0, 4);
00548         MulAcc(1, 3);
00549         MulAcc(2, 2);
00550         MulAcc(3, 1);
00551         MulAcc(4, 0);
00552 
00553         SaveMulAcc(4, 0, 5);
00554         MulAcc(1, 4);
00555         MulAcc(2, 3);
00556         MulAcc(3, 2);
00557         MulAcc(4, 1);
00558         MulAcc(5, 0);
00559 
00560         SaveMulAcc(5, 0, 6);
00561         MulAcc(1, 5);
00562         MulAcc(2, 4);
00563         MulAcc(3, 3);
00564         MulAcc(4, 2);
00565         MulAcc(5, 1);
00566         MulAcc(6, 0);
00567 
00568         SaveMulAcc(6, 0, 7);
00569         MulAcc(1, 6);
00570         MulAcc(2, 5);
00571         MulAcc(3, 4);
00572         MulAcc(4, 3);
00573         MulAcc(5, 2);
00574         MulAcc(6, 1);
00575         MulAcc(7, 0);
00576 
00577         SaveMulAcc(7, 1, 7);
00578         MulAcc(2, 6);
00579         MulAcc(3, 5);
00580         MulAcc(4, 4);
00581         MulAcc(5, 3);
00582         MulAcc(6, 2);
00583         MulAcc(7, 1);
00584 
00585         SaveMulAcc(8, 2, 7);
00586         MulAcc(3, 6);
00587         MulAcc(4, 5);
00588         MulAcc(5, 4);
00589         MulAcc(6, 3);
00590         MulAcc(7, 2);
00591 
00592         SaveMulAcc(9, 3, 7);
00593         MulAcc(4, 6);
00594         MulAcc(5, 5);
00595         MulAcc(6, 4);
00596         MulAcc(7, 3);
00597 
00598         SaveMulAcc(10, 4, 7);
00599         MulAcc(5, 6);
00600         MulAcc(6, 5);
00601         MulAcc(7, 4);
00602 
00603         SaveMulAcc(11, 5, 7);
00604         MulAcc(6, 6);
00605         MulAcc(7, 5);
00606 
00607         SaveMulAcc(12, 6, 7);
00608         MulAcc(7, 6);
00609 
00610         R[13] = c;
00611         p = (dword)A[7] * B[7] + d;
00612         R[14] = LOW_WORD(p);
00613         R[15] = e + HIGH_WORD(p);
00614 }
00615 
00616 static void CombaMultiplyBottom4(word *R, const word *A, const word *B)
00617 {
00618         dword p;
00619         word c, d, e;
00620 
00621         p = (dword)A[0] * B[0];
00622         R[0] = LOW_WORD(p);
00623         c = HIGH_WORD(p);
00624         d = e = 0;
00625 
00626         MulAcc(0, 1);
00627         MulAcc(1, 0);
00628 
00629         SaveMulAcc(1, 2, 0);
00630         MulAcc(1, 1);
00631         MulAcc(0, 2);
00632 
00633         R[2] = c;
00634         R[3] = d + A[0] * B[3] + A[1] * B[2] + A[2] * B[1] + A[3] * B[0];
00635 }
00636 
00637 static void CombaMultiplyBottom8(word *R, const word *A, const word *B)
00638 {
00639         dword p;
00640         word c, d, e;
00641 
00642         p = (dword)A[0] * B[0];
00643         R[0] = LOW_WORD(p);
00644         c = HIGH_WORD(p);
00645         d = e = 0;
00646 
00647         MulAcc(0, 1);
00648         MulAcc(1, 0);
00649 
00650         SaveMulAcc(1, 2, 0);
00651         MulAcc(1, 1);
00652         MulAcc(0, 2);
00653 
00654         SaveMulAcc(2, 0, 3);
00655         MulAcc(1, 2);
00656         MulAcc(2, 1);
00657         MulAcc(3, 0);
00658 
00659         SaveMulAcc(3, 0, 4);
00660         MulAcc(1, 3);
00661         MulAcc(2, 2);
00662         MulAcc(3, 1);
00663         MulAcc(4, 0);
00664 
00665         SaveMulAcc(4, 0, 5);
00666         MulAcc(1, 4);
00667         MulAcc(2, 3);
00668         MulAcc(3, 2);
00669         MulAcc(4, 1);
00670         MulAcc(5, 0);
00671 
00672         SaveMulAcc(5, 0, 6);
00673         MulAcc(1, 5);
00674         MulAcc(2, 4);
00675         MulAcc(3, 3);
00676         MulAcc(4, 2);
00677         MulAcc(5, 1);
00678         MulAcc(6, 0);
00679 
00680         R[6] = c;
00681         R[7] = d + A[0] * B[7] + A[1] * B[6] + A[2] * B[5] + A[3] * B[4] +
00682                                 A[4] * B[3] + A[5] * B[2] + A[6] * B[1] + A[7] * B[0];
00683 }
00684 
00685 #undef MulAcc
00686 #undef SaveMulAcc
00687 
00688 static void AtomicInverseModPower2(word *C, word A0, word A1)
00689 {
00690         assert(A0%2==1);
00691 
00692         dword A=MAKE_DWORD(A0, A1), R=A0%8;
00693 
00694         for (unsigned i=3; i<2*WORD_BITS; i*=2)
00695                 R = R*(2-R*A);
00696 
00697         assert(R*A==1);
00698 
00699         C[0] = LOW_WORD(R);
00700         C[1] = HIGH_WORD(R);
00701 }
00702 
00703 // ********************************************************
00704 
00705 #define A0              A
00706 #define A1              (A+N2)
00707 #define B0              B
00708 #define B1              (B+N2)
00709 
00710 #define T0              T
00711 #define T1              (T+N2)
00712 #define T2              (T+N)
00713 #define T3              (T+N+N2)
00714 
00715 #define R0              R
00716 #define R1              (R+N2)
00717 #define R2              (R+N)
00718 #define R3              (R+N+N2)
00719 
00720 // R[2*N] - result = A*B
00721 // T[2*N] - temporary work space
00722 // A[N] --- multiplier
00723 // B[N] --- multiplicant
00724 
00725 void RecursiveMultiply(word *R, word *T, const word *A, const word *B, unsigned int N)
00726 {
00727         assert(N>=2 && N%2==0);
00728 
00729         if (N==2)
00730                 AtomicMultiply(R, A[0], A[1], B[0], B[1]);
00731         else if (N==4)
00732                 CombaMultiply4(R, A, B);
00733         else if (N==8)
00734                 CombaMultiply8(R, A, B);
00735         else
00736         {
00737                 const unsigned int N2 = N/2;
00738                 int carry;
00739 
00740                 int aComp = Compare(A0, A1, N2);
00741                 int bComp = Compare(B0, B1, N2);
00742 
00743                 switch (2*aComp + aComp + bComp)
00744                 {
00745                 case -4:
00746                         Subtract(R0, A1, A0, N2);
00747                         Subtract(R1, B0, B1, N2);
00748                         RecursiveMultiply(T0, T2, R0, R1, N2);
00749                         Subtract(T1, T1, R0, N2);
00750                         carry = -1;
00751                         break;
00752                 case -2:
00753                         Subtract(R0, A1, A0, N2);
00754                         Subtract(R1, B0, B1, N2);
00755                         RecursiveMultiply(T0, T2, R0, R1, N2);
00756                         carry = 0;
00757                         break;
00758                 case 2:
00759                         Subtract(R0, A0, A1, N2);
00760                         Subtract(R1, B1, B0, N2);
00761                         RecursiveMultiply(T0, T2, R0, R1, N2);
00762                         carry = 0;
00763                         break;
00764                 case 4:
00765                         Subtract(R0, A1, A0, N2);
00766                         Subtract(R1, B0, B1, N2);
00767                         RecursiveMultiply(T0, T2, R0, R1, N2);
00768                         Subtract(T1, T1, R1, N2);
00769                         carry = -1;
00770                         break;
00771                 default:
00772                         SetWords(T0, 0, N);
00773                         carry = 0;
00774                 }
00775 
00776                 RecursiveMultiply(R0, T2, A0, B0, N2);
00777                 RecursiveMultiply(R2, T2, A1, B1, N2);
00778 
00779                 // now T[01] holds (A1-A0)*(B0-B1), R[01] holds A0*B0, R[23] holds A1*B1
00780 
00781                 carry += Add(T0, T0, R0, N);
00782                 carry += Add(T0, T0, R2, N);
00783                 carry += Add(R1, R1, T0, N);
00784 
00785                 assert (carry >= 0 && carry <= 2);
00786                 Increment(R3, N2, carry);
00787         }
00788 }
00789 
00790 // R[2*N] - result = A*A
00791 // T[2*N] - temporary work space
00792 // A[N] --- number to be squared
00793 
00794 void RecursiveSquare(word *R, word *T, const word *A, unsigned int N)
00795 {
00796         assert(N && N%2==0);
00797 
00798         if (N==2)
00799                 AtomicSquare(R, A[0], A[1]);
00800         else if (N==4)
00801         {
00802                 // VC60 workaround: MSVC 6.0 has an optimization bug that makes
00803                 // (dword)A*B where either A or B has been cast to a dword before
00804                 // very expensive. Revisit a CombaSquare4() function when this
00805                 // bug is fixed.
00806                 CombaMultiply4(R, A, A);
00807         }
00808         else
00809         {
00810                 const unsigned int N2 = N/2;
00811 
00812                 RecursiveSquare(R0, T2, A0, N2);
00813                 RecursiveSquare(R2, T2, A1, N2);
00814                 RecursiveMultiply(T0, T2, A0, A1, N2);
00815 
00816                 word carry = Add(R1, R1, T0, N);
00817                 carry += Add(R1, R1, T0, N);
00818                 Increment(R3, N2, carry);
00819         }
00820 }
00821 
00822 // R[N] - bottom half of A*B
00823 // T[N] - temporary work space
00824 // A[N] - multiplier
00825 // B[N] - multiplicant
00826 
00827 void RecursiveMultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N)
00828 {
00829         assert(N>=2 && N%2==0);
00830 
00831         if (N==2)
00832                 AtomicMultiplyBottom(R, A[0], A[1], B[0], B[1]);
00833         else if (N==4)
00834                 CombaMultiplyBottom4(R, A, B);
00835         else if (N==8)
00836                 CombaMultiplyBottom8(R, A, B);
00837         else
00838         {
00839                 const unsigned int N2 = N/2;
00840 
00841                 RecursiveMultiply(R, T, A0, B0, N2);
00842                 RecursiveMultiplyBottom(T0, T1, A1, B0, N2);
00843                 Add(R1, R1, T0, N2);
00844                 RecursiveMultiplyBottom(T0, T1, A0, B1, N2);
00845                 Add(R1, R1, T0, N2);
00846         }
00847 }
00848 
00849 // R[N] --- upper half of A*B
00850 // T[2*N] - temporary work space
00851 // L[N] --- lower half of A*B
00852 // A[N] --- multiplier
00853 // B[N] --- multiplicant
00854 
00855 void RecursiveMultiplyTop(word *R, word *T, const word *L, const word *A, const word *B, unsigned int N)
00856 {
00857         assert(N>=2 && N%2==0);
00858 
00859         if (N==2)
00860         {
00861                 AtomicMultiply(T, A[0], A[1], B[0], B[1]);
00862                 R[0] = T[2];
00863                 R[1] = T[3];
00864         }
00865         else if (N==4)
00866         {
00867                 CombaMultiply4(T, A, B);
00868                 R[0] = T[4];
00869                 R[1] = T[5];
00870                 R[2] = T[6];
00871                 R[3] = T[7];
00872         }
00873         else
00874         {
00875                 const unsigned int N2 = N/2;
00876                 int carry;
00877 
00878                 int aComp = Compare(A0, A1, N2);
00879                 int bComp = Compare(B0, B1, N2);
00880 
00881                 switch (2*aComp + aComp + bComp)
00882                 {
00883                 case -4:
00884                         Subtract(R0, A1, A0, N2);
00885                         Subtract(R1, B0, B1, N2);
00886                         RecursiveMultiply(T0, T2, R0, R1, N2);
00887                         Subtract(T1, T1, R0, N2);
00888                         carry = -1;
00889                         break;
00890                 case -2:
00891                         Subtract(R0, A1, A0, N2);
00892                         Subtract(R1, B0, B1, N2);
00893                         RecursiveMultiply(T0, T2, R0, R1, N2);
00894                         carry = 0;
00895                         break;
00896                 case 2:
00897                         Subtract(R0, A0, A1, N2);
00898                         Subtract(R1, B1, B0, N2);
00899                         RecursiveMultiply(T0, T2, R0, R1, N2);
00900                         carry = 0;
00901                         break;
00902                 case 4:
00903                         Subtract(R0, A1, A0, N2);
00904                         Subtract(R1, B0, B1, N2);
00905                         RecursiveMultiply(T0, T2, R0, R1, N2);
00906                         Subtract(T1, T1, R1, N2);
00907                         carry = -1;
00908                         break;
00909                 default:
00910                         SetWords(T0, 0, N);
00911                         carry = 0;
00912                 }
00913 
00914                 RecursiveMultiply(T2, R0, A1, B1, N2);
00915 
00916                 // now T[01] holds (A1-A0)*(B0-B1), T[23] holds A1*B1
00917 
00918                 CopyWords(R0, L+N2, N2);
00919                 word c2 = Subtract(R0, R0, L, N2);
00920                 c2 += Subtract(R0, R0, T0, N2);
00921                 word t = (Compare(R0, T2, N2) == -1);
00922 
00923                 carry += t;
00924                 carry += Increment(R0, N2, c2+t);
00925                 carry += Add(R0, R0, T1, N2);
00926                 carry += Add(R0, R0, T3, N2);
00927 
00928                 CopyWords(R1, T3, N2);
00929                 assert (carry >= 0 && carry <= 2);
00930                 Increment(R1, N2, carry);
00931         }
00932 }
00933 
00934 // R[NA+NB] - result = A*B
00935 // T[NA+NB] - temporary work space
00936 // A[NA] ---- multiplier
00937 // B[NB] ---- multiplicant
00938 
00939 void AsymmetricMultiply(word *R, word *T, const word *A, unsigned int NA, const word *B, unsigned int NB)
00940 {
00941         if (NA == NB)
00942         {
00943                 if (A == B)
00944                         RecursiveSquare(R, T, A, NA);
00945                 else
00946                         RecursiveMultiply(R, T, A, B, NA);
00947 
00948                 return;
00949         }
00950 
00951         if (NA > NB)
00952         {
00953                 std::swap(A, B);
00954                 std::swap(NA, NB);
00955         }
00956 
00957         assert(NB % NA == 0);
00958         assert((NB/NA)%2 == 0);         // NB is an even multiple of NA
00959 
00960         if (NA==2 && !A[1])
00961         {
00962                 switch (A[0])
00963                 {
00964                 case 0:
00965                         SetWords(R, 0, NB+2);
00966                         return;
00967                 case 1:
00968                         CopyWords(R, B, NB);
00969                         R[NB] = R[NB+1] = 0;
00970                         return;
00971                 default:
00972                         R[NB] = LinearMultiply(R, B, A[0], NB);
00973                         R[NB+1] = 0;
00974                         return;
00975                 }
00976         }
00977 
00978         RecursiveMultiply(R, T, A, B, NA);
00979         CopyWords(T+2*NA, R+NA, NA);
00980 
00981         unsigned i;
00982 
00983         for (i=2*NA; i<NB; i+=2*NA)
00984                 RecursiveMultiply(T+NA+i, T, A, B+i, NA);
00985         for (i=NA; i<NB; i+=2*NA)
00986                 RecursiveMultiply(R+i, T, A, B+i, NA);
00987 
00988         if (Add(R+NA, R+NA, T+2*NA, NB-NA))
00989                 Increment(R+NB, NA);
00990 }
00991 
00992 // R[N] ----- result = A inverse mod 2**(WORD_BITS*N)
00993 // T[3*N/2] - temporary work space
00994 // A[N] ----- an odd number as input
00995 
00996 void RecursiveInverseModPower2(word *R, word *T, const word *A, unsigned int N)
00997 {
00998         if (N==2)
00999                 AtomicInverseModPower2(R, A[0], A[1]);
01000         else
01001         {
01002                 const unsigned int N2 = N/2;
01003                 RecursiveInverseModPower2(R0, T0, A0, N2);
01004                 T0[0] = 1;
01005                 SetWords(T0+1, 0, N2-1);
01006                 RecursiveMultiplyTop(R1, T1, T0, R0, A0, N2);
01007                 RecursiveMultiplyBottom(T0, T1, R0, A1, N2);
01008                 Add(T0, R1, T0, N2);
01009                 TwosComplement(T0, N2);
01010                 RecursiveMultiplyBottom(R1, T1, R0, T0, N2);
01011         }
01012 }
01013 
01014 // R[N] --- result = X/(2**(WORD_BITS*N)) mod M
01015 // T[3*N] - temporary work space
01016 // X[2*N] - number to be reduced
01017 // M[N] --- modulus
01018 // U[N] --- multiplicative inverse of M mod 2**(WORD_BITS*N)
01019 
01020 void MontgomeryReduce(word *R, word *T, const word *X, const word *M, const word *U, unsigned int N)
01021 {
01022         RecursiveMultiplyBottom(R, T, X, U, N);
01023         RecursiveMultiplyTop(T, T+N, X, R, M, N);
01024         if (Subtract(R, X+N, T, N))
01025         {
01026                 word carry = Add(R, R, M, N);
01027                 assert(carry);
01028         }
01029 }
01030 
01031 // R[N] --- result = X/(2**(WORD_BITS*N/2)) mod M
01032 // T[2*N] - temporary work space
01033 // X[2*N] - number to be reduced
01034 // M[N] --- modulus
01035 // U[N/2] - multiplicative inverse of M mod 2**(WORD_BITS*N/2)
01036 // V[N] --- 2**(WORD_BITS*3*N/2) mod M
01037 
01038 void HalfMontgomeryReduce(word *R, word *T, const word *X, const word *M, const word *U, const word *V, unsigned int N)
01039 {
01040         assert(N%2==0 && N>=4);
01041 
01042 #define M0              M
01043 #define M1              (M+N2)
01044 #define V0              V
01045 #define V1              (V+N2)
01046 
01047 #define X0              X
01048 #define X1              (X+N2)
01049 #define X2              (X+N)
01050 #define X3              (X+N+N2)
01051 
01052         const unsigned int N2 = N/2;
01053         RecursiveMultiply(T0, T2, V0, X3, N2);
01054         int c2 = Add(T0, T0, X0, N);
01055         RecursiveMultiplyBottom(T3, T2, T0, U, N2);
01056         RecursiveMultiplyTop(T2, R, T0, T3, M0, N2);
01057         c2 -= Subtract(T2, T1, T2, N2);
01058         RecursiveMultiply(T0, R, T3, M1, N2);
01059         c2 -= Subtract(T0, T2, T0, N2);
01060         int c3 = -(int)Subtract(T1, X2, T1, N2);
01061         RecursiveMultiply(R0, T2, V1, X3, N2);
01062         c3 += Add(R, R, T, N);
01063 
01064         if (c2>0)
01065                 c3 += Increment(R1, N2);
01066         else if (c2<0)
01067                 c3 -= Decrement(R1, N2, -c2);
01068 
01069         assert(c3>=-1 && c3<=1);
01070         if (c3>0)
01071                 Subtract(R, R, M, N);
01072         else if (c3<0)
01073                 Add(R, R, M, N);
01074 
01075 #undef M0
01076 #undef M1
01077 #undef V0
01078 #undef V1
01079 
01080 #undef X0
01081 #undef X1
01082 #undef X2
01083 #undef X3
01084 }
01085 
01086 #undef A0
01087 #undef A1
01088 #undef B0
01089 #undef B1
01090 
01091 #undef T0
01092 #undef T1
01093 #undef T2
01094 #undef T3
01095 
01096 #undef R0
01097 #undef R1
01098 #undef R2
01099 #undef R3
01100 
01101 // do a 3 word by 2 word divide, returns quotient and leaves remainder in A
01102 static word SubatomicDivide(word *A, word B0, word B1)
01103 {
01104         // assert {A[2],A[1]} < {B1,B0}, so quotient can fit in a word
01105         assert(A[2] < B1 || (A[2]==B1 && A[1] < B0));
01106 
01107         dword p, u;
01108         word Q;
01109 
01110         // estimate the quotient: do a 2 word by 1 word divide
01111         if (B1+1 == 0)
01112                 Q = A[2];
01113         else
01114                 Q = word(MAKE_DWORD(A[1], A[2]) / (B1+1));
01115 
01116         // now subtract Q*B from A
01117         p = (dword) B0*Q;
01118         u = (dword) A[0] - LOW_WORD(p);
01119         A[0] = LOW_WORD(u);
01120         u = (dword) A[1] - HIGH_WORD(p) - (word)(0-HIGH_WORD(u)) - (dword)B1*Q;
01121         A[1] = LOW_WORD(u);
01122         A[2] += HIGH_WORD(u);
01123 
01124         // Q <= actual quotient, so fix it
01125         while (A[2] || A[1] > B1 || (A[1]==B1 && A[0]>=B0))
01126         {
01127                 u = (dword) A[0] - B0;
01128                 A[0] = LOW_WORD(u);
01129                 u = (dword) A[1] - B1 - (word)(0-HIGH_WORD(u));
01130                 A[1] = LOW_WORD(u);
01131                 A[2] += HIGH_WORD(u);
01132                 Q++;
01133                 assert(Q);      // shouldn't overflow
01134         }
01135 
01136         return Q;
01137 }
01138 
01139 // do a 4 word by 2 word divide, returns 2 word quotient in Q0 and Q1
01140 static inline void AtomicDivide(word &Q0, word &Q1, const word *A, word B0, word B1)
01141 {
01142         if (!B0 && !B1) // if divisor is 0, we assume divisor==2**(2*WORD_BITS)
01143         {
01144                 Q0 = A[2];
01145                 Q1 = A[3];
01146         }
01147         else
01148         {
01149                 word T[4];
01150                 T[0] = A[0]; T[1] = A[1]; T[2] = A[2]; T[3] = A[3];
01151                 Q1 = SubatomicDivide(T+1, B0, B1);
01152                 Q0 = SubatomicDivide(T, B0, B1);
01153 
01154 #ifndef NDEBUG
01155                 // multiply quotient and divisor and add remainder, make sure it equals dividend
01156                 assert(!T[2] && !T[3] && (T[1] < B1 || (T[1]==B1 && T[0]<B0)));
01157                 word P[4];
01158                 AtomicMultiply(P, Q0, Q1, B0, B1);
01159                 Add(P, P, T, 4);
01160                 assert(memcmp(P, A, 4*WORD_SIZE)==0);
01161 #endif
01162         }
01163 }
01164 
01165 // for use by Divide(), corrects the underestimated quotient {Q1,Q0}
01166 static void CorrectQuotientEstimate(word *R, word *T, word &Q0, word &Q1, const word *B, unsigned int N)
01167 {
01168         assert(N && N%2==0);
01169 
01170         if (Q1)
01171         {
01172                 T[N] = T[N+1] = 0;
01173                 unsigned i;
01174                 for (i=0; i<N; i+=4)
01175                         AtomicMultiply(T+i, Q0, Q1, B[i], B[i+1]);
01176                 for (i=2; i<N; i+=4)
01177                         if (AtomicMultiplyAdd(T+i, Q0, Q1, B[i], B[i+1]))
01178                                 T[i+5] += (++T[i+4]==0);
01179         }
01180         else
01181         {
01182                 T[N] = LinearMultiply(T, B, Q0, N);
01183                 T[N+1] = 0;
01184         }
01185 
01186         word borrow = Subtract(R, R, T, N+2);
01187         assert(!borrow && !R[N+1]);
01188 
01189         while (R[N] || Compare(R, B, N) >= 0)
01190         {
01191                 R[N] -= Subtract(R, R, B, N);
01192                 Q1 += (++Q0==0);
01193                 assert(Q0 || Q1); // no overflow
01194         }
01195 }
01196 
01197 // R[NB] -------- remainder = A%B
01198 // Q[NA-NB+2] --- quotient      = A/B
01199 // T[NA+2*NB+4] - temp work space
01200 // A[NA] -------- dividend
01201 // B[NB] -------- divisor
01202 
01203 void Divide(word *R, word *Q, word *T, const word *A, unsigned int NA, const word *B, unsigned int NB)
01204 {
01205         assert(NA && NB && NA%2==0 && NB%2==0);
01206         assert(B[NB-1] || B[NB-2]);
01207         assert(NB <= NA);
01208 
01209         // set up temporary work space
01210         word *const TA=T;
01211         word *const TB=T+NA+2;
01212         word *const TP=T+NA+2+NB;
01213 
01214         // copy B into TB and normalize it so that TB has highest bit set to 1
01215         unsigned shiftWords = (B[NB-1]==0);
01216         TB[0] = TB[NB-1] = 0;
01217         CopyWords(TB+shiftWords, B, NB-shiftWords);
01218         unsigned shiftBits = WORD_BITS - BitPrecision(TB[NB-1]);
01219         assert(shiftBits < WORD_BITS);
01220         ShiftWordsLeftByBits(TB, NB, shiftBits);
01221 
01222         // copy A into TA and normalize it
01223         TA[0] = TA[NA] = TA[NA+1] = 0;
01224         CopyWords(TA+shiftWords, A, NA);
01225         ShiftWordsLeftByBits(TA, NA+2, shiftBits);
01226 
01227         if (TA[NA+1]==0 && TA[NA] <= 1)
01228         {
01229                 Q[NA-NB+1] = Q[NA-NB] = 0;
01230                 while (TA[NA] || Compare(TA+NA-NB, TB, NB) >= 0)
01231                 {
01232                         TA[NA] -= Subtract(TA+NA-NB, TA+NA-NB, TB, NB);
01233                         ++Q[NA-NB];
01234                 }
01235         }
01236         else
01237         {
01238                 NA+=2;
01239                 assert(Compare(TA+NA-NB, TB, NB) < 0);
01240         }
01241 
01242         word B0 = TB[NB-2] + 1;
01243         word B1 = TB[NB-1] + (B0==0);
01244 
01245         // start reducing TA mod TB, 2 words at a time
01246         for (unsigned i=NA-2; i>=NB; i-=2)
01247         {
01248                 AtomicDivide(Q[i-NB], Q[i-NB+1], TA+i-2, B0, B1);
01249                 CorrectQuotientEstimate(TA+i-NB, TP, Q[i-NB], Q[i-NB+1], TB, NB);
01250         }
01251 
01252         // copy TA into R, and denormalize it
01253         CopyWords(R, TA+shiftWords, NB);
01254         ShiftWordsRightByBits(R, NB, shiftBits);
01255 }
01256 
01257 static inline unsigned int EvenWordCount(const word *X, unsigned int N)
01258 {
01259         while (N && X[N-2]==0 && X[N-1]==0)
01260                 N-=2;
01261         return N;
01262 }
01263 
01264 // return k
01265 // R[N] --- result = A^(-1) * 2^k mod M
01266 // T[4*N] - temporary work space
01267 // A[NA] -- number to take inverse of
01268 // M[N] --- modulus
01269 
01270 unsigned int AlmostInverse(word *R, word *T, const word *A, unsigned int NA, const word *M, unsigned int N)
01271 {
01272         assert(NA<=N && N && N%2==0);
01273 
01274         word *b = T;
01275         word *c = T+N;
01276         word *f = T+2*N;
01277         word *g = T+3*N;
01278         unsigned int bcLen=2, fgLen=EvenWordCount(M, N);
01279         unsigned int k=0, s=0;
01280 
01281         SetWords(T, 0, 3*N);
01282         b[0]=1;
01283         CopyWords(f, A, NA);
01284         CopyWords(g, M, N);
01285 
01286         while (1)
01287         {
01288                 word t=f[0];
01289                 while (!t)
01290                 {
01291                         if (EvenWordCount(f, fgLen)==0)
01292                         {
01293                                 SetWords(R, 0, N);
01294                                 return 0;
01295                         }
01296 
01297                         ShiftWordsRightByWords(f, fgLen, 1);
01298                         if (c[bcLen-1]) bcLen+=2;
01299                         assert(bcLen <= N);
01300                         ShiftWordsLeftByWords(c, bcLen, 1);
01301                         k+=WORD_BITS;
01302                         t=f[0];
01303                 }
01304 
01305                 unsigned int i=0;
01306                 while (t%2 == 0)
01307                 {
01308                         t>>=1;
01309                         i++;
01310                 }
01311                 k+=i;
01312 
01313                 if (t==1 && f[1]==0 && EvenWordCount(f, fgLen)==2)
01314                 {
01315                         if (s%2==0)
01316                                 CopyWords(R, b, N);
01317                         else
01318                                 Subtract(R, M, b, N);
01319                         return k;
01320                 }
01321 
01322                 ShiftWordsRightByBits(f, fgLen, i);
01323                 t=ShiftWordsLeftByBits(c, bcLen, i);
01324                 if (t)
01325                 {
01326                         c[bcLen] = t;
01327                         bcLen+=2;
01328                         assert(bcLen <= N);
01329                 }
01330 
01331                 if (f[fgLen-2]==0 && g[fgLen-2]==0 && f[fgLen-1]==0 && g[fgLen-1]==0)
01332                         fgLen-=2;
01333 
01334                 if (Compare(f, g, fgLen)==-1)
01335                 {
01336                         std::swap(f, g);
01337                         std::swap(b, c);
01338                         s++;
01339                 }
01340 
01341                 Subtract(f, f, g, fgLen);
01342 
01343                 if (Add(b, b, c, bcLen))
01344                 {
01345                         b[bcLen] = 1;
01346                         bcLen+=2;
01347                         assert(bcLen <= N);
01348                 }
01349         }
01350 }
01351 
01352 // R[N] - result = A/(2^k) mod M
01353 // A[N] - input
01354 // M[N] - modulus
01355 
01356 void DivideByPower2Mod(word *R, const word *A, unsigned int k, const word *M, unsigned int N)
01357 {
01358         CopyWords(R, A, N);
01359 
01360         while (k--)
01361         {
01362                 if (R[0]%2==0)
01363                         ShiftWordsRightByBits(R, N, 1);
01364                 else
01365                 {
01366                         word carry = Add(R, R, M, N);
01367                         ShiftWordsRightByBits(R, N, 1);
01368                         R[N-1] += carry<<(WORD_BITS-1);
01369                 }
01370         }
01371 }
01372 
01373 // R[N] - result = A*(2^k) mod M
01374 // A[N] - input
01375 // M[N] - modulus
01376 
01377 void MultiplyByPower2Mod(word *R, const word *A, unsigned int k, const word *M, unsigned int N)
01378 {
01379         CopyWords(R, A, N);
01380 
01381         while (k--)
01382                 if (ShiftWordsLeftByBits(R, N, 1) || Compare(R, M, N)>=0)
01383                         Subtract(R, R, M, N);
01384 }
01385 
01386 // ******************************************************************
01387 
01388 static const unsigned int RoundupSizeTable[] = {2, 2, 2, 4, 4, 8, 8, 8, 8};
01389 
01390 static inline unsigned int RoundupSize(unsigned int n)
01391 {
01392         if (n<=8)
01393                 return RoundupSizeTable[n];
01394         else if (n<=16)
01395                 return 16;
01396         else if (n<=32)
01397                 return 32;
01398         else if (n<=64)
01399                 return 64;
01400         else return 1U << BitPrecision(n-1);
01401 }
01402 
01403 Integer::Integer()
01404         : reg(2), sign(POSITIVE)
01405 {
01406         reg[0] = reg[1] = 0;
01407 }
01408 
01409 Integer::Integer(const Integer& t)
01410         : reg(RoundupSize(t.WordCount())), sign(t.sign)
01411 {
01412         CopyWords(reg, t.reg, reg.size);
01413 }
01414 
01415 Integer::Integer(signed long value)
01416         : reg(2)
01417 {
01418         if (value >= 0)
01419                 sign = POSITIVE;
01420         else
01421         {
01422                 sign = NEGATIVE;
01423                 value = -value;
01424         }
01425         reg[0] = word(value);
01426         reg[1] = sizeof(value)>WORD_SIZE ? word(value>>WORD_BITS) : 0;
01427 }
01428 
01429 bool Integer::IsConvertableToLong() const
01430 {
01431         if (ByteCount() > sizeof(long))
01432                 return false;
01433 
01434         unsigned long value = reg[0];
01435         value += sizeof(value)>WORD_SIZE ? ((unsigned long)reg[1]<<WORD_BITS) : 0;
01436 
01437         if (sign==POSITIVE)
01438                 return (signed long)value >= 0;
01439         else
01440                 return -(signed long)value < 0;
01441 }
01442 
01443 signed long Integer::ConvertToLong() const
01444 {
01445         unsigned long value = reg[0];
01446         value += sizeof(value)>WORD_SIZE ? ((unsigned long)reg[1]<<WORD_BITS) : 0;
01447         return sign==POSITIVE ? value : -(signed long)value;
01448 }
01449 
01450 Integer::Integer(BufferedTransformation &encodedInteger, unsigned int byteCount, Signedness s)
01451 {
01452         Decode(encodedInteger, byteCount, s);
01453 }
01454 
01455 Integer::Integer(const byte *encodedInteger, unsigned int byteCount, Signedness s)
01456 {
01457         Decode(encodedInteger, byteCount, s);
01458 }
01459 
01460 Integer::Integer(BufferedTransformation &bt)
01461 {
01462         BERDecode(bt);
01463 }
01464 
01465 Integer::Integer(RandomNumberGenerator &rng, unsigned int bitcount)
01466 {
01467         Randomize(rng, bitcount);
01468 }
01469 
01470 Integer::Integer(RandomNumberGenerator &rng, const Integer &min, const Integer &max, RandomNumberType rnType, const Integer &equiv, const Integer &mod)
01471 {
01472         if (!Randomize(rng, min, max, rnType, equiv, mod))
01473                 throw Integer::RandomNumberNotFound();
01474 }
01475 
01476 Integer Integer::Power2(unsigned int e)
01477 {
01478         Integer r((word)0, bitsToWords(e+1));
01479         r.SetBit(e);
01480         return r;
01481 }
01482 
01483 const Integer &Integer::Zero()
01484 {
01485         static const Integer zero;
01486         return zero;
01487 }
01488 
01489 const Integer &Integer::One()
01490 {
01491         static const Integer one(1,2);
01492         return one;
01493 }
01494 
01495 bool Integer::operator!() const
01496 {
01497         return IsNegative() ? false : (reg[0]==0 && WordCount()==0);
01498 }
01499 
01500 Integer& Integer::operator=(const Integer& t)
01501 {
01502         if (this != &t)
01503         {
01504                 reg.New(RoundupSize(t.WordCount()));
01505                 CopyWords(reg, t.reg, reg.size);
01506                 sign = t.sign;
01507         }
01508         return *this;
01509 }
01510 
01511 bool Integer::GetBit(unsigned int n) const
01512 {
01513         if (n/WORD_BITS >= reg.size)
01514                 return 0;
01515         else
01516                 return bool((reg[n/WORD_BITS] >> (n % WORD_BITS)) & 1);
01517 }
01518 
01519 void Integer::SetBit(unsigned int n, bool value)
01520 {
01521         if (value)
01522         {
01523                 reg.CleanGrow(RoundupSize(bitsToWords(n+1)));
01524                 reg[n/WORD_BITS] |= (word(1) << (n%WORD_BITS));
01525         }
01526         else
01527         {
01528                 if (n/WORD_BITS < reg.size)
01529                         reg[n/WORD_BITS] &= ~(word(1) << (n%WORD_BITS));
01530         }
01531 }
01532 
01533 byte Integer::GetByte(unsigned int n) const
01534 {
01535         if (n/WORD_SIZE >= reg.size)
01536                 return 0;
01537         else
01538                 return byte(reg[n/WORD_SIZE] >> ((n%WORD_SIZE)*8));
01539 }
01540 
01541 void Integer::SetByte(unsigned int n, byte value)
01542 {
01543         reg.CleanGrow(RoundupSize(bytesToWords(n+1)));
01544         reg[n/WORD_SIZE] &= ~(word(0xff) << 8*(n%WORD_SIZE));
01545         reg[n/WORD_SIZE] |= (word(value) << 8*(n%WORD_SIZE));
01546 }
01547 
01548 unsigned long Integer::GetBits(unsigned int i, unsigned int n) const
01549 {
01550         assert(n <= sizeof(unsigned long)*8);
01551         unsigned long v = 0;
01552         for (unsigned int j=0; j<n; j++)
01553                 v |= GetBit(i+j) << j;
01554         return v;
01555 }
01556 
01557 Integer Integer::operator-() const
01558 {
01559         Integer result(*this);
01560         result.Negate();
01561         return result;
01562 }
01563 
01564 Integer Integer::AbsoluteValue() const
01565 {
01566         Integer result(*this);
01567         result.sign = POSITIVE;
01568         return result;
01569 }
01570 
01571 void Integer::swap(Integer &a)
01572 {
01573         reg.swap(a.reg);
01574         std::swap(sign, a.sign);
01575 }
01576 
01577 Integer::Integer(word value, unsigned int length)
01578         : reg(RoundupSize(length)), sign(POSITIVE)
01579 {
01580         reg[0] = value;
01581         SetWords(reg+1, 0, reg.size-1);
01582 }
01583 
01584 
01585 Integer::Integer(const char *str)
01586         : reg(2), sign(POSITIVE)
01587 {
01588         word radix;
01589         unsigned length = strlen(str);
01590 
01591         SetWords(reg, 0, 2);
01592 
01593         if (length == 0)
01594                 return;
01595 
01596         switch (str[length-1])
01597         {
01598         case 'h':
01599         case 'H':
01600                 radix=16;
01601                 break;
01602         case 'o':
01603         case 'O':
01604                 radix=8;
01605                 break;
01606         case 'b':
01607         case 'B':
01608                 radix=2;
01609                 break;
01610         default:
01611                 radix=10;
01612         }
01613 
01614         if (strncmp("0x", str, 2) == 0)
01615                 radix = 16;
01616 
01617         for (unsigned i=0; i<length; i++)
01618         {
01619                 word digit;
01620 
01621                 if (str[i] >= '0' && str[i] <= '9')
01622                         digit = str[i] - '0';
01623                 else if (str[i] >= 'A' && str[i] <= 'F')
01624                         digit = str[i] - 'A' + 10;
01625                 else if (str[i] >= 'a' && str[i] <= 'f')
01626                         digit = str[i] - 'a' + 10;
01627                 else
01628                         digit = radix;
01629 
01630                 if (digit < radix)
01631                 {
01632                         *this *= radix;
01633                         *this += digit;
01634                 }
01635         }
01636 
01637         if (str[0] == '-')
01638                 Negate();
01639 }
01640 
01641 unsigned int Integer::WordCount() const
01642 {
01643         return CountWords(reg, reg.size);
01644 }
01645 
01646 unsigned int Integer::ByteCount() const
01647 {
01648         unsigned wordCount = WordCount();
01649         if (wordCount)
01650                 return (wordCount-1)*WORD_SIZE + BytePrecision(reg[wordCount-1]);
01651         else
01652                 return 0;
01653 }
01654 
01655 unsigned int Integer::BitCount() const
01656 {
01657         unsigned wordCount = WordCount();
01658         if (wordCount)
01659                 return (wordCount-1)*WORD_BITS + BitPrecision(reg[wordCount-1]);
01660         else
01661                 return 0;
01662 }
01663 
01664 void Integer::Decode(const byte *input, unsigned int inputLen, Signedness s)
01665 {
01666         StringStore store(input, inputLen);
01667         Decode(store, inputLen, s);
01668 }
01669 
01670 void Integer::Decode(BufferedTransformation &bt, unsigned int inputLen, Signedness s)
01671 {
01672         assert(bt.MaxRetrievable() >= inputLen);
01673 
01674         byte b;
01675         bt.Peek(b);
01676         sign = ((s==SIGNED) && (b & 0x80)) ? NEGATIVE : POSITIVE;
01677 
01678         while (inputLen>0 && (sign==POSITIVE ? b==0 : b==0xff))
01679         {
01680                 bt.Skip(1);
01681                 inputLen--;
01682                 bt.Peek(b);
01683         }
01684 
01685         reg.CleanNew(RoundupSize(bytesToWords(inputLen)));
01686 
01687         for (unsigned int i=inputLen; i > 0; i--)
01688         {
01689                 bt.Get(b);
01690                 reg[(i-1)/WORD_SIZE] |= b << ((i-1)%WORD_SIZE)*8;
01691         }
01692 
01693         if (sign == NEGATIVE)
01694         {
01695                 for (unsigned i=inputLen; i<reg.size*WORD_SIZE; i++)
01696                         reg[i/WORD_SIZE] |= 0xff << (i%WORD_SIZE)*8;
01697                 TwosComplement(reg, reg.size);
01698         }
01699 }
01700 
01701 unsigned int Integer::MinEncodedSize(Signedness signedness) const
01702 {
01703         unsigned int outputLen = STDMAX(1U, ByteCount());
01704         if (signedness == UNSIGNED)
01705                 return outputLen;
01706         if (NotNegative() && (GetByte(outputLen-1) & 0x80))
01707                 outputLen++;
01708         if (IsNegative() && *this < -Power2(outputLen*8-1))
01709                 outputLen++;
01710         return outputLen;
01711 }
01712 
01713 unsigned int Integer::Encode(byte *output, unsigned int outputLen, Signedness signedness) const
01714 {
01715         ArraySink sink(output, outputLen);
01716         return Encode(sink, outputLen);
01717 }
01718 
01719 unsigned int Integer::Encode(BufferedTransformation &bt, unsigned int outputLen, Signedness signedness) const
01720 {
01721         if (signedness == UNSIGNED || NotNegative())
01722         {
01723                 for (unsigned int i=outputLen; i > 0; i--)
01724                         bt.Put(GetByte(i-1));
01725         }
01726         else
01727         {
01728                 // take two's complement of *this
01729                 Integer temp = Integer::Power2(8*STDMAX(ByteCount(), outputLen)) + *this;
01730                 for (unsigned i=0; i<outputLen; i++)
01731                         bt.Put(temp.GetByte(outputLen-i-1));
01732         }
01733         return outputLen;
01734 }
01735 
01736 void Integer::DEREncode(BufferedTransformation &bt) const
01737 {
01738         DERGeneralEncoder enc(bt, INTEGER);
01739         Encode(enc, MinEncodedSize(SIGNED), SIGNED);
01740         enc.MessageEnd();
01741 }
01742 
01743 void Integer::BERDecode(const byte *input, unsigned int len)
01744 {
01745         StringStore store(input, len);
01746         BERDecode(store);
01747 }
01748 
01749 void Integer::BERDecode(BufferedTransformation &bt)
01750 {
01751         BERGeneralDecoder dec(bt, INTEGER);
01752         if (!dec.IsDefiniteLength() || dec.MaxRetrievable() < dec.RemainingLength())
01753                 BERDecodeError();
01754         Decode(dec, dec.RemainingLength(), SIGNED);
01755         dec.MessageEnd();
01756 }
01757 
01758 void Integer::DEREncodeAsOctetString(BufferedTransformation &bt, unsigned int length) const
01759 {
01760         DERGeneralEncoder enc(bt, OCTET_STRING);
01761         Encode(enc, length);
01762         enc.MessageEnd();
01763 }
01764 
01765 void Integer::BERDecodeAsOctetString(BufferedTransformation &bt, unsigned int length)
01766 {
01767         BERGeneralDecoder dec(bt, OCTET_STRING);
01768         if (!dec.IsDefiniteLength() || dec.RemainingLength() != length)
01769                 BERDecodeError();
01770         Decode(dec, length);
01771         dec.MessageEnd();
01772 }
01773 
01774 unsigned int Integer::OpenPGPEncode(byte *output, unsigned int len) const
01775 {
01776         ArraySink sink(output, len);
01777         return OpenPGPEncode(sink);
01778 }
01779 
01780 unsigned int Integer::OpenPGPEncode(BufferedTransformation &bt) const
01781 {
01782         word16 bitCount = BitCount();
01783         bt.PutWord16(bitCount);
01784         return 2 + Encode(bt, bitsToBytes(bitCount));
01785 }
01786 
01787 void Integer::OpenPGPDecode(const byte *input, unsigned int len)
01788 {
01789         StringStore store(input, len);
01790         OpenPGPDecode(store);
01791 }
01792 
01793 void Integer::OpenPGPDecode(BufferedTransformation &bt)
01794 {
01795         word16 bitCount;
01796         if (bt.GetWord16(bitCount) != 2 || bt.MaxRetrievable() < bitsToBytes(bitCount))
01797                 throw OpenPGPDecodeErr();
01798         Decode(bt, bitsToBytes(bitCount));
01799 }
01800 
01801 void Integer::Randomize(RandomNumberGenerator &rng, unsigned int nbits)
01802 {
01803         const unsigned int nbytes = nbits/8 + 1;
01804         SecByteBlock buf(nbytes);
01805         rng.GetBlock(buf, nbytes);
01806         if (nbytes)
01807                 buf[0] = (byte)Crop(buf[0], nbits % 8);
01808         Decode(buf, nbytes, UNSIGNED);
01809 }
01810 
01811 void Integer::Randomize(RandomNumberGenerator &rng, const Integer &min, const Integer &max)
01812 {
01813         assert(max >= min);
01814 
01815         Integer range = max - min;
01816         const unsigned int nbits = range.BitCount();
01817 
01818         do
01819         {
01820                 Randomize(rng, nbits);
01821         }
01822         while (*this > range);
01823 
01824         *this += min;
01825 }
01826 
01827 bool Integer::Randomize(RandomNumberGenerator &rng, const Integer &min, const Integer &max, RandomNumberType rnType, const Integer &equiv, const Integer &mod)
01828 {
01829         assert(!equiv.IsNegative() && equiv < mod);
01830 
01831         switch (rnType)
01832         {
01833                 case ANY:
01834                         if (mod == One())
01835                                 Randomize(rng, min, max);
01836                         else
01837                         {
01838                                 Integer min1 = min + (equiv-min)%mod;
01839                                 if (max < min1)
01840                                         return false;
01841                                 Randomize(rng, Zero(), (max - min1) / mod);
01842                                 *this *= mod;
01843                                 *this += min1;
01844                         }
01845                         return true;
01846 
01847                 case PRIME:
01848                         int i;
01849                         i = 0;
01850                         while (1)
01851                         {
01852                                 if (++i==16)
01853                                 {
01854                                         // check if there are any suitable primes in [min, max]
01855                                         Integer first = min;
01856                                         if (FirstPrime(first, max, equiv, mod))
01857                                         {
01858                                                 // if there is only one suitable prime, we're done
01859                                                 *this = first;
01860                                                 if (!FirstPrime(first, max, equiv, mod))
01861                                                         return true;
01862                                         }
01863                                         else
01864                                                 return false;
01865                                 }
01866 
01867                                 Randomize(rng, min, max);
01868                                 if (FirstPrime(*this, STDMIN(*this+mod*PrimeSearchInterval(max), max), equiv, mod))
01869                                         return true;
01870                         }
01871 
01872                 default:
01873                         assert(false);
01874                         return false;
01875         }
01876 }
01877 
01878 std::istream& operator>>(std::istream& in, Integer &a)
01879 {
01880         char c;
01881         unsigned int length = 0;
01882         SecBlock<char> str(length + 16);
01883 
01884         std::ws(in);
01885 
01886         do
01887         {
01888                 in.read(&c, 1);
01889                 str[length++] = c;
01890                 if (length >= str.size)
01891                         str.Grow(length + 16);
01892         }
01893         while (in && (c=='-' || c=='x' || (c>='0' && c<='9') || (c>='a' && c<='f') || (c>='A' && c<='F') || c=='h' || c=='H' || c=='o' || c=='O' || c==',' || c=='.'));
01894 
01895         if (in.gcount())
01896                 in.putback(c);
01897         str[length-1] = '\0';
01898         a = Integer(str);
01899 
01900         return in;
01901 }
01902 
01903 std::ostream& operator<<(std::ostream& out, const Integer &a)
01904 {
01905         // Get relevant conversion specifications from ostream.
01906         long f = out.flags() & std::ios::basefield; // Get base digits.
01907         int base, block;
01908         char suffix;
01909         switch(f)
01910         {
01911         case std::ios::oct :
01912                 base = 8;
01913                 block = 8;
01914                 suffix = 'o';
01915                 break;
01916         case std::ios::hex :
01917                 base = 16;
01918                 block = 4;
01919                 suffix = 'h';
01920                 break;
01921         default :
01922                 base = 10;
01923                 block = 3;
01924                 suffix = '.';
01925         }
01926 
01927         SecBlock<char> s(a.BitCount() / (BitPrecision(base)-1) + 1);
01928         Integer temp1=a, temp2;
01929         unsigned i=0;
01930         const char vec[]="0123456789ABCDEF";
01931 
01932         if (a.IsNegative())
01933         {
01934                 out << '-';
01935                 temp1.Negate();
01936         }
01937 
01938         if (!a)
01939                 out << '0';
01940 
01941         while (!!temp1)
01942         {
01943                 word digit;
01944                 Integer::Divide(digit, temp2, temp1, base);
01945                 s[i++]=vec[digit];
01946                 temp1=temp2;
01947         }
01948 
01949         while (i--)
01950         {
01951                 out << s[i];
01952 //              if (i && !(i%block))
01953 //                      out << ",";
01954         }
01955         return out << suffix;
01956 }
01957 
01958 Integer& Integer::operator++()
01959 {
01960         if (NotNegative())
01961         {
01962                 if (Increment(reg, reg.size))
01963                 {
01964                         reg.CleanGrow(2*reg.size);
01965                         reg[reg.size/2]=1;
01966                 }
01967         }
01968         else
01969         {
01970                 word borrow = Decrement(reg, reg.size);
01971                 assert(!borrow);
01972                 if (WordCount()==0)
01973                         *this = Zero();
01974         }
01975         return *this;
01976 }
01977 
01978 Integer& Integer::operator--()
01979 {
01980         if (IsNegative())
01981         {
01982                 if (Increment(reg, reg.size))
01983                 {
01984                         reg.CleanGrow(2*reg.size);
01985                         reg[reg.size/2]=1;
01986                 }
01987         }
01988         else
01989         {
01990                 if (Decrement(reg, reg.size))
01991                         *this = -One();
01992         }
01993         return *this;
01994 }
01995 
01996 void PositiveAdd(Integer &sum, const Integer &a, const Integer& b)
01997 {
01998         word carry;
01999         if (a.reg.size == b.reg.size)
02000                 carry = Add(sum.reg, a.reg, b.reg, a.reg.size);
02001         else if (a.reg.size > b.reg.size)
02002         {
02003                 carry = Add(sum.reg, a.reg, b.reg, b.reg.size);
02004                 CopyWords(sum.reg+b.reg.size, a.reg+b.reg.size, a.reg.size-b.reg.size);
02005                 carry = Increment(sum.reg+b.reg.size, a.reg.size-b.reg.size, carry);
02006         }
02007         else
02008         {
02009                 carry = Add(sum.reg, a.reg, b.reg, a.reg.size);
02010                 CopyWords(sum.reg+a.reg.size, b.reg+a.reg.size, b.reg.size-a.reg.size);
02011                 carry = Increment(sum.reg+a.reg.size, b.reg.size-a.reg.size, carry);
02012         }
02013 
02014         if (carry)
02015         {
02016                 sum.reg.CleanGrow(2*sum.reg.size);
02017                 sum.reg[sum.reg.size/2] = 1;
02018         }
02019         sum.sign = Integer::POSITIVE;
02020 }
02021 
02022 void PositiveSubtract(Integer &diff, const Integer &a, const Integer& b)
02023 {
02024         unsigned aSize = a.WordCount();
02025         aSize += aSize%2;
02026         unsigned bSize = b.WordCount();
02027         bSize += bSize%2;
02028 
02029         if (aSize == bSize)
02030         {
02031                 if (Compare(a.reg, b.reg, aSize) >= 0)
02032                 {
02033                         Subtract(diff.reg, a.reg, b.reg, aSize);
02034                         diff.sign = Integer::POSITIVE;
02035                 }
02036                 else
02037                 {
02038                         Subtract(diff.reg, b.reg, a.reg, aSize);
02039                         diff.sign = Integer::NEGATIVE;
02040                 }
02041         }
02042         else if (aSize > bSize)
02043         {
02044                 word borrow = Subtract(diff.reg, a.reg, b.reg, bSize);
02045                 CopyWords(diff.reg+bSize, a.reg+bSize, aSize-bSize);
02046                 borrow = Decrement(diff.reg+bSize, aSize-bSize, borrow);
02047                 assert(!borrow);
02048                 diff.sign = Integer::POSITIVE;
02049         }
02050         else
02051         {
02052                 word borrow = Subtract(diff.reg, b.reg, a.reg, aSize);
02053                 CopyWords(diff.reg+aSize, b.reg+aSize, bSize-aSize);
02054                 borrow = Decrement(diff.reg+aSize, bSize-aSize, borrow);
02055                 assert(!borrow);
02056                 diff.sign = Integer::NEGATIVE;
02057         }
02058 }
02059 
02060 Integer Integer::Plus(const Integer& b) const
02061 {
02062         Integer sum((word)0, STDMAX(reg.size, b.reg.size));
02063         if (NotNegative())
02064         {
02065                 if (b.NotNegative())
02066                         PositiveAdd(sum, *this, b);
02067                 else
02068                         PositiveSubtract(sum, *this, b);
02069         }
02070         else
02071         {
02072                 if (b.NotNegative())
02073                         PositiveSubtract(sum, b, *this);
02074                 else
02075                 {
02076                         PositiveAdd(sum, *this, b);
02077                         sum.sign = Integer::NEGATIVE;
02078                 }
02079         }
02080         return sum;
02081 }
02082 
02083 Integer& Integer::operator+=(const Integer& t)
02084 {
02085         reg.CleanGrow(t.reg.size);
02086         if (NotNegative())
02087         {
02088                 if (t.NotNegative())
02089                         PositiveAdd(*this, *this, t);
02090                 else
02091                         PositiveSubtract(*this, *this, t);
02092         }
02093         else
02094         {
02095                 if (t.NotNegative())
02096                         PositiveSubtract(*this, t, *this);
02097                 else
02098                 {
02099                         PositiveAdd(*this, *this, t);
02100                         sign = Integer::NEGATIVE;
02101                 }
02102         }
02103         return *this;
02104 }
02105 
02106 Integer Integer::Minus(const Integer& b) const
02107 {
02108         Integer diff((word)0, STDMAX(reg.size, b.reg.size));
02109         if (NotNegative())
02110         {
02111                 if (b.NotNegative())
02112                         PositiveSubtract(diff, *this, b);
02113                 else
02114                         PositiveAdd(diff, *this, b);
02115         }
02116         else
02117         {
02118                 if (b.NotNegative())
02119                 {
02120                         PositiveAdd(diff, *this, b);
02121                         diff.sign = Integer::NEGATIVE;
02122                 }
02123                 else
02124                         PositiveSubtract(diff, b, *this);
02125         }
02126         return diff;
02127 }
02128 
02129 Integer& Integer::operator-=(const Integer& t)
02130 {
02131         reg.CleanGrow(t.reg.size);
02132         if (NotNegative())
02133         {
02134                 if (t.NotNegative())
02135                         PositiveSubtract(*this, *this, t);
02136                 else
02137                         PositiveAdd(*this, *this, t);
02138         }
02139         else
02140         {
02141                 if (t.NotNegative())
02142                 {
02143                         PositiveAdd(*this, *this, t);
02144                         sign = Integer::NEGATIVE;
02145                 }
02146                 else
02147                         PositiveSubtract(*this, t, *this);
02148         }
02149         return *this;
02150 }
02151 
02152 Integer& Integer::operator<<=(unsigned int n)
02153 {
02154         const unsigned int wordCount = WordCount();
02155         const unsigned int shiftWords = n / WORD_BITS;
02156         const unsigned int shiftBits = n % WORD_BITS;
02157 
02158         reg.CleanGrow(RoundupSize(wordCount+bitsToWords(n)));
02159         ShiftWordsLeftByWords(reg, wordCount + shiftWords, shiftWords);
02160         ShiftWordsLeftByBits(reg+shiftWords, wordCount+bitsToWords(shiftBits), shiftBits);
02161         return *this;
02162 }
02163 
02164 Integer& Integer::operator>>=(unsigned int n)
02165 {
02166         const unsigned int wordCount = WordCount();
02167         const unsigned int shiftWords = n / WORD_BITS;
02168         const unsigned int shiftBits = n % WORD_BITS;
02169 
02170         ShiftWordsRightByWords(reg, wordCount, shiftWords);
02171         if (wordCount > shiftWords)
02172                 ShiftWordsRightByBits(reg, wordCount-shiftWords, shiftBits);
02173         if (IsNegative() && WordCount()==0)   // avoid -0
02174                 *this = Zero();
02175         return *this;
02176 }
02177 
02178 void PositiveMultiply(Integer &product, const Integer &a, const Integer &b)
02179 {
02180         unsigned aSize = RoundupSize(a.WordCount());
02181         unsigned bSize = RoundupSize(b.WordCount());
02182 
02183         product.reg.CleanNew(RoundupSize(aSize+bSize));
02184         product.sign = Integer::POSITIVE;
02185 
02186         SecWordBlock workspace(aSize + bSize);
02187         AsymmetricMultiply(product.reg, workspace, a.reg, aSize, b.reg, bSize);
02188 }
02189 
02190 void Multiply(Integer &product, const Integer &a, const Integer &b)
02191 {
02192         PositiveMultiply(product, a, b);
02193 
02194         if (a.NotNegative() != b.NotNegative())
02195                 product.Negate();
02196 }
02197 
02198 Integer Integer::Times(const Integer &b) const
02199 {
02200         Integer product;
02201         Multiply(product, *this, b);
02202         return product;
02203 }
02204 
02205 /*
02206 void PositiveDivide(Integer &remainder, Integer &quotient,
02207                                    const Integer &dividend, const Integer &divisor)
02208 {
02209         remainder.reg.CleanNew(divisor.reg.size);
02210         remainder.sign = Integer::POSITIVE;
02211         quotient.reg.New(0);
02212         quotient.sign = Integer::POSITIVE;
02213         unsigned i=dividend.BitCount();
02214         while (i--)
02215         {
02216                 word overflow = ShiftWordsLeftByBits(remainder.reg, remainder.reg.size, 1);
02217                 remainder.reg[0] |= dividend[i];
02218                 if (overflow || remainder >= divisor)
02219                 {
02220                         Subtract(remainder.reg, remainder.reg, divisor.reg, remainder.reg.size);
02221                         quotient.SetBit(i);
02222                 }
02223         }
02224 }
02225 */
02226 
02227 void PositiveDivide(Integer &remainder, Integer &quotient,
02228                                    const Integer &a, const Integer &b)
02229 {
02230         unsigned aSize = a.WordCount();
02231         unsigned bSize = b.WordCount();
02232 
02233         if (!bSize)
02234                 throw Integer::DivideByZero();
02235 
02236         if (a.PositiveCompare(b) == -1)
02237         {
02238                 remainder = a;
02239                 remainder.sign = Integer::POSITIVE;
02240                 quotient = Integer::Zero();
02241                 return;
02242         }
02243 
02244         aSize += aSize%2;       // round up to next even number
02245         bSize += bSize%2;
02246 
02247         remainder.reg.CleanNew(RoundupSize(bSize));
02248         remainder.sign = Integer::POSITIVE;
02249         quotient.reg.CleanNew(RoundupSize(aSize-bSize+2));
02250         quotient.sign = Integer::POSITIVE;
02251 
02252         SecWordBlock T(aSize+2*bSize+4);
02253         Divide(remainder.reg, quotient.reg, T, a.reg, aSize, b.reg, bSize);
02254 }
02255 
02256 void Integer::Divide(Integer &remainder, Integer &quotient, const Integer &dividend, const Integer &divisor)
02257 {
02258         PositiveDivide(remainder, quotient, dividend, divisor);
02259 
02260         if (dividend.IsNegative())
02261         {
02262                 quotient.Negate();
02263                 if (remainder.NotZero())
02264                 {
02265                         --quotient;
02266                         remainder = divisor.AbsoluteValue() - remainder;
02267                 }
02268         }
02269 
02270         if (divisor.IsNegative())
02271                 quotient.Negate();
02272 }
02273 
02274 void Integer::DivideByPowerOf2(Integer &r, Integer &q, const Integer &a, unsigned int n)
02275 {
02276         q = a;
02277         q >>= n;
02278 
02279         const unsigned int wordCount = bitsToWords(n);
02280         if (wordCount <= a.WordCount())
02281         {
02282                 r.reg.Resize(RoundupSize(wordCount));
02283                 CopyWords(r.reg, a.reg, wordCount);
02284                 SetWords(r.reg+wordCount, 0, r.reg.size-wordCount);
02285                 if (n % WORD_BITS != 0)
02286                         r.reg[wordCount-1] %= (1 << (n % WORD_BITS));
02287         }
02288         else
02289         {
02290                 r.reg.Resize(RoundupSize(a.WordCount()));
02291                 CopyWords(r.reg, a.reg, r.reg.size);
02292         }
02293         r.sign = POSITIVE;
02294 
02295         if (a.IsNegative() && r.NotZero())
02296         {
02297                 --q;
02298                 r = Power2(n) - r;
02299         }
02300 }
02301 
02302 Integer Integer::DividedBy(const Integer &b) const
02303 {
02304         Integer remainder, quotient;
02305         Integer::Divide(remainder, quotient, *this, b);
02306         return quotient;
02307 }
02308 
02309 Integer Integer::Modulo(const Integer &b) const
02310 {
02311         Integer remainder, quotient;
02312         Integer::Divide(remainder, quotient, *this, b);
02313         return remainder;
02314 }
02315 
02316 void Integer::Divide(word &remainder, Integer &quotient, const Integer &dividend, word divisor)
02317 {
02318         if (!divisor)
02319                 throw Integer::DivideByZero();
02320 
02321         assert(divisor);
02322 
02323         if ((divisor & (divisor-1)) == 0)       // divisor is a power of 2
02324         {
02325                 quotient = dividend >> (BitPrecision(divisor)-1);
02326                 remainder = dividend.reg[0] & (divisor-1);
02327                 return;
02328         }
02329 
02330         unsigned int i = dividend.WordCount();
02331         quotient.reg.CleanNew(RoundupSize(i));
02332         remainder = 0;
02333         while (i--)
02334         {
02335                 quotient.reg[i] = word(MAKE_DWORD(dividend.reg[i], remainder) / divisor);
02336                 remainder = word(MAKE_DWORD(dividend.reg[i], remainder) % divisor);
02337         }
02338 
02339         if (dividend.NotNegative())
02340                 quotient.sign = POSITIVE;
02341         else
02342         {
02343                 quotient.sign = NEGATIVE;
02344                 if (remainder)
02345                 {
02346                         --quotient;
02347                         remainder = divisor - remainder;
02348                 }
02349         }
02350 }
02351 
02352 Integer Integer::DividedBy(word b) const
02353 {
02354         word remainder;
02355         Integer quotient;
02356         Integer::Divide(remainder, quotient, *this, b);
02357         return quotient;
02358 }
02359 
02360 word Integer::Modulo(word divisor) const
02361 {
02362         if (!divisor)
02363                 throw Integer::DivideByZero();
02364 
02365         assert(divisor);
02366 
02367         word remainder;
02368 
02369         if ((divisor & (divisor-1)) == 0)       // divisor is a power of 2
02370                 remainder = reg[0] & (divisor-1);
02371         else
02372         {
02373                 unsigned int i = WordCount();
02374 
02375                 if (divisor <= 5)
02376                 {
02377                         dword sum=0;
02378                         while (i--)
02379                                 sum += reg[i];
02380                         remainder = word(sum%divisor);
02381                 }
02382                 else
02383                 {
02384                         remainder = 0;
02385                         while (i--)
02386                                 remainder = word(MAKE_DWORD(reg[i], remainder) % divisor);
02387                 }
02388         }
02389 
02390         if (IsNegative() && remainder)
02391                 remainder = divisor - remainder;
02392 
02393         return remainder;
02394 }
02395 
02396 void Integer::Negate()
02397 {
02398         if (!!(*this))     // don't flip sign if *this==0
02399                 sign = Sign(1-sign);
02400 }
02401 
02402 int Integer::PositiveCompare(const Integer& t) const
02403 {
02404         unsigned size = WordCount(), tSize = t.WordCount();
02405 
02406         if (size == tSize)
02407                 return CryptoPP::Compare(reg, t.reg, size);
02408         else
02409                 return size > tSize ? 1 : -1;
02410 }
02411 
02412 int Integer::Compare(const Integer& t) const
02413 {
02414         if (NotNegative())
02415         {
02416                 if (t.NotNegative())
02417                         return PositiveCompare(t);
02418                 else
02419                         return 1;
02420         }
02421         else
02422         {
02423                 if (t.NotNegative())
02424                         return -1;
02425                 else
02426                         return -PositiveCompare(t);
02427         }
02428 }
02429 
02430 Integer Integer::SquareRoot() const
02431 {
02432         if (!IsPositive())
02433                 return Zero();
02434 
02435         // overestimate square root
02436         Integer x, y = Power2((BitCount()+1)/2);
02437         assert(y*y >= *this);
02438 
02439         do
02440         {
02441                 x = y;
02442                 y = (x + *this/x) >> 1;
02443         } while (y<x);
02444 
02445         return x;
02446 }
02447 
02448 bool Integer::IsSquare() const
02449 {
02450         Integer r = SquareRoot();
02451         return *this == r.Squared();
02452 }
02453 
02454 bool Integer::IsUnit() const
02455 {
02456         return (WordCount() == 1) && (reg[0] == 1);
02457 }
02458 
02459 Integer Integer::MultiplicativeInverse() const
02460 {
02461         return IsUnit() ? *this : Zero();
02462 }
02463 
02464 Integer a_times_b_mod_c(const Integer &x, const Integer& y, const Integer& m)
02465 {
02466         return x*y%m;
02467 }
02468 
02469 Integer a_exp_b_mod_c(const Integer &x, const Integer& e, const Integer& m)
02470 {
02471         ModularArithmetic mr(m);
02472         return mr.Exponentiate(x, e);
02473 }
02474 
02475 Integer Integer::Gcd(const Integer &a, const Integer &b)
02476 {
02477         return EuclideanDomainOf<Integer>().Gcd(a, b);
02478 }
02479 
02480 Integer Integer::InverseMod(const Integer &m) const
02481 {
02482         assert(m.NotNegative());
02483 
02484         if (IsNegative() || *this>=m)
02485                 return (*this%m).InverseMod(m);
02486 
02487         if (m.IsEven())
02488         {
02489                 if (!m || IsEven())
02490                         return Zero();  // no inverse
02491                 if (*this == One())
02492                         return One();
02493 
02494                 Integer u = m.InverseMod(*this);
02495                 return !u ? Zero() : (m*(*this-u)+1)/(*this);
02496         }
02497 
02498         SecBlock<word> T(m.reg.size * 4);
02499         Integer r((word)0, m.reg.size);
02500         unsigned k = AlmostInverse(r.reg, T, reg, reg.size, m.reg, m.reg.size);
02501         DivideByPower2Mod(r.reg, r.reg, k, m.reg, m.reg.size);
02502         return r;
02503 }
02504 
02505 word Integer::InverseMod(const word mod) const
02506 {
02507         word g0 = mod, g1 = *this % mod;
02508         word v0 = 0, v1 = 1;
02509         word y;
02510 
02511         while (g1)
02512         {
02513                 if (g1 == 1)
02514                         return v1;
02515                 y = g0 / g1;
02516                 g0 = g0 % g1;
02517                 v0 += y * v1;
02518 
02519                 if (!g0)
02520                         break;
02521                 if (g0 == 1)
02522                         return mod-v0;
02523                 y = g1 / g0;
02524                 g1 = g1 % g0;
02525                 v1 += y * v0;
02526         }
02527         return 0;
02528 }
02529 
02530 // ********************************************************
02531 
02532 ModularArithmetic::ModularArithmetic(BufferedTransformation &bt)
02533 {
02534         BERSequenceDecoder seq(bt);
02535         OID oid(seq);
02536         if (oid != ASN1::prime_field())
02537                 BERDecodeError();
02538         modulus.BERDecode(seq);
02539         seq.MessageEnd();
02540         result.reg.Resize(modulus.reg.size);
02541 }
02542 
02543 void ModularArithmetic::DEREncode(BufferedTransformation &bt) const
02544 {
02545         DERSequenceEncoder seq(bt);
02546         ASN1::prime_field().DEREncode(seq);
02547         modulus.DEREncode(seq);
02548         seq.MessageEnd();
02549 }
02550 
02551 void ModularArithmetic::DEREncodeElement(BufferedTransformation &out, const Element &a) const
02552 {
02553         a.DEREncodeAsOctetString(out, MaxElementByteLength());
02554 }
02555 
02556 void ModularArithmetic::BERDecodeElement(BufferedTransformation &in, Element &a) const
02557 {
02558         a.BERDecodeAsOctetString(in, MaxElementByteLength());
02559 }
02560 
02561 const Integer& ModularArithmetic::Half(const Integer &a) const
02562 {
02563         if (a.reg.size==modulus.reg.size)
02564         {
02565                 CryptoPP::DivideByPower2Mod(result.reg.ptr, a.reg, 1, modulus.reg, a.reg.size);
02566                 return result;
02567         }
02568         else
02569                 return result1 = (a.IsEven() ? (a >> 1) : ((a+modulus) >> 1));
02570 }
02571 
02572 const Integer& ModularArithmetic::Add(const Integer &a, const Integer &b) const
02573 {
02574         if (a.reg.size==modulus.reg.size && b.reg.size==modulus.reg.size)
02575         {
02576                 if (CryptoPP::Add(result.reg.ptr, a.reg, b.reg, a.reg.size)
02577                         || Compare(result.reg, modulus.reg, a.reg.size) >= 0)
02578                 {
02579                         CryptoPP::Subtract(result.reg.ptr, result.reg, modulus.reg, a.reg.size);
02580                 }
02581                 return result;
02582         }
02583         else
02584         {
02585                 result1 = a+b;
02586                 if (result1 >= modulus)
02587                         result1 -= modulus;
02588                 return result1;
02589         }
02590 }
02591 
02592 Integer& ModularArithmetic::Accumulate(Integer &a, const Integer &b) const
02593 {
02594         if (a.reg.size==modulus.reg.size && b.reg.size==modulus.reg.size)
02595         {
02596                 if (CryptoPP::Add(a.reg, a.reg, b.reg, a.reg.size)
02597                         || Compare(a.reg, modulus.reg, a.reg.size) >= 0)
02598                 {
02599                         CryptoPP::Subtract(a.reg, a.reg, modulus.reg, a.reg.size);
02600                 }
02601         }
02602         else
02603         {
02604                 a+=b;
02605                 if (a>=modulus)
02606                         a-=modulus;
02607         }
02608 
02609         return a;
02610 }
02611 
02612 const Integer& ModularArithmetic::Subtract(const Integer &a, const Integer &b) const
02613 {
02614         if (a.reg.size==modulus.reg.size && b.reg.size==modulus.reg.size)
02615         {
02616                 if (CryptoPP::Subtract(result.reg.ptr, a.reg, b.reg, a.reg.size))
02617                         CryptoPP::Add(result.reg.ptr, result.reg, modulus.reg, a.reg.size);
02618                 return result;
02619         }
02620         else
02621         {
02622                 result1 = a-b;
02623                 if (result1.IsNegative())
02624                         result1 += modulus;
02625                 return result1;
02626         }
02627 }
02628 
02629 Integer& ModularArithmetic::Reduce(Integer &a, const Integer &b) const
02630 {
02631         if (a.reg.size==modulus.reg.size && b.reg.size==modulus.reg.size)
02632         {
02633                 if (CryptoPP::Subtract(a.reg, a.reg, b.reg, a.reg.size))
02634                         CryptoPP::Add(a.reg, a.reg, modulus.reg, a.reg.size);
02635         }
02636         else
02637         {
02638                 a-=b;
02639                 if (a.IsNegative())
02640                         a+=modulus;
02641         }
02642 
02643         return a;
02644 }
02645 
02646 const Integer& ModularArithmetic::Inverse(const Integer &a) const
02647 {
02648         if (!a)
02649                 return a;
02650 
02651         CopyWords(result.reg.ptr, modulus.reg, modulus.reg.size);
02652         if (CryptoPP::Subtract(result.reg.ptr, result.reg, a.reg, a.reg.size))
02653                 Decrement(result.reg.ptr+a.reg.size, 1, modulus.reg.size-a.reg.size);
02654 
02655         return result;
02656 }
02657 
02658 Integer ModularArithmetic::CascadeExponentiate(const Integer &x, const Integer &e1, const Integer &y, const Integer &e2) const
02659 {
02660         if (modulus.IsOdd())
02661         {
02662                 MontgomeryRepresentation dr(modulus);
02663                 return dr.ConvertOut(dr.CascadeExponentiate(dr.ConvertIn(x), e1, dr.ConvertIn(y), e2));
02664         }
02665         else
02666                 return AbstractRing<Integer>::CascadeExponentiate(x, e1, y, e2);
02667 }
02668 
02669 void ModularArithmetic::SimultaneousExponentiate(Integer *results, const Integer &base, const Integer *exponents, unsigned int exponentsCount) const
02670 {
02671         if (modulus.IsOdd())
02672         {
02673                 MontgomeryRepresentation dr(modulus);
02674                 dr.SimultaneousExponentiate(results, dr.ConvertIn(base), exponents, exponentsCount);
02675                 for (unsigned int i=0; i<exponentsCount; i++)
02676                         results[i] = dr.ConvertOut(results[i]);
02677         }
02678         else
02679                 AbstractRing<Integer>::SimultaneousExponentiate(results, base, exponents, exponentsCount);
02680 }
02681 
02682 MontgomeryRepresentation::MontgomeryRepresentation(const Integer &m)    // modulus must be odd
02683         : ModularArithmetic(m),
02684           u((word)0, modulus.reg.size),
02685           workspace(5*modulus.reg.size)
02686 {
02687         assert(modulus.IsOdd());
02688         RecursiveInverseModPower2(u.reg, workspace, modulus.reg, modulus.reg.size);
02689 }
02690 
02691 const Integer& MontgomeryRepresentation::Multiply(const Integer &a, const Integer &b) const
02692 {
02693         word *const T = workspace.ptr;
02694         word *const R = result.reg.ptr;
02695         const unsigned int N = modulus.reg.size;
02696         assert(a.reg.size<=N && b.reg.size<=N);
02697 
02698         AsymmetricMultiply(T, T+2*N, a.reg, a.reg.size, b.reg, b.reg.size);
02699         SetWords(T+a.reg.size+b.reg.size, 0, 2*N-a.reg.size-b.reg.size);
02700         MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N);
02701         return result;
02702 }
02703 
02704 const Integer& MontgomeryRepresentation::Square(const Integer &a) const
02705 {
02706         word *const T = workspace.ptr;
02707         word *const R = result.reg.ptr;
02708         const unsigned int N = modulus.reg.size;
02709         assert(a.reg.size<=N);
02710 
02711         RecursiveSquare(T, T+2*N, a.reg, a.reg.size);
02712         SetWords(T+2*a.reg.size, 0, 2*N-2*a.reg.size);
02713         MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N);
02714         return result;
02715 }
02716 
02717 Integer MontgomeryRepresentation::ConvertOut(const Integer &a) const
02718 {
02719         word *const T = workspace.ptr;
02720         word *const R = result.reg.ptr;
02721         const unsigned int N = modulus.reg.size;
02722         assert(a.reg.size<=N);
02723 
02724         CopyWords(T, a.reg, a.reg.size);
02725         SetWords(T+a.reg.size, 0, 2*N-a.reg.size);
02726         MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N);
02727         return result;
02728 }
02729 
02730 const Integer& MontgomeryRepresentation::MultiplicativeInverse(const Integer &a) const
02731 {
02732 //        return (EuclideanMultiplicativeInverse(a, modulus)<<(2*WORD_BITS*modulus.reg.size))%modulus;
02733         word *const T = workspace.ptr;
02734         word *const R = result.reg.ptr;
02735         const unsigned int N = modulus.reg.size;
02736         assert(a.reg.size<=N);
02737 
02738         CopyWords(T, a.reg, a.reg.size);
02739         SetWords(T+a.reg.size, 0, 2*N-a.reg.size);
02740         MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N);
02741         unsigned k = AlmostInverse(R, T, R, N, modulus.reg, N);
02742 
02743 //      cout << "k=" << k << " N*32=" << 32*N << endl;
02744 
02745         if (k>N*WORD_BITS)
02746                 DivideByPower2Mod(R, R, k-N*WORD_BITS, modulus.reg, N);
02747         else
02748                 MultiplyByPower2Mod(R, R, N*WORD_BITS-k, modulus.reg, N);
02749 
02750         return result;
02751 }
02752 
02753 NAMESPACE_END

Generated at Mon Jan 15 01:16:33 2001 for Crypto++ by doxygen1.2.4 written by Dimitri van Heesch, © 1997-2000