/*
 * COPYRIGHT INFORMATION - DO NOT REMOVE
 * Copyright (c) 2003 LinuxMagic Inc. All Rights Reserved.
 *
 * This file contains Original Code and/or Modifications of Original Code as
 * defined in and that are subject to the Free Source Code License Version
 * 1.0 (the 'License'). You may not use this file except in compliance with
 * the License. Please obtain a copy of the License at:
 *
 * http://www.linuxmagic.com/opensource/licensing/FSCL.txt
 *
 * and read it before using this file.
 *
 * The Original Code and all software distributed under the License are
 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
 * EXPRESS OR IMPLIED, AND LINUXMAGIC HEREBY DISCLAIMS ALL SUCH WARRANTIES,
 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY, FITNESS
 * FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT. Please see
 * the License for the specific language governing rights and limitations
 * under the License."
 *
 * Please read the terms of this license carefully. By using or downloading
 * this software or file, you are accepting and agreeing to the terms of this
 * license with LinuxMagic Inc. If you are agreeing to this license on behalf
 * of a company, you represent that you are authorized to bind the company to
 * such a license. If you do not meet this criterion or you do not agree to
 * any of the terms of this license, do NOT download, distribute, use or alter
 * this software or file in any way.
 *
 * Author(s): Josh Wilsdon <josh@wizard.ca>
 *
 * Version: $Id: getmx.c,v 1.4 2003/09/19 18:56:03 josh Exp $
 *
 * DO NOT MODIFY WITHOUT CONSULTING THE LICENSE
 *
 */

#include <ctype.h>
#include <resolv.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <arpa/inet.h>
#include <arpa/nameser.h>
#include <netinet/in.h>
#include "getmx.h"

/* #define DEBUG 1 */

/* private local prototypes */
static int add_mxhost(mxdata_t mxdata, char *domain, short preference);
static int add_mxhostip(mxdata_t mxdata, char *domain, unsigned long ip);
static int parse_mx(mxdata_t mxdata, unsigned char *answer, int anslen);
static int read_short(unsigned char *answer, int *pos, short *out);
static int read_ulong(unsigned char *answer, int *pos, unsigned long *out);
static int read_name(unsigned char *answer, int *pos, char *name);

/* public functions */

mxdata_t mx_query(char *domain)
{
    mxdata_t mxdata;
    unsigned char answer[MAXDNAME + 1];
    int res;

    if (domain == NULL) {
        return (NULL);
    }

    mxdata = malloc(sizeof(struct mxdata));
    if (mxdata == NULL) {
        /* printf("out of memory\n"); */
        return (NULL);
    }

    /* initialize mxdata */
    mxdata->mxcount = 0;
    mxdata->_mxhost_count = 0;
    mxdata->mxhost = NULL;

    res_init();
    res = res_query(domain, C_IN, T_MX, answer, MAXDNAME);

    if (res > 0) {

#ifdef DEBUG
        printf("%p ", &answer[0]);
        for (i = 0; i < res; i++) {
            printf("%02x(%c) ", answer[i],
                   isprint(answer[i]) ? answer[i] : '.');
            if ((i > 0) && (((i + 1) % 15) == 0)) {
                printf("\n%p ", &answer[i + 1]);
            }
        }
        printf("\n");
#endif
        if (parse_mx(mxdata, answer, res) == 1) {
            return (NULL);
        }
    } else {
        /* query failed */
        return (NULL);
    }

    return (mxdata);
}

int mx_free(mxdata_t mxdata)
{
    int i, j;

    if (mxdata == NULL) {
        return (-1);
    }

    for (i = 0; i < mxdata->mxcount; i++) {
        for (j = 0; j < mxdata->mxhost[i]->addrcount; j++) {
            /* TODO: verify that these are ok to free */
            free(mxdata->mxhost[i]->name);
            free(mxdata->mxhost[i]->addr);
        }
        free(mxdata->mxhost[i]);
    }
    free(mxdata->mxhost);
    free(mxdata);

    /* TODO: anything left? */

    return (0);
}

/* private functions */

