#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>

// witness for Miller-Rabin primality test
const int a[9] = {2, 3, 5, 7, 11, 13, 17, 19, 23};

// calculate x * y module p for big number (avoid overflow)
long long int mod_mul(long long int x, long long int y, 
                      long long int p) {
    long long int res = 0;
    x %= p, y %= p;
    while (y > 0) {
        if(y & 1) {
            res += x;
            if(res >= p) res -= p;
        }
        x <<= 1, y >>= 1;
        if(x >= p) x -= p;
    }
    return res;
}

// calculate x^y module p
long long int mod_pow(long long int x, long long int y, 
                      long long int p) {
    long long int res = 1;
    while (y > 0) {
        if (y & 1) res = mod_mul(res, x, p);
        x = mod_mul(x, x, p);
        y >>= 1;
    }
    return res;
}

// determine whether n is prime number
int is_prime(long long int n) {
    int i, j, s, pass;
    long long int d, x;
    
    if (n < 2) return 0;
    for (i = 0; i < 9; ++i) {
        if (n % a[i] == 0){
            if (n == a[i]) return 1;
            else return 0;
        }
    }
    // n - 1 = 2^s * d (d is odd number)
    d = n - 1, s = 0;
    while (d & 1 == 0) {
        d >>= 1;
        s++;
    }
    for (i = 0; i < 9; ++i) {
        x = mod_pow(a[i], d, n);
        if (x == 1 || x == n - 1) continue;
        pass = 0;
        for (j = 1; j < s; ++j) {
            x = mod_mul(x, x, n);
            if (x == 1) return 0;
            if (x == n - 1) {
                pass = 1;
                break;
            }
        }
        if (pass == 1) continue;
        return 0;
    }
    return 1;
}

void sqr_sum(int p, int x){
    int j=0, u, n0, n1, n2, n3;
    
    u = mod_pow(x, (p-1)/4, p);
    u = u % p;
    if (u > p/2) u = p-u;

    n0=p;
    n1=u;
    n2 = n0 % n1;

    while (n2 >= sqrt(p)){
        n0 = n1; n1 = n2;
        n2 = n0 % n1;
    }
    n3 = n1 % n2;
    printf("%d = %d^2  + %d^2\n", p, n2, n3);
}

int Quadratic_nonresidue(int n){
    int u;
    if (n % 8 == 5) return 2;
    if (n % 3 == 2) return 3;
    if (n % 5 == 2) return 5;
    if (n % 5 == 3) return 5;
    printf("* ");
    for (u=1; u<=n-1; u++) {
        if (mod_pow(u,(n-1)/2,n)!=1)
            return u;
    }
    return 0;
}

/* 実験回数の上限 */
#define MAX_NUM 20

int main(int argc, char *argv[]) {
    long int ct=0, i, j, n, x;
    n = 0;
    srand((unsigned)time(NULL));

    while (ct++ < MAX_NUM) {
        while (1) {
            n = rand();
            if (is_prime(n)==0) {
//                printf("素数ではない！！\n"); 
                continue;}
            if ((n-1)%4 != 0) {
//                printf("4k+1型！！\n"); 
                continue;}
            break;
        } 
        x = Quadratic_nonresidue(n);
        if (x==0) continue;
        sqr_sum(n,x);
    }
    return 0;
}
