#!/usr/bin/python3
# vim: expandtab tabstop=4

import optparse, os, re, sys, tempfile, time

class KnownHostsFile():

    def __init__(self):
        self.lines = []

    def load(self, fn):
        f = open(fn, 'r')
        for line in f.readlines():
            if len(line.split('#', 1)[0].strip()) == 0:
                continue
            try:
                tags, fptype, fpvalue = line.rstrip().split(' ', 2)
            except ValueError:
                print(line)
                continue
            self.lines.append([tags.split(','), fptype, fpvalue])
        f.close()
    
    def save(self, fn):
        f = open(fn, 'w')
        for tags, fptype, fpvalue in self.lines:
            f.write(' '.join((','.join(tags), fptype, fpvalue)) + '\n')
        f.close()

    def compact(self, verbose=None):
        newlines = []
        fpvalues = []
        for tags, fptype, fpvalue in self.lines:
            try:
                idx = fpvalues.index(fpvalue)
                if newlines[idx][1] != fptype:
                    raise(RuntimeError("Weeha, fptype mismatch"))
                for tag in tags:
                    if tag in ('::1', '127.0.0.1', 'localhost'):
                        print('rewriting harmful tag: %s' % tag)
                        newlines.append([[tag, ], fptype, fpvalue])
                        fpvalues.append(fpvalue)
                        continue
                    if tag in newlines[idx][0]:
                        print('skipping duplicate tag: %s' % tag)
                        continue
                    print('joining tag %s with %s' % (tag, ','.join(newlines[idx][0])))
                    newlines[idx][0].append(tag)
            except ValueError:
                newlines.append([tags, fptype, fpvalue])
                fpvalues.append(fpvalue)
        if len(newlines) < len(self.lines):
            self.lines = newlines
            return True
        else:
            return False
 
    def sort_tags(self):
        newlines = []
        for tags, fptype, fpvalue in self.lines:
            names_, ip4_, ip6_ = [], [], []
            for tag in tags:
                if re.compile('^(|\[)([0-9a-fA-F]{0,4}:){1,7}([0-9a-fA-F]{0,4}|:)(|\]:[0-9]{1,5})$').match(tag):
                    ip6_.append(tag)
                elif re.compile('^(|\[)([0-9]{1,3}.){3,3}[0-9]{1,3}(|\]:[0-9]{1,5})$$').match(tag):
                    ip4_.append(tag)
                else:
                    names_.append(tag)
            newlines.append([sorted(names_) + sorted(ip4_) + sorted(ip6_), fptype, fpvalue])
        if newlines != self.lines:
            self.lines = newlines
            return True
        else:
            return False

def main():
    parser = optparse.OptionParser(description='Usage: %prog [options] <known hosts file|~/.ssh/known/hosts>')

    parser.add_option('-n', '--dry-run', action='store_true', dest='dryrun', help='don\'t modify anything')
    parser.add_option('-v', '--verbose', action='store_true', dest='verbose', help='show what\'s going on')
    parser.add_option('-b', '--backup', action='store_true', dest='backup', help='create backup')
    options, args = parser.parse_args()

    if len(args) == 0:
        khfn = os.path.expanduser('~/.ssh/known_hosts')
    elif len(args) == 1:
        khfn = args[0]
    else:
        parser.error('incorrect number of arguments')
    
    if options.verbose:
        print('Handling known_hosts file: %s' % khfn)

    khf = KnownHostsFile()
    khf.load(khfn)
    changes1 = khf.compact(verbose=options.verbose)
    changes2 = khf.sort_tags()
    if changes1 or changes2:
        if options.dryrun:
            print('Exit without action')
            sys.exit(0)
        if options.backup:
            d_, b_ = os.path.split(khfn)
            bfd, bfn = tempfile.mkstemp(dir=d_, prefix=b_+'-')
            f_ = open(khfn, 'r')
            os.write(bfd, f_.read().encode('utf-8'))
            f_.close()
            os.close(bfd)
            print('backup created: %s' % bfn)
        khf.save(khfn)

if __name__ == "__main__":
    main()
