/*
 * Atari DVG bytecode disassembler.
 *
 * Examples:
 *  dvg_disasm --decode 0xca80 
 *  dvg_disasm --decode 0xa0e4 0x115e
 *  dvg_disasm --list Asteroids 0x1800 0x71e 0x4000 0x1000
 *  dvg_disasm --sgec Asteroids 0x1800 0x71e 0x4000 0x1000
 *
 * Thanks:
 *  https://wiki.philpem.me.uk/_media/elec/vecgen/vecgen.pdf
 *
 * Copyright 2021 faddenSoft.  Licensed under the Apache License, Version 2.0.
 */
#include <stdlib.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>

#include <string>

#define NELEM(x) (sizeof(x) / sizeof(x[0]))

enum Opcode {
    INSTR_UNKNOWN = 0,
    VCTR, LABS, HALT, JSRL, RTSL, JMPL, SVEC
};

// Set the sign for a 2-bit value with a sign flag in the 3rd bit.
int sign3(int val) {
    if ((val & 0x0004) == 0) {
        return val;
    } else {
        return -(val & 0x03);
    }
}

// Set the sign for a 10-bit value with a sign flag in the 11th bit.
int sign11(int val) {
    if ((val & 0x0400) == 0) {
        return val;
    } else {
        return -(val & 0x03ff);
    }
}

// Extracts opcode.
Opcode GetOpcode(uint16_t code) {
	switch (code & 0xf000) {
		case 0xa000:    return LABS;
		case 0xb000:    return HALT;
		case 0xc000:    return JSRL;
		case 0xd000:    return RTSL;
		case 0xe000:    return JMPL;
		case 0xf000:    return SVEC;
		default:        return VCTR;      // 0x0nnn - 0x9nnn
    }
}

// Formats an instruction as a string.
std::string decodeInstr(uint16_t code0, uint16_t code1, int baseAddr,
        bool& isTwoWords, int& addrRef) {
    Opcode opc = INSTR_UNKNOWN;
    isTwoWords = false;
    addrRef = -1;

    opc = GetOpcode(code0);
    if (opc == INSTR_UNKNOWN) {
        // shouldn't be possible -- all bit patterns are covered
        return "!UNKNOWN!";
    }

    char outBuf[64];

    switch (opc) {
    case VCTR: {    // SSSS -mYY YYYY YYYY | BBBB -mXX XXXX XXXX
            int dx = sign11(code1 & 0x07ff);
            int dy = sign11(code0 & 0x07ff);
            int sc = code0 >> 12;   // local scale
            int bb = code1 >> 12;   // brightness
            isTwoWords = true;

            snprintf(outBuf, sizeof(outBuf), "VCTR x=%+d y=%+d sc=%d b=%d",
                dx, dy, sc, bb);
        }
        break;
    case LABS: {    // 1010 00yy yyyy yyyy | SSSS 00xx xxxx xxxx
            int xc = code1 & 0x07ff;
            int yc = code0 & 0x07ff;
            int scRaw = code1 >> 12;

            // Sign-extend the scale factor.  (It's usually 0 or 1 in ROM.)
            int left = (uint8_t)(scRaw << 4);
            int sc = (int8_t)left >> 4;

            isTwoWords = true;

            snprintf(outBuf, sizeof(outBuf), "LABS x=%d y=%d sc=%d",
                xc, yc, sc);
        }
        break;
    case HALT: {    // 1011 0000 0000 0000
            snprintf(outBuf, sizeof(outBuf), "HALT");
        }
        break;
    case JSRL: {    // 1100 aaaa aaaa aaaa
            int addr = code0 & 0x0fff;
            snprintf(outBuf, sizeof(outBuf), "JSRL a=$%04x ($%04x)",
                addr, baseAddr + addr * 2);
            addrRef = addr;
        }
        break;
    case RTSL: {    // 1101 0000 0000 0000
            // 110----- --------
            snprintf(outBuf, sizeof(outBuf), "RTSL");
        }
        break;
    case JMPL: {    // 1110 aaaa aaaa aaaa
            int addr = code0 & 0x0fff;
            snprintf(outBuf, sizeof(outBuf), "JMPL a=$%04x ($%04x)",
                addr, baseAddr + addr * 2);
            addrRef = addr;
        }
        break;
    case SVEC: {    // 1111 smYY BBBB SmXX
            int dy = sign3((code0 >> 8) & 0x07);
            int dx = sign3(code0 & 0x07);
            int sc = ((code0 >> 11) & 0x01) | ((code0 >> 2) & 0x02);
            int bb = (code0 >> 4) & 0x0f;
            snprintf(outBuf, sizeof(outBuf), "SVEC x=%+d y=%+d sc=%d b=%d",
                dx, dy, sc, bb);
        }
        break;
    case INSTR_UNKNOWN:
        break;
    }

    return std::string(outBuf);
}

