/*
 * Atari AVG bytecode disassembler.
 *
 * Examples:
 *  avg_disasm --decode 0x68c1 
 *  avg_disasm --decode 0x1f00 0x0010
 *  avg_disasm --list Battlezone 0x3000 0x784 0x2000 0x1000
 *  avg_disasm --sgec Battlezone 0x3000 0x784 0x2000 0x1000
 *
 * Thanks:
 *  http://www.ionpool.net/arcade/atari_docs/avg.pdf
 *  http://www.brouhaha.com/~eric/software/vecsim/
 *
 * Copyright 2020 faddenSoft.  Licensed under the Apache License, Version 2.0.
 */
#include <stdlib.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>

#include <string>

#define NELEM(x) (sizeof(x) / sizeof(x[0]))

enum Opcode {
    INSTR_UNKNOWN = 0,
    VCTR, HALT, SVEC, STAT, SCAL, CNTR, JSR, RTS, JMP
};

// Sign-extends a signed 5-bit value.
int sign5(int val) {
    int8_t val5 = val << 3;
    return val5 >> 3;
}

// Sign-extends a signed 13-bit value.
int sign13(int val) {
    int16_t val13 = val << 3;
    return val13 >> 3;
}

// Extracts opcode.
Opcode GetOpcode(uint16_t code) {
    switch (code & 0xe000) {
    case 0x0000:    return VCTR;
    case 0x2000:    return HALT;
    case 0x4000:    return SVEC;
    case 0x6000:    return ((code & 0xf000) == 0x6000) ? STAT : SCAL;
    case 0x8000:    return CNTR;
    case 0xa000:    return JSR;
    case 0xc000:    return RTS;
    case 0xe000:    return JMP;
    default:        return INSTR_UNKNOWN;	// shouldn't be possible
    }
}

// Formats an instruction as a string.
std::string decodeInstr(uint16_t code0, uint16_t code1, int baseAddr,
        bool& isTwoWords, int& addrRef) {
    Opcode opc = INSTR_UNKNOWN;
    isTwoWords = false;
    addrRef = -1;

    opc = GetOpcode(code0);
    if (opc == INSTR_UNKNOWN) {
        // shouldn't be possible -- all bit patterns are covered
        return "!UNKNOWN!";
    }

    char outBuf[64];

    switch (opc) {
    case VCTR: {
            // 000YYYYY YYYYYYYY  IIIXXXXX XXXXXXXX
            int dy = sign13(code0 & 0x1fff);
            int dx = sign13(code1 & 0x1fff);
            int ii = code1 >> 13;
            if (ii != 1) {
                ii *= 2;
            }
            snprintf(outBuf, sizeof(outBuf), "VCTR dx=%+d dy=%+d in=%d",
                dx, dy, ii);

            isTwoWords = true;
        }
        break;
    case HALT: {
            // 00100000 00100000
            snprintf(outBuf, sizeof(outBuf), "HALT");
        }
        break;
    case SVEC: {
            // 010YYYYY IIIXXXXX
            int dy = sign5((code0 >> 8) & 0x1f) * 2;
            int dx = sign5(code0 & 0x1f) * 2;
            int ii = (code0 >> 5) & 0x07;
            if (ii != 1) {
                ii *= 2;
            }
            snprintf(outBuf, sizeof(outBuf), "SVEC dx=%+d dy=%+d in=%d",
                dx, dy, ii);
        }
        break;
    case STAT: {
            // 0110-EHO IIIICCCC
            // (different for e.g. Star Wars: 0110-RGB ZZZZZZZZ)
            int eflag = (code0 >> 10) & 0x01;
            int hflag = (code0 >> 9) & 0x01;
            int oflag = (code0 >> 8) & 0x01;
            int ii = (code0 >> 4) & 0x0f;
            int cc = code0 & 0x0f;
            snprintf(outBuf, sizeof(outBuf), "STAT in=%d cl=%d flags=%c/%c/%c",
                ii, cc,
                eflag ? 'E' : 'N',
                hflag ? 'H' : 'L',
                oflag ? 'I' : 'O');
        }
        break;
    case SCAL: {
            // 0111-BBB LLLLLLLL
            int bs = (code0 >> 8) & 0x07;
            int ls = code0 & 0xff;
            // value *= (2 ^ (1 - B)) * (1 - L / (2 ^ B))  ???
            // VECSIM uses this formula:
            int scale = (16384 - (ls << 6)) >> bs;
            snprintf(outBuf, sizeof(outBuf), "SCAL b=%d l=%d (* %.3f)",
                bs, ls, scale / 8192.0);
        }
        break;
    case CNTR: {
            // 10000000 01------
            snprintf(outBuf, sizeof(outBuf), "CNTR");
        }
        break;
    case JSR: {
            // 101-AAAA AAAAAAAA
            // (spec says 12, VECSIM masks 13, doesn't matter for Battlezone)
            int addr = code0 & 0x0fff;
            snprintf(outBuf, sizeof(outBuf), "VJSR a=$%04x ($%04x)",
                addr, baseAddr + addr * 2);
            addrRef = addr;
        }
        break;
    case RTS: {
            // 110----- --------
            snprintf(outBuf, sizeof(outBuf), "VRTS");
        }
        break;
    case JMP: {
            // 111-AAAA AAAAAAAA
            // (spec says 12, VECSIM masks 13, doesn't matter for Battlezone)
            int addr = code0 & 0x0fff;
            snprintf(outBuf, sizeof(outBuf), "VJMP a=$%04x ($%04x)",
                addr, baseAddr + addr * 2);
            addrRef = addr;
        }
        break;
    case INSTR_UNKNOWN:
        break;
    }

    return std::string(outBuf);
}

