Karpathy Video: makemore 1

Karpathy video - simple neural net bigram, neural-net model to predict names.

2 114

References

Notes

What is MakeMore? It makes more of what you give it. We're building character level language model - predict next character given a sequence of characters. Bigram language model - working with two characters at a time. With a bigram language model, you look at only the previous character to predict the next one.

words = open("names.txt",'r').read().splitlines()
words[0:10]
out[2]

['emma',

'olivia',

'ava',

'isabella',

'sophia',

'charlotte',

'mia',

'amelia',

'harper',

'evelyn']

len(words)
out[3]

32033

min(len(w) for w in words)
out[4]

2

max(len(w) for w in words)
out[5]

15

b = {} # Counting frequencies of bigrams
for w in words:
  # Special start character + list characters + special end character
  chs =['<S>'] + list(w) + ["<E>"]
  for ch1, ch2 in zip(chs,chs[1:]):
    bigram = (ch1, ch2)
    b[bigram] = b.get(bigram,0) + 1
out[6]
sorted(b.items(),key=lambda kv: kv[1],reverse=True)
out[7]

[(('n', '<E>'), 6763),

(('a', '<E>'), 6640),

(('a', 'n'), 5438),

(('<S>', 'a'), 4410),

(('e', '<E>'), 3983),

(('a', 'r'), 3264),

(('e', 'l'), 3248),

(('r', 'i'), 3033),

(('n', 'a'), 2977),

(('<S>', 'k'), 2963),

(('l', 'e'), 2921),

(('e', 'n'), 2675),

(('l', 'a'), 2623),

(('m', 'a'), 2590),

(('<S>', 'm'), 2538),

(('a', 'l'), 2528),

(('i', '<E>'), 2489),

(('l', 'i'), 2480),

(('i', 'a'), 2445),

(('<S>', 'j'), 2422),

(('o', 'n'), 2411),

(('h', '<E>'), 2409),

(('r', 'a'), 2356),

(('a', 'h'), 2332),

(('h', 'a'), 2244),

(('y', 'a'), 2143),

(('i', 'n'), 2126),

(('<S>', 's'), 2055),

(('a', 'y'), 2050),

(('y', '<E>'), 2007),

(('e', 'r'), 1958),

(('n', 'n'), 1906),

(('y', 'n'), 1826),

(('k', 'a'), 1731),

(('n', 'i'), 1725),

(('r', 'e'), 1697),

(('<S>', 'd'), 1690),

(('i', 'e'), 1653),

(('a', 'i'), 1650),

(('<S>', 'r'), 1639),

(('a', 'm'), 1634),

(('l', 'y'), 1588),

(('<S>', 'l'), 1572),

(('<S>', 'c'), 1542),

(('<S>', 'e'), 1531),

(('j', 'a'), 1473),

(('r', '<E>'), 1377),

(('n', 'e'), 1359),

(('l', 'l'), 1345),

(('i', 'l'), 1345),

(('i', 's'), 1316),

(('l', '<E>'), 1314),

(('<S>', 't'), 1308),

(('<S>', 'b'), 1306),

(('d', 'a'), 1303),

(('s', 'h'), 1285),

(('d', 'e'), 1283),

(('e', 'e'), 1271),

(('m', 'i'), 1256),

(('s', 'a'), 1201),

(('s', '<E>'), 1169),

(('<S>', 'n'), 1146),

(('a', 's'), 1118),

(('y', 'l'), 1104),

(('e', 'y'), 1070),

(('o', 'r'), 1059),

(('a', 'd'), 1042),

(('t', 'a'), 1027),

(('<S>', 'z'), 929),

(('v', 'i'), 911),

(('k', 'e'), 895),

(('s', 'e'), 884),

(('<S>', 'h'), 874),

(('r', 'o'), 869),

(('e', 's'), 861),

(('z', 'a'), 860),

(('o', '<E>'), 855),

(('i', 'r'), 849),

(('b', 'r'), 842),

(('a', 'v'), 834),

(('m', 'e'), 818),

(('e', 'i'), 818),

(('c', 'a'), 815),

(('i', 'y'), 779),

(('r', 'y'), 773),

(('e', 'm'), 769),

(('s', 't'), 765),

(('h', 'i'), 729),

(('t', 'e'), 716),

(('n', 'd'), 704),

(('l', 'o'), 692),

(('a', 'e'), 692),

(('a', 't'), 687),

(('s', 'i'), 684),

(('e', 'a'), 679),

(('d', 'i'), 674),

(('h', 'e'), 674),

(('<S>', 'g'), 669),

(('t', 'o'), 667),

(('c', 'h'), 664),

(('b', 'e'), 655),

(('t', 'h'), 647),

(('v', 'a'), 642),

(('o', 'l'), 619),

(('<S>', 'i'), 591),

(('i', 'o'), 588),

(('e', 't'), 580),

(('v', 'e'), 568),

(('a', 'k'), 568),

(('a', 'a'), 556),

(('c', 'e'), 551),

(('a', 'b'), 541),

(('i', 't'), 541),

(('<S>', 'y'), 535),

(('t', 'i'), 532),

(('s', 'o'), 531),

(('m', '<E>'), 516),

(('d', '<E>'), 516),

(('<S>', 'p'), 515),

(('i', 'c'), 509),

(('k', 'i'), 509),

(('o', 's'), 504),

(('n', 'o'), 496),

(('t', '<E>'), 483),

(('j', 'o'), 479),

(('u', 's'), 474),

(('a', 'c'), 470),

(('n', 'y'), 465),

(('e', 'v'), 463),

(('s', 's'), 461),

(('m', 'o'), 452),

(('i', 'k'), 445),

(('n', 't'), 443),

(('i', 'd'), 440),

(('j', 'e'), 440),

(('a', 'z'), 435),

(('i', 'g'), 428),

(('i', 'm'), 427),

(('r', 'r'), 425),

(('d', 'r'), 424),

(('<S>', 'f'), 417),

(('u', 'r'), 414),

(('r', 'l'), 413),

(('y', 's'), 401),

(('<S>', 'o'), 394),

(('e', 'd'), 384),

(('a', 'u'), 381),

(('c', 'o'), 380),

(('k', 'y'), 379),

(('d', 'o'), 378),

(('<S>', 'v'), 376),

(('t', 't'), 374),

(('z', 'e'), 373),

(('z', 'i'), 364),

(('k', '<E>'), 363),

(('g', 'h'), 360),

(('t', 'r'), 352),

(('k', 'o'), 344),

(('t', 'y'), 341),

(('g', 'e'), 334),

(('g', 'a'), 330),

(('l', 'u'), 324),

(('b', 'a'), 321),

(('d', 'y'), 317),

(('c', 'k'), 316),

(('<S>', 'w'), 307),

(('k', 'h'), 307),

(('u', 'l'), 301),

(('y', 'e'), 301),

(('y', 'r'), 291),

(('m', 'y'), 287),

(('h', 'o'), 287),

(('w', 'a'), 280),

(('s', 'l'), 279),

(('n', 's'), 278),

(('i', 'z'), 277),

(('u', 'n'), 275),

(('o', 'u'), 275),

(('n', 'g'), 273),

(('y', 'd'), 272),

(('c', 'i'), 271),

(('y', 'o'), 271),

(('i', 'v'), 269),

(('e', 'o'), 269),

(('o', 'm'), 261),

(('r', 'u'), 252),

(('f', 'a'), 242),

(('b', 'i'), 217),

(('s', 'y'), 215),

(('n', 'c'), 213),

(('h', 'y'), 213),

(('p', 'a'), 209),

(('r', 't'), 208),

(('q', 'u'), 206),

(('p', 'h'), 204),

(('h', 'r'), 204),

(('j', 'u'), 202),

(('g', 'r'), 201),

(('p', 'e'), 197),

(('n', 'l'), 195),

(('y', 'i'), 192),

(('g', 'i'), 190),

(('o', 'd'), 190),

(('r', 's'), 190),

(('r', 'd'), 187),

(('h', 'l'), 185),

(('s', 'u'), 185),

(('a', 'x'), 182),

(('e', 'z'), 181),

(('e', 'k'), 178),

(('o', 'v'), 176),

(('a', 'j'), 175),

(('o', 'h'), 171),

(('u', 'e'), 169),

(('m', 'm'), 168),

(('a', 'g'), 168),

(('h', 'u'), 166),

(('x', '<E>'), 164),

(('u', 'a'), 163),

(('r', 'm'), 162),

(('a', 'w'), 161),

(('f', 'i'), 160),

(('z', '<E>'), 160),

(('u', '<E>'), 155),

(('u', 'm'), 154),

(('e', 'c'), 153),

(('v', 'o'), 153),

(('e', 'h'), 152),

(('p', 'r'), 151),

(('d', 'd'), 149),

(('o', 'a'), 149),

(('w', 'e'), 149),

(('w', 'i'), 148),

(('y', 'm'), 148),

(('z', 'y'), 147),

(('n', 'z'), 145),

(('y', 'u'), 141),

(('r', 'n'), 140),

(('o', 'b'), 140),

(('k', 'l'), 139),

(('m', 'u'), 139),

(('l', 'd'), 138),

(('h', 'n'), 138),

(('u', 'd'), 136),

(('<S>', 'x'), 134),

(('t', 'l'), 134),

(('a', 'f'), 134),

(('o', 'e'), 132),

(('e', 'x'), 132),

(('e', 'g'), 125),

(('f', 'e'), 123),

(('z', 'l'), 123),

(('u', 'i'), 121),

(('v', 'y'), 121),

(('e', 'b'), 121),

(('r', 'h'), 121),

(('j', 'i'), 119),

(('o', 't'), 118),

(('d', 'h'), 118),

(('h', 'm'), 117),

(('c', 'l'), 116),

(('o', 'o'), 115),

(('y', 'c'), 115),

(('o', 'w'), 114),

(('o', 'c'), 114),

(('f', 'r'), 114),

(('b', '<E>'), 114),

(('m', 'b'), 112),

(('z', 'o'), 110),

(('i', 'b'), 110),

(('i', 'u'), 109),

(('k', 'r'), 109),

(('g', '<E>'), 108),

(('y', 'v'), 106),

(('t', 'z'), 105),

(('b', 'o'), 105),

(('c', 'y'), 104),

(('y', 't'), 104),

(('u', 'b'), 103),

(('u', 'c'), 103),

(('x', 'a'), 103),

(('b', 'l'), 103),

(('o', 'y'), 103),

(('x', 'i'), 102),

(('i', 'f'), 101),

(('r', 'c'), 99),

(('c', '<E>'), 97),

(('m', 'r'), 97),

(('n', 'u'), 96),

(('o', 'p'), 95),

(('i', 'h'), 95),

(('k', 's'), 95),

(('l', 's'), 94),

(('u', 'k'), 93),

(('<S>', 'q'), 92),

(('d', 'u'), 92),

(('s', 'm'), 90),

(('r', 'k'), 90),

(('i', 'x'), 89),

(('v', '<E>'), 88),

(('y', 'k'), 86),

(('u', 'w'), 86),

(('g', 'u'), 85),

(('b', 'y'), 83),

(('e', 'p'), 83),

(('g', 'o'), 83),

(('s', 'k'), 82),

(('u', 't'), 82),

(('a', 'p'), 82),

(('e', 'f'), 82),

(('i', 'i'), 82),

(('r', 'v'), 80),

(('f', '<E>'), 80),

(('t', 'u'), 78),

(('y', 'z'), 78),

(('<S>', 'u'), 78),

(('l', 't'), 77),

(('r', 'g'), 76),

(('c', 'r'), 76),

(('i', 'j'), 76),

(('w', 'y'), 73),

(('z', 'u'), 73),

(('l', 'v'), 72),

(('h', 't'), 71),

(('j', '<E>'), 71),

(('x', 't'), 70),

(('o', 'i'), 69),

(('e', 'u'), 69),

(('o', 'k'), 68),

(('b', 'd'), 65),

(('a', 'o'), 63),

(('p', 'i'), 61),

(('s', 'c'), 60),

(('d', 'l'), 60),

(('l', 'm'), 60),

(('a', 'q'), 60),

(('f', 'o'), 60),

(('p', 'o'), 59),

(('n', 'k'), 58),

(('w', 'n'), 58),

(('u', 'h'), 58),

(('e', 'j'), 55),

(('n', 'v'), 55),

(('s', 'r'), 55),

(('o', 'z'), 54),

(('i', 'p'), 53),

(('l', 'b'), 52),

(('i', 'q'), 52),

(('w', '<E>'), 51),

(('m', 'c'), 51),

(('s', 'p'), 51),

(('e', 'w'), 50),

(('k', 'u'), 50),

(('v', 'r'), 48),

(('u', 'g'), 47),

(('o', 'x'), 45),

(('u', 'z'), 45),

(('z', 'z'), 45),

(('j', 'h'), 45),

(('b', 'u'), 45),

(('o', 'g'), 44),

(('n', 'r'), 44),

(('f', 'f'), 44),

(('n', 'j'), 44),

(('z', 'h'), 43),

(('c', 'c'), 42),

(('r', 'b'), 41),

(('x', 'o'), 41),

(('b', 'h'), 41),

(('p', 'p'), 39),

(('x', 'l'), 39),

(('h', 'v'), 39),

(('b', 'b'), 38),

(('m', 'p'), 38),

(('x', 'x'), 38),

(('u', 'v'), 37),

(('x', 'e'), 36),

(('w', 'o'), 36),

(('c', 't'), 35),

(('z', 'm'), 35),

(('t', 's'), 35),

(('m', 's'), 35),

(('c', 'u'), 35),

(('o', 'f'), 34),

(('u', 'x'), 34),

(('k', 'w'), 34),

(('p', '<E>'), 33),

(('g', 'l'), 32),

(('z', 'r'), 32),

(('d', 'n'), 31),

(('g', 't'), 31),

(('g', 'y'), 31),

(('h', 's'), 31),

(('x', 's'), 31),

(('g', 's'), 30),

(('x', 'y'), 30),

(('y', 'g'), 30),

(('d', 'm'), 30),

(('d', 's'), 29),

(('h', 'k'), 29),

(('y', 'x'), 28),

(('q', '<E>'), 28),

(('g', 'n'), 27),

(('y', 'b'), 27),

(('g', 'w'), 26),

(('n', 'h'), 26),

(('k', 'n'), 26),

(('g', 'g'), 25),

(('d', 'g'), 25),

(('l', 'c'), 25),

(('r', 'j'), 25),

(('w', 'u'), 25),

(('l', 'k'), 24),

(('m', 'd'), 24),

(('s', 'w'), 24),

(('s', 'n'), 24),

(('h', 'd'), 24),

(('w', 'h'), 23),

(('y', 'j'), 23),

(('y', 'y'), 23),

(('r', 'z'), 23),

(('d', 'w'), 23),

(('w', 'r'), 22),

(('t', 'n'), 22),

(('l', 'f'), 22),

(('y', 'h'), 22),

(('r', 'w'), 21),

(('s', 'b'), 21),

(('m', 'n'), 20),

(('f', 'l'), 20),

(('w', 's'), 20),

(('k', 'k'), 20),

(('h', 'z'), 20),

(('g', 'd'), 19),

(('l', 'h'), 19),

(('n', 'm'), 19),

(('x', 'z'), 19),

(('u', 'f'), 19),

(('f', 't'), 18),

(('l', 'r'), 18),

(('p', 't'), 17),

(('t', 'c'), 17),

(('k', 't'), 17),

(('d', 'v'), 17),

(('u', 'p'), 16),

(('p', 'l'), 16),

(('l', 'w'), 16),

(('p', 's'), 16),

(('o', 'j'), 16),

(('r', 'q'), 16),

(('y', 'p'), 15),

(('l', 'p'), 15),

(('t', 'v'), 15),

(('r', 'p'), 14),

(('l', 'n'), 14),

(('e', 'q'), 14),

(('f', 'y'), 14),

(('s', 'v'), 14),

(('u', 'j'), 14),

(('v', 'l'), 14),

(('q', 'a'), 13),

(('u', 'y'), 13),

(('q', 'i'), 13),

(('w', 'l'), 13),

(('p', 'y'), 12),

(('y', 'f'), 12),

(('c', 'q'), 11),

(('j', 'r'), 11),

(('n', 'w'), 11),

(('n', 'f'), 11),

(('t', 'w'), 11),

(('m', 'z'), 11),

(('u', 'o'), 10),

(('f', 'u'), 10),

(('l', 'z'), 10),

(('h', 'w'), 10),

(('u', 'q'), 10),

(('j', 'y'), 10),

(('s', 'z'), 10),

(('s', 'd'), 9),

(('j', 'l'), 9),

(('d', 'j'), 9),

(('k', 'm'), 9),

(('r', 'f'), 9),

(('h', 'j'), 9),

(('v', 'n'), 8),

(('n', 'b'), 8),

(('i', 'w'), 8),

(('h', 'b'), 8),

(('b', 's'), 8),

(('w', 't'), 8),

(('w', 'd'), 8),

(('v', 'v'), 7),

(('v', 'u'), 7),

(('j', 's'), 7),

(('m', 'j'), 7),

(('f', 's'), 6),

(('l', 'g'), 6),

(('l', 'j'), 6),

(('j', 'w'), 6),

(('n', 'x'), 6),

(('y', 'q'), 6),

(('w', 'k'), 6),

(('g', 'm'), 6),

(('x', 'u'), 5),

(('m', 'h'), 5),

(('m', 'l'), 5),

(('j', 'm'), 5),

(('c', 's'), 5),

(('j', 'v'), 5),

(('n', 'p'), 5),

(('d', 'f'), 5),

(('x', 'd'), 5),

(('z', 'b'), 4),

(('f', 'n'), 4),

(('x', 'c'), 4),

(('m', 't'), 4),

(('t', 'm'), 4),

(('z', 'n'), 4),

(('z', 't'), 4),

(('p', 'u'), 4),

(('c', 'z'), 4),

(('b', 'n'), 4),

(('z', 's'), 4),

(('f', 'w'), 4),

(('d', 't'), 4),

(('j', 'd'), 4),

(('j', 'c'), 4),

(('y', 'w'), 4),

(('v', 'k'), 3),

(('x', 'w'), 3),

(('t', 'j'), 3),

(('c', 'j'), 3),

(('q', 'w'), 3),

(('g', 'b'), 3),

(('o', 'q'), 3),

(('r', 'x'), 3),

(('d', 'c'), 3),

(('g', 'j'), 3),

(('x', 'f'), 3),

(('z', 'w'), 3),

(('d', 'k'), 3),

(('u', 'u'), 3),

(('m', 'v'), 3),

(('c', 'x'), 3),

(('l', 'q'), 3),

(('p', 'b'), 2),

(('t', 'g'), 2),

(('q', 's'), 2),

(('t', 'x'), 2),

(('f', 'k'), 2),

(('b', 't'), 2),

(('j', 'n'), 2),

(('k', 'c'), 2),

(('z', 'k'), 2),

(('s', 'j'), 2),

(('s', 'f'), 2),

(('z', 'j'), 2),

(('n', 'q'), 2),

(('f', 'z'), 2),

(('h', 'g'), 2),

(('w', 'w'), 2),

(('k', 'j'), 2),

(('j', 'k'), 2),

(('w', 'm'), 2),

(('z', 'c'), 2),

(('z', 'v'), 2),

(('w', 'f'), 2),

(('q', 'm'), 2),

(('k', 'z'), 2),

(('j', 'j'), 2),

(('z', 'p'), 2),

(('j', 't'), 2),

(('k', 'b'), 2),

(('m', 'w'), 2),

(('h', 'f'), 2),

(('c', 'g'), 2),

(('t', 'f'), 2),

(('h', 'c'), 2),

(('q', 'o'), 2),

(('k', 'd'), 2),

(('k', 'v'), 2),

(('s', 'g'), 2),

(('z', 'd'), 2),

(('q', 'r'), 1),

(('d', 'z'), 1),

(('p', 'j'), 1),

(('q', 'l'), 1),

(('p', 'f'), 1),

(('q', 'e'), 1),

(('b', 'c'), 1),

(('c', 'd'), 1),

(('m', 'f'), 1),

(('p', 'n'), 1),

(('w', 'b'), 1),

(('p', 'c'), 1),

(('h', 'p'), 1),

(('f', 'h'), 1),

(('b', 'j'), 1),

(('f', 'g'), 1),

(('z', 'g'), 1),

(('c', 'p'), 1),

(('p', 'k'), 1),

(('p', 'm'), 1),

(('x', 'n'), 1),

(('s', 'q'), 1),

(('k', 'f'), 1),

(('m', 'k'), 1),

(('x', 'h'), 1),

(('g', 'f'), 1),

(('v', 'b'), 1),

(('j', 'p'), 1),

(('g', 'z'), 1),

(('v', 'd'), 1),

(('d', 'b'), 1),

(('v', 'h'), 1),

(('h', 'h'), 1),

(('g', 'v'), 1),

(('d', 'q'), 1),

(('x', 'b'), 1),

(('w', 'z'), 1),

(('h', 'q'), 1),

(('j', 'b'), 1),

(('x', 'm'), 1),

(('w', 'g'), 1),

(('t', 'b'), 1),

(('z', 'x'), 1)]

