#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import sqlite3
import sys


def create_tables(db_name):
    cxn = sqlite3.connect(db_name)
    cur = cxn.cursor()

    cur.execute(
        'CREATE TABLE loginfo('
        + 'id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, '
        + 'log_id TEXT, '
        + 'public_key TEXT, '  # path to PEM-encoded file
        + 'distrusted INTEGER, '  # non-zero if not trusted
        + 'min_valid_timestamp INTEGER, '
        + 'max_valid_timestamp INTEGER, '
        + 'url TEXT)'
    )
    cur.close()
    cxn.commit()
    cxn.close()


def record_id_arg(cur, args, required=False):
    if len(args) < 1 or args[0][0] != '#' or len(args[0]) < 2:
        if required:
            print >> sys.stderr, 'A record id was not provided'
            sys.exit(1)
        return None
    record_id = args.pop(0)[1:]
    stmt = 'SELECT * FROM loginfo WHERE id = ?'
    cur.execute(stmt, [record_id])
    recs = list(cur.fetchall())
    assert len(recs) < 2
    if len(recs) == 0:
        print >> sys.stderr, 'Record #%s was not found' % record_id
        sys.exit(1)
    return record_id


def log_id_arg(cur, args, required=True):
    if len(args) < 1 or len(args[0]) != 64:
        if not required:
            return None
        print >> sys.stderr, 'A log id was not provided'
        sys.exit(1)
    log_id = args.pop(0).upper()
    if len(re.compile(r'[A-Z0-9]').findall(log_id)) != len(log_id):
        print >> sys.stderr, 'The log id is not formatted properly'
        sys.exit(1)
    return log_id


def public_key_arg(args):
    if len(args) < 1:
        print >> sys.stderr, 'A public key file was not provided'
        sys.exit(1)
    pubkey = args.pop(0)
    if not os.path.exists(pubkey):
        print >> sys.stderr, 'Public key file %s could not be read' % pubkey
        sys.exit(1)
    return pubkey


def time_arg(args):
    if len(args) < 1:
        print >> sys.stderr, 'A timestamp was not provided'
        sys.exit(1)
    t = args.pop(0)
    if t == '-':
        return None
    bad_val = False
    val = None
    try:
        val = int(t)
    except ValueError:
        bad_val = True

    if bad_val or val < 1:
        print >> sys.stderr, 'The timestamp "%s" is invalid' % t
        sys.exit(1)

    return val


def configure_public_key(cur, args):
    record_id = record_id_arg(cur, args, False)
    public_key = public_key_arg(args)
    if len(args) != 0:
        usage()
    if not record_id:
        stmt = 'INSERT INTO loginfo (public_key) VALUES(?)'
        cur.execute(stmt, [public_key])
    else:
        stmt = 'UPDATE loginfo SET public_key = ? WHERE id = ?'
        cur.execute(stmt, [public_key, record_id])


def configure_url(cur, args):
    # can't specify more than one of record-id and log-id
    log_id = None
    record_id = record_id_arg(cur, args, False)
    if not record_id:
        log_id = log_id_arg(cur, args, False)
    if len(args) != 1:
        usage()
    url = args.pop(0)

    if record_id:
        stmt = 'UPDATE loginfo SET url = ? WHERE id = ?'
        args = [url, record_id]
    elif log_id:
        stmt = 'INSERT INTO loginfo (log_id, url) VALUES(?, ?)'
        args = [log_id, url]
    else:
        stmt = 'INSERT INTO loginfo (url) VALUES(?)'
        args = [url]

    cur.execute(stmt, args)


def forget_log(cur, args):
    record_id = record_id_arg(cur, args, False)
    log_id = None
    if not record_id:
        log_id = log_id_arg(cur, args, True)
    if len(args) != 0:
        usage()
    if record_id:
        stmt = 'DELETE FROM loginfo WHERE id = ?'
        args = [record_id]
    else:
        stmt = 'DELETE FROM loginfo WHERE log_id = ?'
        args = [log_id]
    cur.execute(stmt, args)


def trust_distrust_log(cur, args):
    # could take a record id or a log id
    record_id = record_id_arg(cur, args, False)
    if record_id:
        log_id = None
    else:
        log_id = log_id_arg(cur, args)

    if len(args) != 1:
        usage()
    flag = args.pop(0)

    if not record_id:
        stmt = 'INSERT INTO loginfo (log_id, distrusted) VALUES(?, ?)'
        cur.execute(stmt, [log_id, flag])
    else:
        stmt = 'UPDATE loginfo SET distrusted = ? WHERE id = ?'
        cur.execute(stmt, [flag, record_id])