// Opens a file for reading, and checks its length.
FILE* prepareFile(std::string fileName, int offset, int length) {
    if ((length & 0x01) != 0) {
        fprintf(stderr, "Length must be even\n");
        return NULL;
    }

    FILE* fp = fopen(fileName.c_str(), "rb");
    if (fp == NULL) {
        fprintf(stderr, "Unable to open '%s'\n", fileName.c_str());
        return NULL;
    }
    fseek(fp, 0, SEEK_END);
    long fileLen = ftell(fp);
    rewind(fp);

    if (offset < 0 || offset >= fileLen) {
        fprintf(stderr, "Invalid offset\n");
        fclose(fp);
        return NULL;
    }
    if (offset + length >= fileLen) {
        fprintf(stderr, "Invalid length\n");
        fclose(fp);
        return NULL;
    }

    return fp;
}

// Loads relevant section of file into memory.
uint8_t* loadFile(FILE* fp, int offset, int length) {
    if (fseek(fp, offset, SEEK_SET) < 0) {
        fprintf(stderr, "Seek failed\n");
        return NULL;
    }
    uint8_t* buf = new uint8_t[length];
    if (buf == NULL) {
        return NULL;
    }
    size_t actual = fread(buf, 1, length, fp);
    if (actual != (size_t) length) {
        fprintf(stderr, "Failed to read all data (got %zd)\n", actual);
        delete[] buf;
        return NULL;
    }

    return buf;
}