import torch
N = torch.zeros((27,27),dtype=torch.int32)
out[8]
# Character Lookup Array
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)} # Mapping of characters to integers
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

b = {} # Counting frequencies of bigrams
for w in words:
  # Special start character + list characters + special end character
  chs =['.'] + list(w) + ["."]
  for ch1, ch2 in zip(chs,chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    N[ix1,ix2] += 1
out[9]
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(16,16))
plt.imshow(N,cmap='Blues')
for i in range(27):
  for j in range(27):
    chstr = itos[i] + itos[j]
    plt.text(j,i,chstr,ha="center",va="bottom",color="gray")
    plt.text(j,i,N[i,j].item(),ha="center",va="top",color="gray")
plt.axis("off")
out[10]

(-0.5, 26.5, 26.5, -0.5)

Jupyter Notebook Image

<Figure size 1600x1600 with 1 Axes>

N[0,:]
p = N[0].float()
p = p / p.sum()
p
out[11]

tensor([0.0000, 0.1377, 0.0408, 0.0481, 0.0528, 0.0478, 0.0130, 0.0209, 0.0273,

0.0184, 0.0756, 0.0925, 0.0491, 0.0792, 0.0358, 0.0123, 0.0161, 0.0029,

0.0512, 0.0642, 0.0408, 0.0024, 0.0117, 0.0096, 0.0042, 0.0167, 0.0290])

