Skip to content

Pmi

PMI

PMI (pointwise mutual information) is a measure of association used in information theory and statistics.

Given a list of pairs (x, y)

\[pmi(x, y) = log\frac{p(x,y)}{p(x)p(y}\]

where - \(p(x)\): probability of x - \(p(y)\): probability of y -\(p(x,y)\): joint probability

Example: p(x=0) = 0.8, p(x=1)=0.2, p(y=0)=0.25, p(y=1)=0.75

  • pmi(x=0;y=0) = −1
  • pmi(x=0;y=1) = 0.222392
  • pmi(x=1;y=0) = 1.584963
  • pmi(x=1;y=1) = -1.584963

Example notebook see: note book in class/PMI

!pip install pyspark
Collecting pyspark
[?25l  Downloading https://files.pythonhosted.org/packages/f0/26/198fc8c0b98580f617cb03cb298c6056587b8f0447e20fa40c5b634ced77/pyspark-3.0.1.tar.gz (204.2MB)
     |████████████████████████████████| 204.2MB 70kB/s 
[?25hCollecting py4j==0.10.9
[?25l  Downloading https://files.pythonhosted.org/packages/9e/b6/6a4fb90cd235dc8e265a6a2067f2a2c99f0d91787f06aca4bcf7c23f3f80/py4j-0.10.9-py2.py3-none-any.whl (198kB)
     |████████████████████████████████| 204kB 44.6MB/s 
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.0.1-py2.py3-none-any.whl size=204612243 sha256=ae3121fc30af19c4ec22b0beb2c7452d103f59dc5ad06c6fa21a5b108cdbf54a
  Stored in directory: /root/.cache/pip/wheels/5e/bd/07/031766ca628adec8435bb40f0bd83bb676ce65ff4007f8e73f
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9 pyspark-3.0.1
import requests
from pyspark.context import SparkContext

r = requests.get('https://www.cse.ust.hk/msbd5003/data/adj_noun_pairs.txt')
open('adj_noun_pairs.txt', 'wb').write(r.content)
sc = SparkContext.getOrCreate()
# Data file at https://www.cse.ust.hk/msbd5003/data

lines = sc.textFile('adj_noun_pairs.txt')
lines.count()
3162692
lines.getNumPartitions()
2
lines.take(5)
['early radical',
 'french revolution',
 'pejorative way',
 'violent means',
 'positive label']
# Converting lines into word pairs. 
# Data is dirty: some lines have more than 2 words, so filter them out.
pairs = lines.map(lambda l: tuple(l.split())).filter(lambda p: len(p)==2)
pairs.cache()
PythonRDD[4] at RDD at PythonRDD.scala:53
pairs.take(5)
[('early', 'radical'),
 ('french', 'revolution'),
 ('pejorative', 'way'),
 ('violent', 'means'),
 ('positive', 'label')]
N = pairs.count()
N
3162674
# Compute the frequency of each pair.
# Ignore pairs that not frequent enough
pair_freqs = pairs.map(lambda p: (p,1)).reduceByKey(lambda f1, f2: f1 + f2) \
                  .filter(lambda pf: pf[1] >= 100)
pair_freqs.take(5)
[(('political', 'philosophy'), 160),
 (('human', 'society'), 154),
 (('16th', 'century'), 950),
 (('first', 'man'), 166),
 (('same', 'time'), 2744)]
# Computing the frequencies of the adjectives and the nouns
a_freqs = pairs.map(lambda p: (p[0],1)).reduceByKey(lambda x,y: x+y)
n_freqs = pairs.map(lambda p: (p[1],1)).reduceByKey(lambda x,y: x+y)
a_freqs.take(5)
[('violent', 1191),
 ('positive', 2302),
 ('self-defined', 3),
 ('political', 15935),
 ('differ', 381)]
n_freqs.count()
106333
# Broadcasting the adjective and noun frequencies. 
#a_dict = a_freqs.collectAsMap()
#a_dict = sc.parallelize(a_dict).map(lambda x: x)
n_dict = sc.broadcast(n_freqs.collectAsMap())
a_dict = sc.broadcast(a_freqs.collectAsMap())
a_dict.value['violent']
1191
from math import *

# Computing the PMI for a pair.
def pmi_score(pair_freq):
    w1, w2 = pair_freq[0]
    f = pair_freq[1]
    pmi = log(float(f)*N/(a_dict.value[w1]*n_dict.value[w2]), 2)
    return pmi, (w1, w2)
# Computing the PMI for all pairs.
scored_pairs = pair_freqs.map(pmi_score)
# Printing the most strongly associated pairs. 
scored_pairs.top(10)
[(14.41018838546462, ('magna', 'carta')),
 (13.071365888694997, ('polish-lithuanian', 'Commonwealth')),
 (12.990597616733414, ('nitrous', 'oxide')),
 (12.64972604311254, ('latter-day', 'Saints')),
 (12.50658937509916, ('stainless', 'steel')),
 (12.482331020687814, ('pave', 'runway')),
 (12.19140721768055, ('corporal', 'punishment')),
 (12.183248694293388, ('capital', 'punishment')),
 (12.147015483562537, ('rush', 'yard')),
 (12.109945794428935, ('globular', 'cluster'))]

Another way

n_dict = n_freqs.collectAsMap()
a_dict = a_freqs.collectAsMap()
from math import *

# Computing the PMI for a pair.
def pmi_score(pair_freq):
    w1, w2 = pair_freq[0]
    f = pair_freq[1]
    pmi = log(float(f)*N/(a_dict[w1]*n_dict[w2]), 2)
    return pmi, (w1, w2)

scored_pairs = pair_freqs.map(pmi_score)
scored_pairs.top(10)
[(14.41018838546462, ('magna', 'carta')),
 (13.071365888694997, ('polish-lithuanian', 'Commonwealth')),
 (12.990597616733414, ('nitrous', 'oxide')),
 (12.64972604311254, ('latter-day', 'Saints')),
 (12.50658937509916, ('stainless', 'steel')),
 (12.482331020687814, ('pave', 'runway')),
 (12.19140721768055, ('corporal', 'punishment')),
 (12.183248694293388, ('capital', 'punishment')),
 (12.147015483562537, ('rush', 'yard')),
 (12.109945794428935, ('globular', 'cluster'))]