// Outputs a series of decoded instructions to stdout.
int doList(std::string fileName, int offset, int length, int vramAddr,
        int listStartOffset, bool doSgec) {

    FILE* fp = prepareFile(fileName, offset, length);
    if (fp == NULL) {
        return 1;
    }
    uint8_t* data = loadFile(fp, offset, length);
    fclose(fp);
    if (data == NULL) {
        return 1;
    }

    bool* isRef = new bool[length / 2]();
    bool* twoWords = new bool[length / 2]();

    // first pass: scan codes and populate isRef[]
    for (int i = 0; i < length; i += 2) {
        uint16_t code0 = data[i] | (data[i+1] << 8);
        uint16_t code1;
        if (i + 2 < length) {
            code1 = data[i+2] | (data[i+3] << 8);
        }
        bool isTwoWords;
        int addrRef;
        std::string decoded = decodeInstr(code0, code1, vramAddr,
            /*ref*/ isTwoWords, /*ref*/ addrRef);
        if (i + 2 >= length) {
            isTwoWords = false;
        }
        twoWords[i/2] = isTwoWords;
        if (addrRef != -1) {
            // JMP or JSR
            addrRef -= listStartOffset / 2;
            if (addrRef < 0 || addrRef >= length / 2) {
                fprintf(stderr, "Invalid address reference %04x at +%06x\n",
                    addrRef, i);
                // keep going
            } else {
                isRef[addrRef] = true;
            }
        }

        if (isTwoWords) {
            i += 2;
        }
    }

    // second pass: generate output
    for (int i = 0; i < length; i += 2) {
        if (doSgec) {
            printf("set-comment +%06x:", offset + i);
        } else {
            printf("%04x: ", vramAddr + listStartOffset + i);
        }

        uint16_t code0 = data[i] | (data[i+1] << 8);
        uint16_t code1 = 0;
        if (twoWords[i/2]) {
            code1 = data[i+2] | (data[i+3] << 8);
            if (!doSgec) {
                printf("%04x %04x ", code0, code1);
            }
        } else {
            if (!doSgec) {
                printf("%04x      ", code0);
            }
        }

        if (!doSgec) {
            printf("%c", isRef[i/2] ? '>' : ' ');
            putchar(' ');
        }

        bool isTwoWords;
        int addrRef;
        std::string decoded = decodeInstr(code0, code1, vramAddr,
            /*ref*/ isTwoWords, /*ref*/ addrRef);
        printf("%s%s\n", decoded.c_str(),
            (doSgec && isRef[i/2]) ? " <<<" : "");

        if (isTwoWords) {
            i += 2;
        }
    }

    delete[] isRef;
    delete[] data;
    return 0;
}

// Formats a one-word instruction.
int doDecode(uint16_t code0) {
    bool isTwoWords;
    int unused;
    std::string decoded = decodeInstr(code0, 0xcccc, 0,
        /*ref*/ isTwoWords, /*ref*/ unused);
    if (isTwoWords) {
        fprintf(stderr, "ERROR: instruction is two words\n");
        return 1;
    }
    printf("%04x -> %s\n", code0, decoded.c_str());
    return 0;
}

// Formats a two-word instruction.
int doDecode(uint16_t code0, uint16_t code1) {
    bool isTwoWords;
    int unused;
    std::string decoded = decodeInstr(code0, code1, 0,
        /*ref*/ isTwoWords, /*ref*/ unused);
    printf("%04x %04x -> %s\n", code0, code1, decoded.c_str());
    return 0;
}

void usage() {
    fprintf(stderr, "Usage: avgdis --list <filename> <offset> <length> <vram-addr> <start-offset>\n");
    fprintf(stderr, "       avgdis --sgec <filename> <offset> <length> <vram-addr> <start-offset>\n");
    fprintf(stderr, "       avgdis --decode <code0> [<code1>]\n");
}

int main(int argc, char* argv[]) {
    if (argc < 2) {
        usage();
        return 2;
    }

    int result = 2;

    if (strcmp(argv[1], "--list") == 0) {
        if (argc != 7) {
            usage();
        } else {
            const char* fileName = argv[2];
            int offset = strtol(argv[3], NULL, 0);
            int length = strtol(argv[4], NULL, 0);
            int vramAddr = strtol(argv[5], NULL, 0);
            int listStartOffset = strtol(argv[6], NULL, 0);
            result = doList(fileName, offset, length, vramAddr, listStartOffset, false);
        }
    } else if (strcmp(argv[1], "--sgec") == 0) {
        if (argc != 7) {
            usage();
        } else {
            const char* fileName = argv[2];
            int offset = strtol(argv[3], NULL, 0);
            int length = strtol(argv[4], NULL, 0);
            int vramAddr = strtol(argv[5], NULL, 0);
            int listStartOffset = strtol(argv[6], NULL, 0);
            result = doList(fileName, offset, length, vramAddr, listStartOffset, true);
        }
    } else if (strcmp(argv[1], "--decode") == 0) {
        if (argc != 3 && argc != 4) {
            usage();
        } else {
            uint16_t code0 = (uint16_t) strtol(argv[2], NULL, 0);
            if (argc == 4) {
                uint16_t code1 = (uint16_t) strtol(argv[3], NULL, 0);
                result = doDecode(code0, code1);
            } else {
                result = doDecode(code0);
            }
        }
    } else {
        usage();
    }

    return result;
}