g = torch.Generator().manual_seed(214783647)
ix = torch.multinomial(p,num_samples=1,replacement=True,generator=g).item()
itos[ix]
out[12]

'.'

g = torch.Generator().manual_seed(214783647)
p = torch.randn(3,generator=g)
p = p /p.sum()
p
out[13]

tensor([0.1830, 0.4276, 0.3894])

torch.multinomial(p,num_samples=100,replacement=True,generator=g)
out[14]

tensor([0, 0, 1, 2, 2, 1, 0, 1, 1, 0, 1, 1, 1, 2, 0, 2, 1, 0, 2, 2, 0, 0, 1, 0,

2, 1, 0, 0, 1, 2, 0, 2, 0, 0, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 0,

0, 0, 2, 1, 0, 2, 1, 2, 2, 1, 1, 2, 1, 2, 1, 2, 0, 2, 1, 2, 2, 0, 2, 2,

1, 2, 1, 2, 1, 1, 2, 0, 1, 0, 2, 0, 1, 1, 1, 1, 2, 1, 2, 1, 2, 0, 1, 2,

2, 0, 1, 2])

p.shape
out[15]

torch.Size([27])

P = N.float()
P = P / P.sum(1,keepdim=True)
out[16]
P[0].sum()
out[17]

tensor(1.)

