Source code for trigrams

# -*- coding: UTF-8 -*-

"""
Simple trigrams-based text generation
"""

__version__ = '0.1.0'

import re
import json
from random import sample


[docs]class TrigramsDB(object): """ A trigrams database. It has two main methods: ``feed``, to initialize it with some existing text, and ``generate``, to generate some new text. The more text you feed it, the more "random" the generated text will be. """ _WSEP = '###' # words separator def __init__(self, dbfile=None): """ Initialize a new trigrams database. If ``dbfile`` is given, the database is read and written from/to this file. """ self.dbfile = dbfile self._load()
[docs] def save(self, output=None): """ Save the database to a file. If ``output`` is not given, the ``dbfile`` given in the constructor is used. """ if output is None: if self.dbfile is None: return output = self.dbfile with open(output, 'w') as f: f.write(self._dump())
[docs] def feed(self, text=None, source=None): """ Feed some text to the database, either from a string (``text``) or a file (``source``). >>> db = TrigramsDB() >>> db.feed("This is my text") >>> db.feed(source="some/file.txt") """ if text is not None: words = re.split(r'\s+', text) wlen = len(words) for i in range(wlen - 2): self._insert(words[i:i+3]) if source is not None: with open(source, 'r') as f: self.feed(f.read())
[docs] def generate(self, **kwargs): """ Generate some text from the database. By default only 70 words are generated, but you can change this using keyword arguments. Keyword arguments: - ``wlen``: maximum length (words) - ``words``: a list of words to use to begin the text with """ words = list(map(self._sanitize, kwargs.get('words', []))) max_wlen = kwargs.get('wlen', 70) wlen = len(words) if wlen < 2: if not self._db: return '' if wlen == 0: words = sample(self._db.keys(), 1)[0].split(self._WSEP) elif wlen == 1: spl = [k for k in self._db.keys() if k.startswith(words[0]+self._WSEP)] words.append(sample(spl, 1)[0].split(self._WSEP)[1]) wlen = 2 while wlen < max_wlen: next_word = self._get(words[-2], words[-1]) if next_word is None: break words.append(next_word) wlen += 1 return ' '.join(words)
def _load(self): """ Load the database from its ``dbfile`` if it has one """ if self.dbfile is not None: with open(self.dbfile, 'r') as f: self._db = json.loads(f.read()) else: self._db = {} def _dump(self): """ Return a string version of the database, which can then be used by ``_load`` to get the original object back. """ return json.dumps(self._db) def _get(self, word1, word2): """ Return a possible next word after ``word1`` and ``word2``, or ``None`` if there's no possibility. """ key = self._WSEP.join([self._sanitize(word1), self._sanitize(word2)]) key = key.lower() if key not in self._db: return return sample(self._db[key], 1)[0] def _sanitize(self, word): """ Sanitize a word for insertion in the DB """ return word.replace(self._WSEP, '') def _insert(self, trigram): """ Insert a trigram in the DB """ words = list(map(self._sanitize, trigram)) key = self._WSEP.join(words[:2]).lower() next_word = words[2] self._db.setdefault(key, []) # we could use a set here, but sets are not serializables in JSON. This # is the same reason we use dicts instead of defaultdicts. if next_word not in self._db[key]: self._db[key].append(next_word)