Exponentiation (jl)


function msb(x)
        local i=0
        t=copy(x)
        while t>0 
          i+=1
          t>>=1 
        end
        i
end

function bad_pow(x,n)
    n==0 && return oneunit(x)
    n<0  && return bad_pow(inv(x),-n)
    x*bad_pow(x,n-1)
end

function better_pow(x,n)
    c=oneunit(x)
    n==0 && return c
    n<0  && return better_pow(inv(x),-n)
    for i in msb(n):-1:0
        c *= c
        (((n>>i) & 1)==1) && ( c*= x)
    end
    c
end

function my_gcd(x,y)
    x>y || return my_gcd(y,x)
    y==0 && return x
    @show (x,y)
    my_gcd(y,mod(x,y))
end

function my_xgcd(x,y)
    local U=[oneunit(x),zero(x),x]
    local V=[zero(y),oneunit(y),y]
    while V[3]!=zero(x)
        q=U[3]÷V[3]
        U,V=V,U-V*q
        @show U,V
    end
    U
end

function Jacobi(a,m)
    t=1
    while (a%=m)!=zero(a)
        while (a%2)==zero(a)
            a=a÷2
            (m%8 ∈ [3,5]) && (t*=-1)
        end
        a,m=m,a
        (a%4==3) && (a%4==m%4) && (t*=-1)
    end
    m==1 && return t
    return 0
end

function SPprime(n,a)
    t,s = n-1,0
    while (t%2==0)
        s +=  1
        t >>= 1
    end
    b=powermod(a,t,n)
    ((b==1) | (b==n-1)) && return true
    for j in 1:s-1
        b=powermod(b,2,n)
        b==n-1 && return true
    end
    false
end

function SolvayStrassen(n)
    n==2 && return true
    n>2 && (n%2==zero(n)) && return false
    W=(n+3)÷2
    for a in 2:W
        if (powermod(a,(n-1)÷2,n)!=(n+Jacobi(a,n))%n) 
            print("Witness=",a,"\n")
            return false
        end
    end
    true
end

function Miller(n)
    W = Integer(min(round(2*log(n)^2),n-1))
    for a in 2:W
        SPprime(n,a) || return false
    end
    true
end

function AKS(n)
    
    "Approximates (as integer) the k-th root of n"
    function int_root(n,k)
        x,y=1<<Integer(ceil(msb(n)/k)),0
        while(y<x)
            (y>0) && (x=y)
            y=((k-1)*x+(n÷x^(k-1)))÷k
        end
        x
    end
    
    "Is n an i-th power?"
    function ispower(n,i)
        int_root(n,i)^i==n
    end
        
    "Is n a perfect power for some i?"
    function ispower(n)
        for i=2:Integer(ceil(log2(n)))
            ispower(n,i) && return true
        end
        return false
    end

    "Find an integer r such that Or(n)>log2(n)^2"
    function findOrder(n)
        lval=Integer(floor(log2(n))^2)+1
        uval=Integer(floor(max(3,log2(n)^5)))
        for r=lval:uval
            (1∉[powermod(n,i,r) for i in 1:lval]) && return r
        end
    end

    "Compute (A*B  mod (X-1)^r) mod n where r=lenght(B)"
    function ⋆(A::Vector{T},B::Vector{T}) where T<: Integer
        r=length(B)
        Rpol=fill(0,r)
        for i in 0:r-1
            Rpol+=A[i+1]*circshift(B,i)
            Rpol.%=n
        end
        Rpol
    end
        
 
    "Compute (X+a)^n mod (X-1)^r) mod n"
    function Ξ(a::Integer,n::Integer,r::Integer)
        Xpol=fill(0,r+1)
        Xpol[[1,2]]=[a,1]
        Rpol=fill(0,r+1)
        Rpol[1]=1
        for i in msb(n):-1:0
            Rpol=Rpol⋆Rpol
            (((n>>i) & 1)==1) && ( Rpol=Rpol⋆Xpol )
        end
        Rpol
    end
    # Step 1
    ispower(n) && return false
    # Step 2
    r=findOrder(n)
    # Step 3
    for a in 2:r
        gcd(a,n)!=1 && gcd(a,n)!=n && return false
    end
    # Step 4
    n≤r && return true
    # Step 5
    for a in 1:Integer(ceil(√(r))*ceil(log2(n)))
        Xpol=Ξ(a,n,r)
        Xpol[1]-=a
        Xpol[(n%r)+1]-=1
        Xpol = Xpol .% n
        Xpol==0*Xpol || return false 
    end
    # Step 6
    return true
end
        

PrimeP=Miller

function NextPrime(n)
    (n<=2)    && return 2
    (n%2==0)  && return NextPrime(n+1)
    PrimeP(n) && return n
    NextPrime(n+2)
end


### DEFINE Zp

import Base.+, Base.-, Base.*,Base.^,Base.inv,Base.==,Base.oneunit,Base.zero

struct ZpElement <: Number
    value::Integer
    prime::Integer
end

function oneunit(x::ZpElement)
    ZpElement(1,x.prime)
end

function zero(x::ZpElement)
    ZpElement(0,x.prime)
end

function +(x::ZpElement, y::ZpElement)
    ZpElement(mod(x.value+y.value,x.prime),x.prime)
end

function *(t::BigInt, y::ZpElement)
    ZpElement(mod(t*y.value,y.prime),y.prime)
end

function -(y::ZpElement)
   -1*y
end


function -(x::ZpElement,y::ZpElement)
   x+-y
end

function ==(x::ZpElement,y::ZpElement)
	(x.value==y.value) & (x.prime==y.prime)
end

function *(x::ZpElement,y::ZpElement)
    ZpElement(mod(x.value*y.value,x.prime),x.prime)
end


function inv(x::ZpElement)
    v=my_xgcd(x.value,x.prime)
    return ZpElement(mod(v[1],x.prime),x.prime)
end

function ^(x::ZpElement,i::BigInt)
    better_pow(x,i)
end

function CRT( a::Vector{T}, b::Vector{T} ) where T <: Integer
    ap,bp,r=my_xgcd(a[2],b[2])
    @assert r≡oneunit(r) "Must be coprime"
    [ mod(a[1]*bp*b[2]+b[1]*ap*a[2],a[2]*b[2]), a[2]*b[2] ]
end

function CRT( x::Vector{Vector{T}} ) where T <: Integer
    reduce(CRT,x)
end