g = torch.Generator().manual_seed(214783647)
ix = 0
for i in range(20):
  out = []
  while True:
    p = N[ix].float()
    p = p / p.sum()
    ix = torch.multinomial(p,num_samples=1,replacement=True,generator=g).item()
    out.append(itos[ix])
    if ix==0:
      break
  print(''.join(out))
out[18]

kaa.
akyremilsandearvikyloria.
hte.
reckadevaiacadivi.
atayama.
a.
luloradan.
dror.
ancherwite.
drin.
konnilinninadrala.
m.
shri.
ka.
de.
avaxyle.
dyncovi.
enesman.
grran.
kah.

# Goal": maximize likelihood of the data wrt to model parameters (statistical modeling); equivakent to maximizing the log lokelihood (because log is monotonic); equivalent to minimizing the negative log likelihood; eqivalent to minimizing the negative log likelihood
# log(a*b*c) = log(a) + log(b) + log(c)
log_likelihood = 0.0
for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    prob = P[ix1, ix2]
    logprop = torch.log(prob)
    log_likelihood += logprop
    n+=1
    # print(f'{ch1}{ch2}: {prob :.4f} {logprop :.4f}')
print(f"{log_likelihood=}")
nll = -log_likelihood
print(f"{nll=}")
print(f"{nll/n}")
out[19]

