#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import binascii
import os
import sqlite3
import ssl
import struct
import sys
import tempfile

from contextlib import closing

SERVER_START = 1
KEY_START = 2
CERT_START = 3
SCT_START = 4


def usage():
    print >> sys.stderr, ('Usage: %s /path/to/audit/files ' +
                          '[/path/to/log-config-db]') % sys.argv[0]
    sys.exit(1)


def audit(fn, tmp, already_checked, cur):
    print 'Auditing %s...' % fn

    # First, parse the audit file into a series of related
    #
    #   1. PEM file with certificate chain
    #   2. Individual SCT files
    #
    # Next,  for each SCT, invoke verify_single_proof to verify.
    log_bytes = open(fn, 'rb').read()
    offset = 0
    while offset < len(log_bytes):
        print 'Got package from server...'
        val = struct.unpack_from('>H', log_bytes, offset)
        assert val[0] == SERVER_START
        offset += 2

        assert struct.unpack_from('>H', log_bytes, offset)[0] == KEY_START
        offset += 2

        key_size = struct.unpack_from('>H', log_bytes, offset)[0]
        assert key_size > 0
        offset += 2

        key = log_bytes[offset:offset + key_size]
        offset += key_size

        # at least one certificate
        assert struct.unpack_from('>H', log_bytes, offset)[0] == CERT_START

        # for each certificate:
        leaf = None
        while struct.unpack_from('>H', log_bytes, offset)[0] == CERT_START:
            offset += 2
            val = struct.unpack_from('BBB', log_bytes, offset)
            offset += 3
            der_size = (val[0] << 16) | (val[1] << 8) | (val[2] << 0)
            print '  Certificate size:', hex(der_size)
            if not leaf:
                leaf = (offset, der_size)
            offset += der_size

        pem = ssl.DER_cert_to_PEM_cert(log_bytes[leaf[0]:leaf[0] + leaf[1]])

        tmp_leaf_pem = tempfile.mkstemp(text=True)
        with closing(os.fdopen(tmp_leaf_pem[0], 'w')) as f:
            f.write(pem)

        # at least one SCT
        assert struct.unpack_from('>H', log_bytes, offset)[0] == SCT_START

        # for each SCT:
        while offset < len(log_bytes) and \
                struct.unpack_from('>H', log_bytes, offset)[0] == SCT_START:
            offset += 2
            len_offset = offset
            sct_size = struct.unpack_from('>H', log_bytes, len_offset)[0]
            offset += 2
            print '  SCT size:', hex(sct_size)
            log_id = log_bytes[offset + 1:offset + 1 + 32]
            log_id_hex = binascii.hexlify(log_id).upper()
            print '    Log id: %s' % log_id_hex
            timestamp_ms = struct.unpack_from('>Q', log_bytes, offset + 33)[0]
            print '    Timestamp: %s' % timestamp_ms

            #  If we ever need the full SCT: sct = (offset, sct_size)
            offset += sct_size

            if key in already_checked:
                print '  (SCTs already checked)'
                continue

            already_checked[key] = True

            log_url_arg = ''
            if cur:
                stmt = 'SELECT * FROM loginfo WHERE log_id = ?'
                cur.execute(stmt, [log_id_hex])
                recs = list(cur.fetchall())
                if len(recs) > 0 and recs[0][6] is not None:
                    log_url = recs[0][6]

                    # verify_single_proof doesn't accept <scheme>://
                    if '://' in log_url:
                        log_url = log_url.split('://')[1]
                    log_url_arg = '--log_url %s' % log_url

                    print '    Log URL: ' + log_url

            cmd = 'verify_single_proof.py --cert %s --timestamp %s %s' % \
                  (tmp_leaf_pem[1], timestamp_ms, log_url_arg)
            print '>%s<' % cmd
            os.system(cmd)

        os.unlink(tmp_leaf_pem[1])


def main():
    if len(sys.argv) != 2 and len(sys.argv) != 3:
        usage()

    top = sys.argv[1]
    tmp = '/tmp'

    if len(sys.argv) == 3:
        cxn = sqlite3.connect(sys.argv[2])
        cur = cxn.cursor()
    else:
        cur = None

    # could serialize this between runs to further limit duplicate checking
    already_checked = dict()

    for dirpath, dnames, fnames in os.walk(top):
        fnames = [fn for fn in fnames if fn[-4:] == '.out']
        for fn in fnames:
            audit(os.path.join(dirpath, fn), tmp, already_checked, cur)


if __name__ == "__main__":
    main()