int add_mxhost(mxdata_t mxdata, char *domain, short preference)
{
    mxhost_t mxhost, *tmpmxhost;

    mxhost = malloc(sizeof(struct mxhost));
    if (mxhost == NULL) {
        /* printf("out of memory\n"); */
        return (0);
    }

    /* initialize mxhost */
    mxhost->name = strdup(domain);
    mxhost->preference = preference;
    mxhost->addrcount = 0;
    mxhost->_addr_count = 0;
    mxhost->addr = NULL;

    /* add to mxdata */
    if (mxdata->_mxhost_count <= mxdata->mxcount) {
        /* need to allocate more space */
        tmpmxhost =
            realloc(mxdata->mxhost,
                    (sizeof(struct mxhost *) * (mxdata->_mxhost_count + 1)));
        if (tmpmxhost == NULL) {
            /* printf("out of memory\n"); */
            return (-1);
        }
        mxdata->mxhost = tmpmxhost;
        mxdata->_mxhost_count++;
    }
    mxdata->mxhost[mxdata->mxcount] = mxhost;
    mxdata->mxcount++;

    return (0);
}

int add_mxhostip(mxdata_t mxdata, char *domain, unsigned long ip)
{
    int i;
    struct in_addr in;
    unsigned long *tmplong;

    if (domain == NULL) {
        return (-1);
    }

    for (i = 0; i < mxdata->mxcount; i++) {
        if (strncmp
            (mxdata->mxhost[i]->name, domain,
             strlen(mxdata->mxhost[i]->name)) == 0) {
            /* this is the record we need to update */

            if (mxdata->mxhost[i]->_addr_count <= mxdata->mxhost[i]->addrcount) {
                /* need to allocate more space */
                tmplong =
                    realloc(mxdata->mxhost[i]->addr,
                            sizeof(unsigned long) *
                            (mxdata->mxhost[i]->_addr_count + 1));
                if (tmplong == NULL) {
                    /* printf("out of memory\n"); */
                    return (-1);
                }
                mxdata->mxhost[i]->addr = tmplong;
                mxdata->mxhost[i]->_addr_count++;
            }
            in.s_addr = htonl(ip);
            mxdata->mxhost[i]->addr[mxdata->mxhost[i]->addrcount] = ip;
            mxdata->mxhost[i]->addrcount++;

            return (0);
        }
    }

    return (1);
}

int parse_mx(mxdata_t mxdata, unsigned char *answer, int anslen)
{
    char name[255];
    unsigned long tmpulong;
    short tmpshort, questions, answers, authorities, additionals;
    int i, pos;
    struct in_addr in;

    /* HEADER */

    pos = 0;
    pos += 2;                   /* skip ID */
    pos += 2;                   /* skip flags */
    read_short(answer, &pos, &questions);
    read_short(answer, &pos, &answers);
    read_short(answer, &pos, &authorities);
    read_short(answer, &pos, &additionals);

#ifdef DEBUG
    printf("Q: %d A: %d A: %d A: %d\n", questions, answers, authorities,
           additionals);
#endif

    /* QUESTION */

    if (read_name(answer, &pos, name) == 1) {   /* Q domain */
        return (1);
    }
    pos += 2;                   /* skip Q type */
    pos += 2;                   /* skip Q class */

#ifdef DEBUG
    printf("Question: %s\n", name);
#endif

    /* ANSWER(S) */

    for (i = 0; i < answers; i++) {
        if (read_name(answer, &pos, name) == 1) {   /* Answer Name */
            return (1);
        }
#ifdef DEBUG
        printf("Answer: %s\n", name);
#endif
        read_short(answer, &pos, &tmpshort);    /* Answer Type */
        if (tmpshort != T_MX) {
            /* oops */
            /* printf("Oops: type = %d\n", tmpshort); */
            return (1);
        }
        pos += 2;               /* skip class */
        pos += 4;               /* skip ttl */
        read_short(answer, &pos, &tmpshort);    /* Answer Datalen */

        /* Answer data for MX: Preference(16) + domain name */
        read_short(answer, &pos, &tmpshort);    /* MX preference */
        if (read_name(answer, &pos, name) == 1) {   /* MX host */
            return (1);
        }

        add_mxhost(mxdata, name, tmpshort);
#ifdef DEBUG
        printf("%02d %s\n", tmpshort, name);
#endif
    }

    /* AUTHORIT(Y/IES) */

    for (i = 0; i < authorities; i++) {
        if (read_name(answer, &pos, name) == 1) {
            return (1);
        }
#ifdef DEBUG
        printf("Authority %s\n", name);
#endif
        pos += 2;               /* skip type */
        pos += 2;               /* skip class */
        pos += 4;               /* skip ttl */
        read_short(answer, &pos, &tmpshort);
        if (read_name(answer, &pos, name) == 1) {
            return (1);
        }
#ifdef DEBUG
        printf("NS: %s\n", name);
#endif
    }

    /* ADDITONAL(S) */

    for (i = 0; i < additionals; i++) {
        if (read_name(answer, &pos, name) == 1) {
            return (1);
        }
#ifdef DEBUG
        printf("Additional %s\n", name);
#endif
        pos += 2;               /* skip type */
        pos += 2;               /* skip class */
        pos += 4;               /* skip ttl */
        read_short(answer, &pos, &tmpshort);
        read_ulong(answer, &pos, &tmpulong);
        in.s_addr = htonl(tmpulong);
        add_mxhostip(mxdata, name, tmpulong);
#ifdef DEBUG
        printf("ADDR: %s\n", (char *) inet_ntoa(in));
#endif
    }

    return (0);
}

