# Fernando Guevara Vasquez, 2013
import random,fractions

_mrpt_num_trials = 5 # number of bases to test

def is_probable_prime(n):
    """
    Miller-Rabin primality test.

    A return value of *False* means *n* is certainly not prime. A return value of
    *True* means *n* is very likely a prime.

    This code comes from:
    http://rosettacode.org/wiki/Miller-Rabin_primality_test#Python
    (with corrections!)

    Example::

        >>> is_probable_prime(50800665469)
        True
        >>> is_probable_prime(2**1279 -1)
        True

    """
    assert n >= 2
    # special case 2
    if n == 2: return True
    # ensure n is odd
    if n % 2 == 0: return False
    # write n-1 as 2**s * d, by repeatedly dividing by 2
    s,d = 0,n-1
    while not d % 2:
        s, d = s + 1, d >> 1
    assert(2**s * d == n-1)
    for i in range(_mrpt_num_trials):
        a = random.randint(2, n-1) # between 2 and n-1
        x = pow(a,d,n)
        if x==1 or x==n-1: continue # try another a
        for r in range(1,s):
            x = pow(x,2,n)
            if x == 1: return False # composite
            if x == n-1: break
        if not x == n-1: return False # composite
    return True # probable prime

def randprime(ndigits):
    """
    Generate a prime number with *ndigits* digits
    
    Example::

        >>> randprime(10)
        50800665469

    """
    for i in range(1,500):
        n = random.randint(10**ndigits,10**(ndigits+1))
        if is_probable_prime(n): 
            return n
    assert 0,'no prime found!'

def egcd(a,b):
    """
    Carries out extended Euclid's algorithm on *a* and *b*.

    Outputs:
       *x*, *y* integers such that :math:`ax+by = \\text{gcd}(a,b)`

    How does this work? Assuming :math:`a\geq b`, the integer division gives
       :math:`a = bq+r`
    If :math:`r=0` then we are done since gcd(a,b) = b if not we have 
       :math:`\Rightarrow \\text{gcd}(a,b) = \\text{gcd}(b,r)`

    For the extended gcd part, if :math:`r=0` we certainly have :math:`b = \\text{gcd}(a,b) = a \cdot 0 + b \cdot 1`. Otherwise, if we know that :math:`bx + ry = \\text{gcd}(b,r) = \\text{gcd}(a,b)`, we must have that
        :math:`\\text{gcd}(a,b) = bx + ry = bx + (a-bq)y = ay + (x-qy)b`

    Example::
        
        >>> egcd(880,560)
        (80, 2, -3)

    """
    if b>a:
        g,x,y = egcd(b,a)
        return g,y,x  # if order of a,b is swapped, do swap x,y too
    # now a>=b
    q,r = divmod(a,b) # computes q,r s.t a = bq + r
    if r==0:
        return b,0,1
    else:
        g,x,y=egcd(b,r)
        return g,y,x-q*y
   
def modinv(a, n):
    """
    Computes the multiplicative inverse of *a* modulo *n*. For the
    multiplcative inverse to exist, we must have
    :math:`\\text{gcd}(a,n)=1`

    Outputs:
        x integer such that :math:`ax \mod n = 1`

    Example::

        >>> modinv(17,880)
        673
        
    """
    g, x, y = egcd(a, n)
    if g != 1:
        raise ValueError("gcd(a,n) must be 1 to guarantee existence of multiplicative inverse of a")
    return x % n