log_likelihood=tensor(-559891.7500)
nll=tensor(559891.7500)
2.4539220333099365

xs, ys = [], []

for w in words[:1]:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    print(ch1,ch2)
    xs.append(ix1)
    ys.append(ix2)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
out[20]

. e
e m
m m
m a
a .

xs
out[21]

tensor([ 0, 5, 13, 13, 1])

ys
out[22]

tensor([ 5, 13, 13, 1, 0])

import torch.nn.functional as F
xenc = F.one_hot(xs,num_classes=27).float()
xenc
out[23]

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

0., 0., 0., 0., 0., 0., 0., 0., 0.],

[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

0., 0., 0., 0., 0., 0., 0., 0., 0.],

[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,

0., 0., 0., 0., 0., 0., 0., 0., 0.],

[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,

0., 0., 0., 0., 0., 0., 0., 0., 0.],

[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

0., 0., 0., 0., 0., 0., 0., 0., 0.]])

xenc.shape
out[24]

torch.Size([5, 27])

plt.imshow(xenc)
out[25]

<matplotlib.image.AxesImage at 0x7eb1d1277c10>

Jupyter Notebook Image

<Figure size 640x480 with 1 Axes>

W = torch.randn((27,27))
xenc @ W
out[26]

tensor([[-2.0454e+00, -3.5652e-01, -9.9976e-01, -2.5065e-01, -2.0485e-02,

-3.7248e-01, 1.3631e-02, -6.6458e-01, -4.3562e-01, 1.1678e+00,

-3.1893e-01, -3.2290e-01, -1.4551e+00, -2.3795e+00, 2.7208e+00,

1.9710e+00, -4.3540e-02, -6.3442e-01, -1.5164e+00, -1.2616e+00,

1.2295e+00, 8.5375e-02, 7.0326e-01, 1.0973e+00, 1.6222e-01,

1.9416e+00, 2.6361e-02],

[-1.5266e+00, 1.7824e+00, -1.9095e+00, 1.4598e+00, 1.1004e+00,

-2.2231e+00, 7.3944e-01, 5.8398e-01, -5.3106e-01, 4.1470e-01,

-5.7726e-01, 9.4285e-01, 3.7042e-02, 2.1097e+00, -3.0585e-02,

-1.1600e+00, 4.0923e-01, 2.6036e-01, 7.5992e-01, 1.0163e+00,

7.7561e-01, 7.8655e-01, 3.0344e-01, 1.9462e+00, -8.6809e-01,

-7.0019e-01, -1.6206e-01],

[-1.4347e+00, 1.3512e+00, -2.9737e-01, -1.5572e+00, -3.0751e-01,

1.7622e+00, 1.5343e+00, 1.2580e+00, -1.0035e-01, -8.3215e-01,

1.2751e+00, 9.5201e-01, -1.6963e-01, 1.7545e-01, 1.1466e-01,

2.1630e-03, 1.0484e+00, -6.6757e-01, -2.3991e-02, -5.1581e-01,

3.3165e-01, -1.7703e+00, -1.6880e+00, -7.7751e-01, 9.1288e-01,

9.1313e-01, -1.1471e+00],

[-1.4347e+00, 1.3512e+00, -2.9737e-01, -1.5572e+00, -3.0751e-01,

1.7622e+00, 1.5343e+00, 1.2580e+00, -1.0035e-01, -8.3215e-01,

1.2751e+00, 9.5201e-01, -1.6963e-01, 1.7545e-01, 1.1466e-01,

2.1630e-03, 1.0484e+00, -6.6757e-01, -2.3991e-02, -5.1581e-01,

3.3165e-01, -1.7703e+00, -1.6880e+00, -7.7751e-01, 9.1288e-01,

9.1313e-01, -1.1471e+00],

[-5.8451e-01, -2.4627e-01, 1.9328e+00, -6.6693e-01, 1.1274e+00,

1.1800e+00, -8.2000e-01, -1.1936e+00, -1.7690e+00, 3.9007e-01,

-4.8504e-01, 1.1760e+00, -1.5293e-01, 1.0496e+00, 1.6835e-01,

8.9445e-01, -1.2797e+00, -2.9126e-01, 2.3389e+00, 2.7473e-01,

-5.9567e-01, -1.2981e+00, 8.3209e-01, 1.0973e+00, -1.6951e+00,

-1.6946e+00, -1.2167e+00]])

out[27]

You can read more about how comments are sorted in this blog post.

User Comments