// Opens a file for reading, and checks its length.
FILE* prepareFile(std::string fileName, int offset, int length) {
    if ((length & 0x01) != 0) {
        fprintf(stderr, "Length must be even\n");
        return NULL;
    }

    FILE* fp = fopen(fileName.c_str(), "rb");
    if (fp == NULL) {
        fprintf(stderr, "Unable to open '%s'\n", fileName.c_str());
        return NULL;
    }
    fseek(fp, 0, SEEK_END);
    long fileLen = ftell(fp);
    rewind(fp);

    if (offset < 0 || offset >= fileLen) {
        fprintf(stderr, "Invalid offset\n");
        fclose(fp);
        return NULL;
    }
    if (offset + length >= fileLen) {
        fprintf(stderr, "Invalid length\n");
        fclose(fp);
        return NULL;
    }

    return fp;
}

// Loads relevant section of file into memory.
uint8_t* loadFile(FILE* fp, int offset, int length) {
    if (fseek(fp, offset, SEEK_SET) < 0) {
        fprintf(stderr, "Seek failed\n");
        return NULL;
    }
    uint8_t* buf = new uint8_t[length];
    if (buf == NULL) {
        return NULL;
    }
    size_t actual = fread(buf, 1, length, fp);
    if (actual != (size_t) length) {
        fprintf(stderr, "Failed to read all data (got %zd)\n", actual);
        delete[] buf;
        return NULL;
    }

    return buf;
}

// Outputs a series of decoded instructions to stdout.
int doList(std::string fileName, int offset, int length, int vramAddr,
        int listStartOffset, bool doSgec) {

    FILE* fp = prepareFile(fileName, offset, length);
    if (fp == NULL) {
        return 1;
    }
    uint8_t* data = loadFile(fp, offset, length);
    fclose(fp);
    if (data == NULL) {
        return 1;
    }

    bool* isRef = new bool[length / 2]();
    bool* twoWords = new bool[length / 2]();

    // first pass: scan codes and populate isRef[]
    for (int i = 0; i < length; i += 2) {
        uint16_t code0 = data[i] | (data[i+1] << 8);
        uint16_t code1;
        if (i + 2 < length) {
            code1 = data[i+2] | (data[i+3] << 8);
        }
        bool isTwoWords;
        int addrRef;
        std::string decoded = decodeInstr(code0, code1, vramAddr,
            /*ref*/ isTwoWords, /*ref*/ addrRef);
        if (i + 2 >= length) {
            isTwoWords = false;
        }
        twoWords[i/2] = isTwoWords;
        if (addrRef != -1) {
            // JMPL or JSRL
            addrRef -= listStartOffset / 2;
            if (addrRef < 0 || addrRef >= length / 2) {
                fprintf(stderr, "Invalid address reference %04x at +%06x\n",
                    addrRef, i);
                // keep going
            } else {
                isRef[addrRef] = true;
            }
        }

        if (isTwoWords) {
            i += 2;
        }
    }

    // second pass: generate output
    for (int i = 0; i < length; i += 2) {
        if (doSgec) {
            printf("set-comment +%06x:", offset + i);
        } else {
            printf("%04x: ", vramAddr + listStartOffset + i);
        }

        uint16_t code0 = data[i] | (data[i+1] << 8);
        uint16_t code1 = 0;
        if (twoWords[i/2]) {
            code1 = data[i+2] | (data[i+3] << 8);
            if (!doSgec) {
                printf("%04x %04x ", code0, code1);
            }
        } else {
            if (!doSgec) {
                printf("%04x      ", code0);
            }
        }

        if (!doSgec) {
            printf("%c", isRef[i/2] ? '>' : ' ');
            putchar(' ');
        }

        bool isTwoWords;
        int addrRef;
        std::string decoded = decodeInstr(code0, code1, vramAddr,
            /*ref*/ isTwoWords, /*ref*/ addrRef);
        printf("%s%s\n", decoded.c_str(),
            (doSgec && isRef[i/2]) ? " <<<" : "");

        if (isTwoWords) {
            i += 2;
        }
    }

    delete[] isRef;
    delete[] data;
    return 0;
}