def rsa_check(p,q,N,N2,e,d,ndigits):
    """ 
    Do some checks on some RSA numbers. The numbers p,q,N,N2,e,d could come
    from :py:func:`rsa_gen`

    Example::

        >>> status,msg = rsa_check(23,41,943,880,17,503,1)
        >>> print str(status)+" "+msg
        False The following diagnostics failed:
          e and d are not multiplicative inverses of each other
          problem in enc/dec
        >>> status,msg = rsa_check(23,41,943,880,17,673,1)
        >>> print str(status)+" "+msg
        True All tests passed!
          N  has 3 digits
          N2 has 3 digits

    """
    s = ''
    if not is_probable_prime(p): s=s+'  p does not appear to be prime\n'
    if not is_probable_prime(q): s=s+'  q does not appear to be prime\n'
    if len(str(p))!=ndigits+1:   s=s+'  p does not have correct number of digits\n'
    if len(str(q))!=ndigits+1:   s=s+'  q does not have correct number of digits\n'
    if len(str(e))!=ndigits+1:   s=s+'  e does not have correct number of digits\n'
    if N!=p*q:                        s=s+'  N must be p*q\n'
    if N2!=(p-1)*(q-1):           s=s+'  N2 must be (p-1)*(q-1)\n'
    if not e<=N2:                s=s+'  e must be smaller than N2\n'
    if fractions.gcd(e,N2)!=1:   s=s+'  e not relatively prime to N2\n'
    if fractions.gcd(d,N2)!=1:    s=s+'  d not relatively prime to N2\n'
    if e*d % N2 != 1:                 s=s+'  e and d are not multiplicative inverses of each other\n'
    encdec = True
    for i in range(1,10):
       # pick a random message
        x = random.randint(10**ndigits,10**(ndigits+1))
        # encrypt
        m = pow(x,e,N)
        # decrypt
        y = pow(m,d,N)
        encdec = encdec and x==y
    if not encdec:                   s=s+'  problem in enc/dec'
    if s=='':
        s='All tests passed!\n'
        s=s+'  N  has %d digits\n'%len(str(N))
        s=s+'  N2 has %d digits\n'%len(str(N2))
        return True,s
    else:
        s='The following diagnostics failed:\n'+s
        return False,s

_davis_table= " ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz.,:;'\"`!@#$%^&*-+()[]{}?/<>0123456789"

def davis_enc(m):
    """
     Davis table encoding (from letters to numbers)

     Outputs:
        A large integer with an even number of digits

     see http://mathcircle.berkeley.edu/BMC3/crypto.pdf

     Example::

        >>> davis_enc("Hello!")
        184148485170

    """
    n = ''
    for j in range(0,len(m)):
        i=_davis_table.find(m[j])
        if i>=0:
            n = n+str(i+10)
    return int(n)

def davis_dec(x):
    """
    Davis table decoding (from numbers to letters)

    Outputs:
        The message (str) encoded in x

    see http://mathcircle.berkeley.edu/BMC3/crypto.pdf

    Example::

        >>> davis_dec(184148485170)
        'Hello!'

    """
    s = str(x)
    assert len(s) % 2 == 0, 'number must have even number of digits!'
    m = ''
    for i in range(0,len(s)/2):
        n=int(s[2*i:2*(i+1)])
        assert 10<=n and n<=99, 'n must be between 10 and 99'
        m = m + _davis_table[n-10]
    return m

def chop_message(m,n):
    """
    Chop message *m* it into packets of at most *n* characters

    Example::

        >>> chop_message("This is a not so long message!",10)
        ['This is a ', 'not so lon', 'g message!']
        
    """
    return [ m[i:i+n] for i in range(0,len(m),n) ]

def encode_message(m,n): 
    """
    Encode message *m* with packets of at most 2 *n* digits integers (using
    Davis table, see :py:func:`davis_enc`)

    Example::

        >>> encode_message("This is a not so long message!",10)
        [30444555104555103710L, 50515610555110485150L, 43104941555537434170L]
    """
    return map(davis_enc,chop_message(m,n))

def decode_message(em): 
    """
    Decode a list of integer packets into a message (using Davis table, see
    :py:func:`davis_dec`) 

    Example::

        >>> print decode_message([30444555104555103710L, 50515610555110485150L, 43104941555537434170L])
        This is a 
        not so lon
        g message!
    """
    return "\n".join(map(davis_dec,em))


def chop_integer(m,n):
    """
    Chop a large integer *m* into integer packets, where each packet has 
    at most *n* digits

    Example::
    
        >>> chop_integer(123451234512345,5)
        [12345, 12345, 12345]

    """
    return map(int,chop_message(str(m),n))

def join_integer(l):
    """
    Joins a large integer that has been chopped in a list

    Example::

        >>> join_integer([1234567,8910111213,141516171819])
        12345678910111213141516171819L

    """
    return int("".join(map(str,l)))
