# Fernando Guevara Vasquez, 2013
import random,fractions

_mrpt_num_trials = 5 # number of bases to test

def is_probable_prime(n):
    """
    Miller-Rabin primality test.

    A return value of *False* means *n* is certainly not prime. A return value of
    *True* means *n* is very likely a prime.

    This code comes from:
    http://rosettacode.org/wiki/Miller-Rabin_primality_test#Python
    (with corrections!)
    """
    assert n >= 2
    # special case 2
    if n == 2: return True
    # ensure n is odd
    if n % 2 == 0: return False
    # write n-1 as 2**s * d, by repeatedly dividing by 2
    s,d = 0,n-1
    while not d % 2:
        s, d = s + 1, d >> 1
    assert(2**s * d == n-1)
    for i in range(_mrpt_num_trials):
        a = random.randint(2, n-1) # between 2 and n-1
        x = pow(a,d,n)
        if x==1 or x==n-1: continue # try another a
        for r in range(1,s):
            x = pow(x,2,n)
            if x == 1: return False # composite
            if x == n-1: break
        if not x == n-1: return False # composite
    return True # probable prime

def randprime(ndigits):
    """
    Generate a prime number with *ndigits* digits
    """
    for i in range(1,500):
        n = random.randint(10**ndigits,10**(ndigits+1))
        if is_probable_prime(n): 
            return n
    assert 0,'no prime found!'

def extended_gcd(a,b):
    """
    Carries out extended Euclid's algorithm on *a* and *b*.

    Outputs:
       *x*,*y* integers such that :math:`ax+by = \\text{gcd}(a,b)`

    See http://rosettacode.org/wiki/Modular_inverse.
    """
    lastremainder, remainder = abs(a), abs(b)
    x, lastx, y, lasty = 0, 1, 1, 0
    while remainder:
        lastremainder, (quotient, remainder) = remainder, divmod(lastremainder, remainder)
        x, lastx = lastx - quotient*x, x
        y, lasty = lasty - quotient*y, y
    return lastremainder, lastx * (-1 if a < 0 else 1), lasty * (-1 if b < 0 else 1)
 
def modinv(a, m):
    """
    Computes the inverse of *a* modulo *m*

    Outputs:
        x integer such that :math:`ax \mod m = 1`

    from http://rosettacode.org/wiki/Modular_inverse 
    """
    g, x, y = extended_gcd(a, m)
    if g != 1:
        raise ValueError
    return x % m


def rsa_check(p,q,N,N2,e,d,ndigits):
    """ 
    Do some checks on some RSA numbers. The numbers p,q,N,N2,e,d could come
    from :py:func:`rsa_gen`
    """
    s = ''
    if not is_probable_prime(p): s=s+'  p does not appear to be prime\n'
    if not is_probable_prime(q): s=s+'  q does not appear to be prime\n'
    if len(str(p))!=ndigits+1:   s=s+'  p does not have correct number of digits\n'
    if len(str(q))!=ndigits+1:   s=s+'  q does not have correct number of digits\n'
    if len(str(e))!=ndigits+1:   s=s+'  e does not have correct number of digits\n'
    if N!=p*q:                        s=s+'  N must be p*q\n'
    if N2!=(p-1)*(q-1):           s=s+'  N2 must be (p-1)*(q-1)\n'
    if not e<=N2:                s=s+'  e must be smaller than N2\n'
    if fractions.gcd(e,N2)!=1:   s=s+'  e not relatively prime to N2\n'
    if fractions.gcd(d,N2)!=1:    s=s+'  d not relatively prime to N2\n'
    if e*d % N2 != 1:                 s=s+'  e and d are not multiplicative inverses of each other\n'
    encdec = True
    for i in range(1,10):
       # pick a random message
        x = random.randint(10**ndigits,10**(ndigits+1))
        # encrypt
        m = pow(x,e,N)
        # decrypt
        y = pow(m,d,N)
        encdec = encdec and x==y
    if not encdec:                   s=s+'  problem in enc/dec'
    if s=='':
        s='All tests passed!\n'
        s=s+'  N  has %d digits\n'%len(str(N))
        s=s+'  N2 has %d digits\n'%len(str(N2))
        return True,s
    else:
        s='The following diagnostics failed:\n'+s
        return False,s

_davis_table= " ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz.,:;'\"`!@#$%^&*-+()[]{}?/<>0123456789"

def davis_enc(m):
    """
     Davis table encoding (from letters to numbers)

     Outputs:
        A large integer with an even number of digits

     see http://mathcircle.berkeley.edu/BMC3/crypto.pdf
    """
    n = ''
    for j in range(0,len(m)):
        i=_davis_table.find(m[j])
        if i>=0:
            n = n+str(i+10)
    return int(n)

def davis_dec(x):
    """
    Davis table decoding (from numbers to letters)

    Outputs:
        The message (str) encoded in x

    see http://mathcircle.berkeley.edu/BMC3/crypto.pdf
    """
    s = str(x)
    assert len(s) % 2 == 0, 'number must have even number of digits!'
    m = ''
    for i in range(0,len(s)/2):
        n=int(s[2*i:2*(i+1)])
        assert 10<=n and n<=99, 'n must be between 10 and 99'
        m = m + _davis_table[n-10]
    return m

def chop_message(m,n):
    """
    Chop message *m* it into packets of at most *n* characters
    """
    return [ m[i:i+n] for i in range(0,len(m),n) ]

def encode_message(m,n): 
    """
    Encode message *m* with packets of at most 2*n* digits integers (using Davis table, see :py:func:`davis_enc`)
    """
    return map(davis_enc,chop_message(m,n))

def decode_message(em): 
    """
    Decode a list of integer packets into a message (using Davis table, see :py:func:`davis_dec`)
    """
    return "\n".join(map(davis_dec,em))