// Formats a one-word instruction.
int doDecode(uint16_t code0) {
    bool isTwoWords;
    int unused;
    std::string decoded = decodeInstr(code0, 0xcccc, 0,
        /*ref*/ isTwoWords, /*ref*/ unused);
    if (isTwoWords) {
        fprintf(stderr, "ERROR: instruction is two words\n");
        return 1;
    }
    printf("%04x -> %s\n", code0, decoded.c_str());
    return 0;
}

// Formats a two-word instruction.
int doDecode(uint16_t code0, uint16_t code1) {
    bool isTwoWords;
    int unused;
    std::string decoded = decodeInstr(code0, code1, 0,
        /*ref*/ isTwoWords, /*ref*/ unused);
    printf("%04x %04x -> %s\n", code0, code1, decoded.c_str());
    return 0;
}

void usage() {
    fprintf(stderr, "Usage: avgdis --list <filename> <offset> <length> <vram-addr> <start-offset>\n");
    fprintf(stderr, "       avgdis --sgec <filename> <offset> <length> <vram-addr> <start-offset>\n");
    fprintf(stderr, "       avgdis --decode <code0> [<code1>]\n");
}

int main(int argc, char* argv[]) {
    if (argc < 2) {
        usage();
        return 2;
    }

    int result = 2;

    if (strcmp(argv[1], "--list") == 0) {
        if (argc != 7) {
            usage();
        } else {
            const char* fileName = argv[2];
            int offset = strtol(argv[3], NULL, 0);
            int length = strtol(argv[4], NULL, 0);
            int vramAddr = strtol(argv[5], NULL, 0);
            int listStartOffset = strtol(argv[6], NULL, 0);
            result = doList(fileName, offset, length, vramAddr, listStartOffset, false);
        }
    } else if (strcmp(argv[1], "--sgec") == 0) {
        if (argc != 7) {
            usage();
        } else {
            const char* fileName = argv[2];
            int offset = strtol(argv[3], NULL, 0);
            int length = strtol(argv[4], NULL, 0);
            int vramAddr = strtol(argv[5], NULL, 0);
            int listStartOffset = strtol(argv[6], NULL, 0);
            result = doList(fileName, offset, length, vramAddr, listStartOffset, true);
        }
    } else if (strcmp(argv[1], "--decode") == 0) {
        if (argc != 3 && argc != 4) {
            usage();
        } else {
            uint16_t code0 = (uint16_t) strtol(argv[2], NULL, 0);
            if (argc == 4) {
                uint16_t code1 = (uint16_t) strtol(argv[3], NULL, 0);
                result = doDecode(code0, code1);
            } else {
                result = doDecode(code0);
            }
        }
    } else {
        usage();
    }

    return result;
}
