#!/usr/bin/env python # -*- coding: utf-8 -*- from math import sqrt def wordlist(text): text = text.lower() return text.split() class Collection(dict): def __init__(self, *args, **kargs): dict.__init__(self, *args, **kargs) for k, v in self.items(): dict.__setitem__(self, k, [v]) def __setitem__(self, item, value): if item in self: self[item].append(value) else: dict.__setitem__(self, item, [value]) def normalize(n, maximum, k=100): return float(n)/maximum * k + 1 def distribution(words, k=100): vec = Collection() numwords = len(words) for i, word in enumerate(words): vec[word] = int(normalize(i, numwords, k)) return vec def freq(seq): fq = {} for e in seq: if e not in fq: fq[e] = 1 else: fq[e] += 1 return fq def scalar(vec): total = 0 for elem in vec: total += vec[elem] * vec[elem] return sqrt(total) def sim(v, w): total = 0 for elem in v: if elem in w: total += v[elem] * w[elem] return float(total) / (scalar(v) * scalar(w)) def compare(query, source, target, k=100): queryvector = source[query] comparisons = [] for word in target: targetvector = target[word] similarity = sim(freq(queryvector), freq(targetvector)) if similarity > 0.5: comparisons.append((similarity, query, word)) return sorted(comparisons) def filetowords(fname): text = unicode(open(fname, 'U').read()) text = text.lower() text = depunctuate2(text) words = wordlist(text) return words if __name__ == "__main__": import sys from text import depunctuate2 sourcefile, targetfile = sys.argv[1], sys.argv[2] sourcewords = filetowords(sourcefile) sourcedist = distribution(sourcewords) targetwords = filetowords(targetfile) targetdist = distribution(targetwords) for (i,w) in enumerate(sorted(list(set(sourcewords)))): comparisons = compare(w, sourcedist, targetdist, k=100) #print comparisons[-40:] for i, q in enumerate(sorted(list(set(sourcewords)))): print "%d/%d\t%s" % (i, len(set(sourcewords)), q) if len(sourcedist[q]) > 2: for similarity, query, matched, in sorted(compare(q, sourcedist, targetdist, k=100))[-5:]: print "\t%s" % matched