/* utility functions */

int read_ulong(unsigned char *answer, int *pos, unsigned long *out)
{
    memcpy(out, answer + *pos, 4);
    *out = ntohl(*out);
    *pos += 4;

    return (0);
}

int read_short(unsigned char *answer, int *pos, short *out)
{
    memcpy(out, answer + *pos, 2);
    *out = ntohs(*out);
    *pos += 2;

    return (0);
}

int read_name(unsigned char *answer, int *pos, char *name)
{
    unsigned char *ptr;
    unsigned char byte1, byte2;
    int index, i, compressed, redirects;
    unsigned short readoffset;

    ptr = answer + *pos;
    name[0] = '\0';             /* start w/ empty string */
    index = 0;
    redirects = 0;
    compressed = 0;
    do {
        memcpy(&byte1, ptr++, 1);
        if (!compressed) {
            (*pos)++;
        }
        if ((byte1 != 0x00) && (index > 0)) {
            /* this is not first name */
            name[index++] = '.';
        }
        if ((byte1 & 0xC0) == 0xC0) {
            redirects++;
            /* TODO: use dn_expand() instead? */
            /* compressed domain name */
            if (!compressed) {
                (*pos)++;
            }
            compressed = 1;
            byte1 ^= 0xC0;      /* remove first two bits */
            memcpy(&byte2, ptr++, 1);   /* offsets are 6 bits */
            readoffset = ((byte1 << 8) ^ byte2);    /* byte1byte2 */
            ptr = answer + readoffset;
            memcpy(&byte1, ptr++, 1);
        }
        if (redirects >= 100) {
            /* too many redirections, giving up */
            return (1);
        }
        for (i = 0; i < byte1; i++) {
            name[index++] = (char) *(ptr++);
            if (!compressed) {
                (*pos)++;
            }
        }
    } while (byte1 != 0x00);
    name[index] = '\0';

    return (0);
}

/* testing functions */

#ifdef DEBUG
int main(int argc, char *argv[])
{
    struct in_addr in;
    mxdata_t mxdata;
    int i, j;

    if (2 != argc) {
        printf("usage: %s mx_domain\n", argv[0]);
        return 0;
    }

    mxdata = mx_query(argv[1]);

    printf("results: (%d)\n", mxdata->mxcount);
    for (i = 0; i < mxdata->mxcount; i++) {
        printf("%02d %s ", mxdata->mxhost[i]->preference,
               mxdata->mxhost[i]->name);
        for (j = 0; j < mxdata->mxhost[i]->addrcount; j++) {
            in.s_addr = htonl(mxdata->mxhost[i]->addr[j]);
            printf("%s ", (char *) inet_ntoa(in));
        }
        printf("\n");
    }

    mx_free(mxdata);

    exit(0);
}
#endif
