Skip to content
Snippets Groups Projects
nqueens.c 5.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • gback's avatar
    gback committed
    /* NQueens solver.
     * Written by gback for CS 3214
     */
    
    spruett3's avatar
    spruett3 committed
    #include <stdio.h>
    #include <stdbool.h>
    #include <stdlib.h>
    #include <string.h>
    #include <getopt.h>
    #include "threadpool.h"
    #include "threadpool_lib.h"
    
    #define MAX_N (18)
    #define WORD_BITS (sizeof(long) * 8)
    
    gback's avatar
    gback committed
    #define MAX_LONGS ((MAX_N + WORD_BITS - 1) / WORD_BITS * MAX_N)
    
    spruett3's avatar
    spruett3 committed
    
    static int max_parallel_depth = 6;
    static int valid_solutions[] = {0, 1, 0, 0, 2, 10, 4, 40, 92, 352, 724, 2680, 14200,
                                    73712, 365596, 2279184, 14772512, 95815104, 666090624};
    
    struct board {
        long bits[MAX_LONGS];
    };
    
    struct board_state {
        struct board board;
        int N;
        int row;
    }; 
    
    static bool is_queen(struct board* board, int x, int y, int N) {
        if (x < 0 || x >= N || y < 0 || y >= N) {
            return false;
        }
        long long idx = x * N + y;
    
        return (board->bits[idx / WORD_BITS] & (1UL << (idx % WORD_BITS))) ==
            (1UL << (idx % WORD_BITS));
    
    spruett3's avatar
    spruett3 committed
    }
    
    static void set_queen(struct board* board, int x, int y, int N) {
        int idx = x * N + y;
    
        board->bits[idx / WORD_BITS] |= (1UL << (idx % WORD_BITS));
    
    spruett3's avatar
    spruett3 committed
    }
    static void unset_queen(struct board* board, int x, int y, int N) {
        int idx = x * N + y;
    
        board->bits[idx / WORD_BITS] &= ~(1UL << (idx % WORD_BITS));
    
    spruett3's avatar
    spruett3 committed
    }
    
    static int solved(struct board* board, int N) {
        int queens = 0;
        int x, y, k;
        for (x = 0; x < N; x++) {
            for (y = 0; y < N; y++) {
                if (is_queen(board, x, y, N)) {
                    queens++;
                    for (k = 1; k < N; k++) {
                        if (is_queen(board, x + k, y, N)
                            || is_queen(board, x, y + k, N)
                            || is_queen(board, x + k, y + k, N)
                            || is_queen(board, x + k, y - k, N)) {
                            return -1;
                        }
                    }
                }
            }
        }
        return queens;
    }
    
    
    static void* backtrack(struct thread_pool* pool, void* _state) {
        int i;
        struct board_state* state = (struct board_state*)_state;
        if (state->N == state->row && solved(&state->board, state->N) == state->N) {
            //print_board(&state->board, state->N);
            return (void*)1;
        }
        else if (state->row == state->N) {
            return (void*)0;
        }
        else if (solved(&state->board, state->N) == -1) {
            return (void*)0;
        }
        if (state->row < max_parallel_depth) {
            struct board_state* boards = calloc(sizeof(struct board_state), state->N);
            struct future** futures = calloc(sizeof(struct future*), state->N - 1);
            long slns = 0;
            for (i = 0; i < state->N; i++) {
                boards[i].N = state->N;
                boards[i].row = state->row + 1;
                memcpy(&boards[i].board, &state->board, sizeof(struct board));
                set_queen(&boards[i].board, state->row, i, state->N);
                if (i != state->N - 1) {
                    futures[i] = thread_pool_submit(pool, backtrack, &boards[i]);
                }
            }
            slns += (long)backtrack(pool, &boards[state->N - 1]);
            for (i = 0; i < state->N - 1; i++) {
                slns += (long)future_get(futures[i]);
                future_free(futures[i]);
            }
            free(futures);
            free(boards);
            return (void*)slns;
        }
        else {
            long slns = 0;
            state->row++;
            for (i = 0; i < state->N; i++) {
                set_queen(&state->board, state->row - 1, i, state->N);
                slns += (long)backtrack(pool, state);
                unset_queen(&state->board, state->row - 1, i, state->N);
            }
            state->row--;
            return (void*)slns;
        }
    }
    
    static void benchmark(int N, int threads) {
        printf("Solving N = %d\n", N);
        struct board_state state;
        memset(&state.board, 0, sizeof(struct board));
        state.N = N;
        state.row = 0;
    
        struct thread_pool* pool = thread_pool_new(threads);
    
        struct benchmark_data* bdata = start_benchmark();
        
        struct future* fut = thread_pool_submit(pool, backtrack, &state);
        long slns = (long)future_get(fut);
    
        stop_benchmark(bdata);
    
        future_free(fut);
        thread_pool_shutdown_and_destroy(pool);
    
        printf("Solutions: %d\n", (int)slns);
        if (slns == valid_solutions[N]) {
            printf("Solution ok.\n");
            report_benchmark_results(bdata);
        }
        else { 
            fprintf(stderr, "Solution bad.\n");
            abort();
        }
    }
    
    static void usage(char *av0, int depth, int nthreads) {
    
    gback's avatar
    gback committed
        fprintf(stderr, "Usage: %s [-d <n>] [-n <n>] <N>\n"
    
    spruett3's avatar
    spruett3 committed
                        " -d        parallel recursion depth, default %d\n"
                        " -n        number of threads in pool, default %d\n"
                        , av0, depth, nthreads);
        abort();
    }
    
    gback's avatar
    gback committed
    
    
    spruett3's avatar
    spruett3 committed
    int main(int ac, char** av) {
        int threads = 4;
        int c;
    
    gback's avatar
    gback committed
        while ((c = getopt(ac, av, "d:n:h")) != EOF) {
    
    spruett3's avatar
    spruett3 committed
            switch (c) {
            case 'd':
                max_parallel_depth = atoi(optarg);
                break;
            case 'n':
                threads = atoi(optarg);
                break;
            case 'h':
                usage(av[0], max_parallel_depth, threads);
            }
        }
        if (optind == ac)
            usage(av[0], max_parallel_depth, threads);
    
        int N = atoi(av[optind]);
        if (N > MAX_N || N < 0) {
            fprintf(stderr, "N must be between 0 and %d\n", MAX_N);
            abort();
        }
        benchmark(N, threads);
        return 0;
        
    }