#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>

// witness for Miller-Rabin primality test
const int a[9] = {2, 3, 5, 7, 11, 13, 17, 19, 23};

// calculate x * y module p for big number (avoid overflow)
long long int mod_mul(long long int x, long long int y, 
                      long long int p) {
    long long int res = 0;
    x %= p, y %= p;
    while (y > 0) {
        if(y & 1) {
            res += x;
            if(res >= p) res -= p;
        }
        x <<= 1, y >>= 1;
        if(x >= p) x -= p;
    }
    return res;
}

// calculate x^y module p
long long int mod_pow(long long int x, long long int y, 
                      long long int p) {
    long long int res = 1;
    while (y > 0) {
        if (y & 1) res = mod_mul(res, x, p);
        x = mod_mul(x, x, p);
        y >>= 1;
    }
    return res;
}

// determine whether n is prime number
int is_prime(long long int n) {
    int i, j, s, pass;
    long long int d, x;
    
    if (n < 2) return 0;
    for (i = 0; i < 9; ++i) {
        if (n % a[i] == 0){
            if (n == a[i]) return 1;
            else return 0;
        }
    }
    // n - 1 = 2^s * d (d is odd number)
    d = n - 1, s = 0;
    while (d & 1 == 0) {
        d >>= 1;
        s++;
    }
    for (i = 0; i < 9; ++i) {
        x = mod_pow(a[i], d, n);
        if (x == 1 || x == n - 1) continue;
        pass = 0;
        for (j = 1; j < s; ++j) {
            x = mod_mul(x, x, n);
            if (x == 1) return 0;
            if (x == n - 1) {
                pass = 1;
                break;
            }
        }
        if (pass == 1) continue;
        return 0;
    }
    return 1;
}

/* 実験回数の上限 */
#define MAX_NUM 20

int main(int argc, char *argv[]) {
    long int ct=0, i, j, n;
    n = 0;
    srand((unsigned)time(NULL));

    while (ct++ < MAX_NUM) {
        while (1) {
            n = rand();
            if (is_prime(n)==0) {
//                printf("素数ではない！！\n"); 
                continue;}
            if ((n-1)%4 != 0) {
//                printf("4k+1型！！\n"); 
                continue;}
            break;
        } 
    
        int sqrn = (int) sqrt(n + 0.5);
        for (i = 1; i < sqrn; ++i) {
            int x = n - i * i;
            int p = (int) sqrt(x + 0.5);
            if (p * p == x) {
                printf("%ld = %ld^2 + %ld^2\n", n, i, p);
                break;
            }
        }
    }
    return 0;
}