def trust_log(cur, args):
    trust_distrust_log(cur, args + [0])


def distrust_log(cur, args):
    trust_distrust_log(cur, args + [1])


def time_range(cur, args):
    # could take a record id or a log id
    record_id = record_id_arg(cur, args, False)
    if record_id:
        log_id = None
    else:
        log_id = log_id_arg(cur, args)

    min_valid_time = time_arg(args)
    max_valid_time = time_arg(args)
    if len(args) != 0:
        usage()
    if not record_id:
        stmt = 'INSERT INTO loginfo ' + \
               '(log_id, min_valid_timestamp, max_valid_timestamp) ' + \
               'VALUES(?, ?, ?)'
        cur.execute(stmt, [log_id, min_valid_time, max_valid_time])
    else:
        stmt = 'UPDATE loginfo SET min_valid_timestamp = ?, ' + \
               'max_valid_timestamp = ? WHERE id = ?'
        cur.execute(stmt, [min_valid_time, max_valid_time, record_id])


class ConfigEntry:

    pass


def dump_ll(cur):
    stmt = 'SELECT * FROM loginfo'
    cur.execute(stmt)
    recs = []
    for row in cur.fetchall():
        obj = ConfigEntry()
        obj.id = row[0]
        obj.log_id = row[1]
        obj.public_key = row[2]
        obj.distrusted = row[3]
        obj.min_valid_timestamp = row[4]
        obj.max_valid_timestamp = row[5]
        obj.url = row[6]
        recs += [obj]
    return recs


def dump(cur, args):
    if len(args) != 0:
        usage()
    recs = dump_ll(cur)
    for rec in recs:
        not_conf = '(not configured)'

        mint = \
            str(rec.min_valid_timestamp) if rec.min_valid_timestamp else '-INF'
        maxt = \
            str(rec.max_valid_timestamp) if rec.max_valid_timestamp else '+INF'
        print 'Log entry:'
        print '  Record ' + str(rec.id) + \
            (' (DISTRUSTED)' if rec.distrusted else '')
        print '  Log id         : ' + (rec.log_id if rec.log_id else not_conf)
        print '  Public key file: ' + \
            (rec.public_key if rec.public_key else not_conf)
        print '  URL            : ' + (rec.url if rec.url else not_conf)
        print '  Time range     : ' + mint + ' to ' + maxt
        print ''


def usage():
    help = """Usage: %s /path/to/log-config-db command args

Commands:
  display config-db contents:
    dump
  configure public key:
    configure-public-key [log-id|record-id] /path/log-pub-key.pem
  configure URL:
    configure-url [log-id|record-id] http://www.example.com/path/
  configure min and/or max valid timestamps:
    valid-time-range log-id|record-id min-range max-range
  mark log as trusted (default):
    trust log-id|record-id
  mark log as untrusted:
    distrust log-id|record-id
  remove log config from config-db:
    forget log-id|record-id

log-id is a 64-character hex string representation of a log id

record-id references an existing entry and is in the form:
  #<record-number>
  (displayable with the dump command)
""" % sys.argv[0]
    print >> sys.stderr, help
    sys.exit(1)


def main(argv):
    if len(argv) < 3:
        usage()

    db_name = argv[1]
    cmd = argv[2]
    args = argv[3:]

    cmds = {'configure-public-key': configure_public_key,
            'configure-url': configure_url,
            'distrust': distrust_log,
            'trust': trust_log,
            'forget': forget_log,
            'valid-time-range': time_range,
            'dump': dump,
            }

    cmds_requiring_db = ['dump', 'forget']  # db must already exist

    if not cmd in cmds:
        usage()

    if not os.path.exists(db_name):
        if not cmd in cmds_requiring_db:
            create_tables(db_name)
        else:
            print >> sys.stderr, 'Database "%s" does not exist' % db_name
            sys.exit(1)

    cxn = sqlite3.connect(db_name)
    cur = cxn.cursor()

    cmds[cmd](cur, args)

    cur.close()
    cxn.commit()
    cxn.close()

if __name__ == "__main__":
    main(sys.argv)
