Date: Friday, Mar 28, 2025 at 9:30 AM US/Eastern
Writing Python code that merely works is one thing—writing code that remains clear, adaptable, and well-tested over time is another. This seminar is all about designing Python code that you won’t regret later, striking the right balance between readability, extensibility, and testability. Whether you’re working solo or contributing to a team, the principles we’ll cover will help ensure that your code remains maintainable and free of unnecessary complexity.
We’ll start by exploring best practices for writing clear, well-documented code, including how to craft docstrings that actually help, choose meaningful names, and structure modules for easier comprehension. We’ll then dive into modular design, discussing how to break down complex logic into reusable, loosely coupled components that are easier to test and extend. Along the way, we’ll highlight common pitfalls—such as hidden side effects, excessive coupling, and overuse of global state—that make code harder to maintain.
From there, we’ll turn to testing strategies that go beyond the basics of unit tests. In addition to writing traditional test cases, we’ll introduce property-based testing with Hypothesis, a powerful tool that generates test cases dynamically, helping you uncover edge cases you might never have thought to check. We’ll also discuss when to use unit tests, integration tests, and property-based testing to maximize coverage without writing excessive boilerplate.
By the end of this session, you’ll have a concrete set of techniques to write Python code that’s easier to read, modify, and test—saving you and your collaborators from future headaches.
python -m pip install numpy pandas scipy pytest hypothesis pyyaml httpx
print("Let's take a look!")
We start by illustrating a consequence of poorly written code. After all, we cannot simply assert that code quality matters without specifying why.
To do this, we’ll start by describing “churn”—when changes to external requirements result in changes to code that create risk.
As an example, imagine the following reporting code that operates on a simple CSV file looking like the following:
date,entity,value
2020-01-02,abc,100.73
2020-01-02,def,98.81
2020-01-02,xyz,103.58
2020-01-03,abc,102.05
2020-01-03,def,99.74
2020-01-03,xyz,102.85
2020-01-04,abc,100.76
2020-01-04,def,99.12
2020-01-04,xyz,102.89
Using pure Python, we’ll read in this data and report the average value per entity.
from collections import defaultdict, Counter
from csv import reader
from datetime import datetime
from pathlib import Path
def process(filename):
data, count = defaultdict(float), Counter()
with open(filename) as f:
for lineno, (_, entity, value) in enumerate(reader(f), start=1):
if lineno == 1: continue # skip header
data[entity] += float(value)
count[entity] += 1
return {k: round(v / count[k], 2) for k, v in data.items()}
if __name__ == '__main__':
data_dir = Path('data')
results = process(data_dir / 'values.csv')
print(
(title := 'Average Value (per Entity)'),
f'{"":\N{box drawings double horizontal}<{len(title)}}',
f'{"name":^7} {"value":^7}',
f'{"":\N{box drawings light horizontal}^7} {"":\N{box drawings light horizontal}^7}',
sep='\n',
)
for name, value in sorted(results.items()):
print(f' {name:<6} {value:>6,} ')
If we were to nitpick this code in code review, there are some inconsequential
details we could criticize. For example, we could suggest that this code use
pandas.read_csv instead of csv.reader… but the author could argue that they
didn’t want to have a dependency against pandas. We could suggest that the
padding width used in the output formatting specifiers should be parameterised
to avoid an ‘update anomaly’ if someone changes the width in one place but
fails to change it in another… but the author could argue that if this happens,
they’d notice it immediately and just go fix it. As we go back and forth, we
may struggle to find changes to suggest for this code where the author cannot
argue that their original choice was intentional.
Instead, we could agree with the author on likely changes that might occur. For example, could the file format change (and in what ways)? Could the desired analytics change? How might these possible changes then change our code? What would the consequences of that change be?
For example, if the file format changes such that the header line is underlined…
date,entity,value
─────────────────
2020-01-02,abc,100.73
2020-01-02,def,98.81
2020-01-02,xyz,103.58
2020-01-03,abc,102.05
2020-01-03,def,99.74
2020-01-03,xyz,102.85
2020-01-04,abc,100.76
2020-01-04,def,99.12
2020-01-04,xyz,102.89
… the code might only have to change minimally:
from collections import defaultdict, Counter
from csv import reader
from datetime import datetime
from pathlib import Path
def process(filename):
data, count = defaultdict(float), Counter()
with open(filename) as f:
for lineno, fields in enumerate(reader(f), start=1):
if lineno == 1: continue # skip header
if lineno == 2: continue # skip underline
_, entity, value = fields
data[entity] += float(value)
count[entity] += 1
return {k: round(v / count[k], 2) for k, v in data.items()}
if __name__ == '__main__':
data_dir = Path('data')
results = process(data_dir / 'values.new.csv')
print(
(title := 'Average Value (per Entity)'),
f'{"":\N{box drawings double horizontal}<{len(title)}}',
f'{"name":^7} {"value":^7}',
f'{"":\N{box drawings light horizontal}^7} {"":\N{box drawings light horizontal}^7}',
sep='\n',
)
for name, value in sorted(results.items()):
print(f' {name:<6} {value:>6,} ')
However, if the analysis requirements changes such that we want to print the minimum and maximum values for each entity, the code must change significantly:
from collections import defaultdict, Counter
from csv import reader
from datetime import datetime
from pathlib import Path
def process(filename):
data, count = defaultdict(float), Counter()
smallest, largest = defaultdict(lambda: float('+inf')), defaultdict(lambda: float('-inf'))
with open(filename) as f:
for lineno, fields in enumerate(reader(f), start=1):
if lineno == 1: continue # skip header
if lineno == 2: continue # skip underline
_, entity, value = fields
data[entity] += float(value)
count[entity] += 1
smallest[entity] = min(float(value), smallest[entity])
largest[entity] = max(float(value), largest[entity])
return (
{k: round(v / count[k], 2) for k, v in data.items()},
{**smallest},
{**largest},
)
if __name__ == '__main__':
data_dir = Path('data')
average, smallest, largest = process(data_dir / 'values.new.csv')
headings = {'name': 7, 'value': 7, 'minimum': 9, 'maximum': 9}
print(
(title := 'Average Value (per Entity)'),
f'{"":\N{box drawings double horizontal}<{len(title)}}',
' '.join(k.center(v) for k, v in headings.items()),
' '.join(f'{'':\N{box drawings light horizontal}^{v}}' for k, v in headings.items()),
sep='\n',
)
for name, value in sorted(average.items()):
print(f' {name:<6} {value:>6,.2f} {smallest[name]:>8,.2f} {largest[name]:>8,.2f}')
Ignore the exact changes for a moment and consider where the changes are occuring. They are occuring throughout the program. If the program were any longer, consider how these changes would have to be carefully threaded through the program, and the effort required to get that right.
Consider also the consequence of process changing. If this function is used
anywhere else in the code base, we may then be required to revalidate that
code. After all, even if we were simply fixing a bug in process, by fixing
that bug, we may change the behavior in someone else’s code. Even if
identifying this change in behavior does not subsequently affect the code we
wrote—after all, our code is now more nearly correct—it is important that we
identify such a change in behavior if only to communicate to the author of the
other code that our fix may affect them. (Consider that in many circumstances,
“bugs” and “fixes” can be very subjective, and so it may not be sufficient for
us to just tell other users to just “live with it.”)
What are the alternative versions of this program?
Here’s one that still uses pure Python:
from collections import defaultdict, namedtuple
from csv import reader
from datetime import datetime
from itertools import chain
from pathlib import Path
from statistics import mean
def load_data(filename):
Entry = namedtuple('Entry', 'date entity value')
data = defaultdict(lambda: defaultdict(set))
with open(filename) as f:
for lineno, (date, entity, value) in enumerate(reader(f), start=1):
if lineno == 1: continue # skip header
ent = Entry(datetime.fromisoformat(date), entity, float(value))
data[ent.entity][ent.date].add(ent)
return data
if __name__ == '__main__':
data_dir = Path('data')
entries = load_data(data_dir / 'values.csv')
headings = {'name': 7, 'value': 7}
print(
(title := 'Average Value (per Entity)'),
f'{"":\N{box drawings double horizontal}<{len(title)}}',
' '.join(k.center(v) for k, v in headings.items()),
' '.join(f'{'':\N{box drawings light horizontal}^{v}}' for k, v in headings.items()),
sep='\n',
)
for entity, ents in sorted(entries.items()):
print(f' {entity:<6} {mean(x.value for x in chain.from_iterable(ents.values())):>6,.2f}')
Here’s how it changes subject to the two changes previously described:
from collections import defaultdict, namedtuple
from csv import reader
from datetime import datetime
from itertools import chain
from pathlib import Path
from statistics import mean
def load_data(filename):
Entry = namedtuple('Entry', 'date entity value')
data = defaultdict(lambda: defaultdict(set))
with open(filename) as f:
for lineno, fields in enumerate(reader(f), start=1):
if lineno == 1: continue # skip header
if lineno == 2: continue # skip underline
date, entity, value = fields
ent = Entry(datetime.fromisoformat(date), entity, float(value))
data[ent.entity][ent.date].add(ent)
return data
if __name__ == '__main__':
data_dir = Path('data')
entries = load_data(data_dir / 'values.new.csv')
headings = {'name': 7, 'value': 7, 'minimum': 9, 'maximum': 9}
print(
(title := 'Average Value (per Entity)'),
f'{"":\N{box drawings double horizontal}<{len(title)}}',
' '.join(k.center(v) for k, v in headings.items()),
' '.join(f'{'':\N{box drawings light horizontal}^{v}}' for k, v in headings.items()),
sep='\n',
)
for entity, ents in sorted(entries.items()):
print(
f' {entity:<6}',
f' {mean(x.value for x in chain.from_iterable(ents.values())):>6,.2f}',
f' {min(x.value for x in chain.from_iterable(ents.values())):>8,.2f}',
f' {max(x.value for x in chain.from_iterable(ents.values())):>8,.2f}',
)
The churn is noticeably, considerably lower! As long as we agree to the relative likelihood of seeing the suggested changes, we now have a basis for objectively measuring code quality: minimizing churn.
Note that if we had written this code using pandas, the “churn” would be
approximately the same.
from pathlib import Path
from pandas import read_csv
def load_data(filename):
return (
read_csv(filename, index_col=['date', 'entity'], parse_dates=['date'])
# .groupby(['date', 'entity']).last()
.squeeze(axis='columns')
)
if __name__ == '__main__':
data_dir = Path('data')
entries = load_data(data_dir / 'values.csv')
headings = {'name': 7, 'value': 7}
print(
(title := 'Average Value (per Entity)'),
f'{"":\N{box drawings double horizontal}<{len(title)}}',
' '.join(k.center(v) for k, v in headings.items()),
' '.join(f'{'':\N{box drawings light horizontal}^{v}}' for k, v in headings.items()),
sep='\n',
)
for name, aggs in entries.groupby('entity').agg(['mean']).iterrows():
print(
f' {name:<6}',
f' {aggs["mean"]:>6,.2f}',
)
And after the two changes:
from pathlib import Path
from pandas import read_csv
def load_data(filename):
return (
read_csv(filename, index_col=['date', 'entity'], parse_dates=['date'], skiprows=[1])
# .groupby(['date', 'entity']).last()
.squeeze(axis='columns')
)
if __name__ == '__main__':
data_dir = Path('data')
entries = load_data(data_dir / 'values.new.csv')
headings = {'name': 7, 'value': 7, 'minimum': 9, 'maximum': 9}
print(
(title := 'Average Value (per Entity)'),
f'{"":\N{box drawings double horizontal}<{len(title)}}',
' '.join(k.center(v) for k, v in headings.items()),
' '.join(f'{'':\N{box drawings light horizontal}^{v}}' for k, v in headings.items()),
sep='\n',
)
for name, aggs in entries.groupby('entity').agg(['mean', 'min', 'max']).iterrows():
print(
f' {name:<6}',
f' {aggs["mean"]:>6,.2f}',
f' {aggs["min"]:>8,.2f}',
f' {aggs["max"]:>8,.2f}',
)
While the use of pandas.read_csv makes for much less code overall, the
pattern of “churn” is similar. Thus, we may rightly assess that for code that
already works, replacing the pure Python parsing with pandas may have minimal
actual impact on how this code behaves under requirements change. Of course,
there are other considerations which might better motivate the use of pandas here
(just as there are other changes that we might anticipate to requirements
that might make pandas a worse overall choice.)
Rather than litigate pure Python versus pandas, let’s consider what this
example shows us:
Note that consideration of analytical changes nicely motivates the use of
pandas in this example: it is structurally amenable to preserving data
allowing for less churn when requirements change!
print("Let's take a look!")
Let’s say we have a function that loads some data. We may argue that we should resist writing such a function, since it’s likely to be a source of churn; however, as we know from previous discussions, certain formats (like CSV or even Parquet) simply do not have enough fidelity to accurately capture all of the details of complex data. As a result, we may have no choice but to commonly perform certain operations every time we load or store data. We may be limited in how superficially we can perform these operations: this is a clear case of intentional repetition whereby modularization can eliminate the risk of update anomaly.
from pathlib import Path
from pandas import read_csv
def load_data(filename):
return (
read_csv(filename, parse_dates=['date'])
.assign(
date=lambda df: df['date'].dt.to_period('D'),
entity=lambda df: df['entity'].astype('category'),
)
.groupby(['date', 'entity'], observed=True).last()
.squeeze(axis='columns')
)
if __name__ == '__main__':
data_dir = Path('data')
s = load_data(data_dir / 'values.csv')
print(s)
How do we write a test for load_data?
from collections import namedtuple
from tempfile import TemporaryDirectory
from pathlib import Path
from pytest import fixture
from pandas import read_csv, Series, MultiIndex, to_datetime, CategoricalIndex
@fixture
def dummy_data():
DummyData = namedtuple('DummyData', 'filename data')
with TemporaryDirectory() as d:
d = Path(d)
data = Series(
index=MultiIndex.from_arrays([
to_datetime(['2020-01-02']*3).to_period('D').rename('date'),
CategoricalIndex('abc def xyz'.split(), name='entity'),
]),
data=[100.73, 98.81, 103.58],
name='value',
)
save_data(data, filename := (d / 'values.csv'))
yield DummyData(filename=filename, data=data)
def save_data(data, filename):
data.to_csv(filename)
def load_data(filename):
return (
read_csv(filename, parse_dates=['date'])
.assign(
date=lambda df: df['date'].dt.to_period('D'),
entity=lambda df: df['entity'].astype('category'),
)
.groupby(['date', 'entity'], observed=True).last()
.squeeze(axis='columns')
)
def test_load_data(dummy_data):
data = load_data(dummy_data.filename)
assert (data == dummy_data.data).all()
This is what we might call an “expected vs actual” or an “Oracular” test. In
the test, we are comparing some actual result derived by running the code
against an expected result that was provided by some “Oracle.” In this case,
the Oracle is provided by the hard-coded, manually-verified data in the
@pytest.fixture called dummy_data.
Despite this test having full coverage over load_data, we could rightly
argue that it is not very relevatory. Similarly, we could argue that there
are rapidly diminishing returns if we extend this test merely by adding
in new hard-coded examples.
Does this mean this test is useless? Not exactly! The test itself was fairly simple to write and will in fact catch some meaningful errors or changes to the code. However, it simply fails to tell us that much about the correctness of the code itself, other than that the code produces a specified output given a specified, fixed input. We can consider this a decent “smoke test”: a test that tells us when we’ve made a trivial mistake (but may not identify if we’ve made a more significant mistake!)
The code can only very minimally help us evaluate if changes to our code may be incorrect or may affect other users.
A better approach in our testing may be to try to discover “properties” of our code, which may take the form of properties observed when looking at the differential between inputs.
Here, we use hypothesis to generate test cases that we then test for
from datetime import date
from pathlib import Path
from tempfile import TemporaryDirectory
from numpy import unique
from pandas import read_csv, to_datetime, CategoricalIndex, MultiIndex, Series
from hypothesis import given, note
from hypothesis.strategies import builds, text, characters, composite, lists, dates, floats
def save_data(data, filename):
data.to_csv(filename)
def load_data(filename):
return (
read_csv(filename, parse_dates=['date'], dtype={'entity': str})
.assign(
date=lambda df: df['date'].dt.to_period('D'),
entity=lambda df: df['entity'].astype('category'),
)
.groupby(['date', 'entity'], observed=True).last()
.squeeze(axis='columns')
)
@composite
def data(draw):
index = MultiIndex.from_product([
to_datetime(
draw(lists(dates(min_value=date(1999, 1, 1), max_value=date(2099, 12, 31)), min_size=1, max_size=100))
).to_period('D').rename('date'),
CategoricalIndex(unique(
draw(
lists(text(alphabet=characters(), min_size=1), min_size=1, max_size=100)
)
), name='entity'),
])
data = draw(lists(floats(), min_size=len(index), max_size=len(index)))
return Series(index=index, data=data, name='value')
@given(
data=data(),
filename=builds(Path, text(
alphabet=characters(exclude_characters={'\0', '/'}),
min_size=1,
)),
)
def test_load_save_data(data, filename):
# assume(...)
note(data.index)
note(data.values)
with TemporaryDirectory() as d:
d = Path(d)
save_data(data, d / filename)
assert (data == load_data(d / filename)).all()
Not only does this test help us identify an error in load_data, specifically,
that pandas.read_csv needs to take a keyword argument dtype={'entity': str}
to avoid incorrectly parsing that column… it even helps us find an error in
both pandas.MultiIndex and pandas.CategoricalIndex!
from pandas import CategoricalIndex, MultiIndex
entities = [b'abc', b'abc\0']
# CORRECT
cat = CategoricalIndex(entities)
assert cat.tolist() == entities
assert len({*cat.tolist()}) == len({*entities})
# CORRECT
idx = MultiIndex.from_product([entities])
assert idx.get_level_values(0).tolist() == entities
assert len({*idx.get_level_values(0).tolist()}) == len({*entities})
entities = ['abc', 'abc\0']
# INCORRECT
cat = CategoricalIndex(entities)
assert cat.tolist() != entities
assert len({*cat.tolist()}) < len({*entities})
# INCORRECT
idx = MultiIndex.from_product([entities])
assert idx.get_level_values(0).tolist() != entities
assert len({*idx.get_level_values(0).tolist()}) < len({*entities})
entities = ['abc', 'abc\0def']
# INCORRECT
cat = CategoricalIndex(entities)
assert cat.tolist() != entities
assert len({*cat.tolist()}) < len({*entities})
# INCORRECT
idx = MultiIndex.from_product([entities])
assert idx.get_level_values(0).tolist() != entities
assert len({*idx.get_level_values(0).tolist()}) < len({*entities})
How might we extend this test?
from datetime import date
from pathlib import Path
from tempfile import TemporaryDirectory
from numpy import unique
from pandas import read_csv, to_datetime, CategoricalIndex, MultiIndex, Series, concat
from hypothesis import given, note, settings
from hypothesis.strategies import builds, text, characters, composite, lists, dates, floats
def save_data(data, filename):
data.to_csv(filename)
def load_data(filename):
return (
read_csv(filename, parse_dates=['date'], dtype={'entity': str})
.assign(
date=lambda df: df['date'].dt.to_period('D'),
entity=lambda df: df['entity'].astype('category'),
)
.groupby(['date', 'entity'], observed=True).last()
.squeeze(axis='columns')
# ['values']
)
@composite
def data(draw):
index = MultiIndex.from_product([
to_datetime(
draw(lists(dates(min_value=date(1999, 1, 1), max_value=date(2099, 12, 31)), min_size=1, max_size=100))
).to_period('D').rename('date'),
CategoricalIndex(unique(
draw(
lists(text(alphabet=characters(codec='utf-8', exclude_characters={'\0'}), min_size=1), min_size=1, max_size=100)
)
), name='entity'),
])
data = draw(lists(floats(), min_size=len(index), max_size=len(index)))
return Series(index=index, data=data, name='value')
@settings(
max_examples=10
)
@given(
data=data(),
filename=builds(Path, text(
alphabet=characters(exclude_characters={'\0', '/'}),
min_size=1,
max_size=100,
)),
)
def test_load_save_data(data, filename):
note(data.index)
note(data.values)
with TemporaryDirectory() as d:
d = Path(d)
save_data(data, d / filename)
assert (data == load_data(d / filename)).all()
save_data(concat([data * 0, data]), d / filename)
assert (data == load_data(d / filename)).all()
lhs, rhs = data.iloc[:len(data)//2], data.iloc[len(data)//2:]
save_data(lhs, d / f'{filename}.lhs')
save_data(rhs, d / f'{filename}.rhs')
assert (data == concat([
load_data(d / f'{filename}.lhs'),
load_data(d / f'{filename}.rhs'),
]).sort_index()).all()
Ultimately, however, we may rightly ask ourselves: is this code written to be testable? Does the code expose interesting and useful properties? Herein we have another objective guidance that allows us to compare the quality of two distinct approaches: which better reveals testable properties on our code?
print("Let's take a look!")
There’s a reason that the dataclasses.dataclass and collections.namedtuple
come up so frequently in our code! They’re the simplest way to create “entities,”
and the creation of “entities” gives us space to expose testable properties and
resist churn.
For example, let’s say we have some LLM that we’re controlling via a REST API. We want to test a prompt with various temperature values. We have specified these in a YAML file.
We may have code that looks like:
from asyncio import run
from pathlib import Path
from tempfile import TemporaryDirectory
from httpx import AsyncClient
from yaml import load, dump, Loader
async def do_post(client, url, params, data):
if not params:
yield (await client.post(url, data=data)).json()
return
v = params.pop(k := next(iter(params)))
x = v['start']
while x <= v['stop']:
async for res in do_post(client, url, params.copy(), data={**data, k: x}):
yield res
x += v['step']
async def main(config_filename):
root_url = 'http://llm.example'
with open(config_filename) as f:
prompts = load(f, Loader=Loader)
async with AsyncClient() as client:
results = []
for x in prompts:
async for res in do_post(client, f'{root_url}/prompt', x['params'].copy(), {'prompt': x['prompt']}):
results.append(res)
print(f'{results = }')
if __name__ == '__main__':
with TemporaryDirectory() as d:
d = Path(d)
with open(config_filename := (d / 'config.yml'), mode='wt') as f:
dump([
{
'prompt': 'What is the capital of California?',
'params': {
'temperature': {'start': .1, 'stop': .3, 'step': .1},
},
},
], f)
run(main(config_filename))
Consider how difficult this code is to test! We could write an “integration test” against an actual config file and the actual API, but we cannot easily test these pieces in isolation.
Let’s consider changes to this code to make it more amenable to testing.
First, we’ll introduce some simple entities, but leave the rest of the code the same:
from asyncio import run
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from httpx import AsyncClient
from yaml import dump, load, Loader
@dataclass(frozen=True)
class LiteralChoices:
choices : frozenset
@classmethod
def from_yaml(cls, *, payload):
return cls(choices=frozenset(payload))
@dataclass(frozen=True)
class RangedChoices:
start : float
stop : float
step : float
@classmethod
def from_yaml(cls, *, payload):
return cls(**payload)
@dataclass(frozen=True)
class Prompt:
prompt : str
params : dict[str, LiteralChoices | RangedChoices]
@classmethod
def from_yaml(cls, *, payload):
params = {
k: (RangedChoices if isinstance(v, dict) else LiteralChoices).from_yaml(payload=v)
for k, v in payload['params'].items()
}
return cls(prompt=payload['prompt'], params=params)
async def do_post(client, url, params, data):
if not params:
yield (await client.post(url, data=data)).json()
return
v = params.pop(k := next(iter(params)))
x = v.start
while x <= v.stop:
async for res in do_post(client, url, params.copy(), data={**data, k: x}):
yield res
x += v.step
async def main(config_filename):
root_url = 'http://llm.example'
with open(config_filename) as f:
config = [Prompt.from_yaml(payload=x) for x in load(f, Loader=Loader)]
async with AsyncClient() as client:
results = []
for x in config:
async for res in do_post(client, f'{root_url}/prompt', x.params, {'prompt': x.prompt}):
results.append(res)
print(f'{results = }')
if __name__ == '__main__':
with TemporaryDirectory() as d:
d = Path(d)
with open(config_filename := (d / 'config.yml'), mode='wt') as f:
dump([
{
'prompt': 'What is the capitol of California?',
'params': {
'temperature': {'start': .1, 'stop': .2, 'step': .1},
},
},
], f)
run(main(config_filename))
This represents approximately a 50% increase in the total code (with only a minimum increase in overall functionality!) Setting aside the relatively low complexity of the added code, we need to be able to justify this additional effort.
The introduction of entities is, in fact, quite easy for us to justify: we have created surface area for better testability.
from asyncio import run
from dataclasses import dataclass, asdict
from pathlib import Path
from tempfile import TemporaryDirectory
from httpx import AsyncClient
from yaml import dump, load, Loader
from hypothesis import given, assume
from hypothesis.strategies import sets, builds, floats
@dataclass(frozen=True)
class LiteralChoices:
choices : frozenset
@classmethod
def from_yaml(cls, *, payload):
return cls(choices=frozenset(payload))
@given(payload=sets(builds(object)))
def test_literal_choices(payload):
assert LiteralChoices.from_yaml(payload=payload).choices == {*payload}
@dataclass(frozen=True)
class RangedChoices:
start : float
stop : float
step : float
@classmethod
def from_yaml(cls, *, payload):
return cls(**payload)
@given(payload=builds(dict, start=floats(), stop=floats(), step=floats()))
def test_ranged_choices(payload):
assume(payload['start'] <= payload['stop'])
assert asdict(RangedChoices.from_yaml(payload=payload)) == payload
@dataclass(frozen=True)
class Prompt:
prompt : str
params : dict[str, LiteralChoices | RangedChoices]
@classmethod
def from_yaml(cls, *, payload):
params = {
k: (RangedChoices if isinstance(v, dict) else LiteralChoices).from_yaml(payload=v)
for k, v in payload['params'].items()
}
return cls(prompt=payload['prompt'], params=params)
async def do_post(client, url, params, data):
if not params:
yield (await client.post(url, data=data)).json()
return
v = params.pop(k := next(iter(params)))
x = v.start
while x <= v.stop:
async for res in do_post(client, url, params.copy(), data={**data, k: x}):
yield res
x += v.step
async def main(config_filename):
root_url = 'http://llm.example'
with open(config_filename) as f:
config = [Prompt.from_yaml(payload=x) for x in load(f, Loader=Loader)]
async with AsyncClient() as client:
results = []
for x in config:
async for res in do_post(client, f'{root_url}/prompt', x.params, {'prompt': x.prompt}):
results.append(res)
print(f'{results = }')
if __name__ == '__main__':
with TemporaryDirectory() as d:
d = Path(d)
with open(config_filename := (d / 'config.yml'), mode='wt') as f:
dump([
{
'prompt': 'What is the capitol of California?',
'params': {
'temperature': {'start': .1, 'stop': .2, 'step': .1},
},
},
], f)
run(main(config_filename))
Now, we can start moving functionality around to better put it under the spotlight of our tests!
from asyncio import run
from dataclasses import dataclass, asdict
from itertools import islice
from pathlib import Path
from tempfile import TemporaryDirectory
from httpx import AsyncClient
from yaml import dump, load, Loader
from hypothesis import given, assume, note
from hypothesis.strategies import sets, builds, floats, composite
@dataclass(frozen=True)
class LiteralChoices:
choices : frozenset
@classmethod
def from_yaml(cls, *, payload):
return cls(choices=frozenset(payload))
@given(payload=sets(builds(object)))
def test_literal_choices(payload):
assert LiteralChoices.from_yaml(payload=payload).choices == {*payload}
@dataclass(frozen=True)
class RangedChoices:
start : float
stop : float
step : float
@classmethod
def from_yaml(cls, *, payload):
return cls(**payload)
@property
def choices(self):
x = self.start
while x <= self.stop:
yield x
x += self.step
@composite
def ranged_payload(draw):
start = draw(floats(min_value=0, max_value=1, allow_nan=False, allow_infinity=False))
stop = start + draw(floats(min_value=0.1, max_value=1, allow_nan=False, allow_infinity=False))
step = draw(floats(min_value=(stop - start) / 100, max_value=stop - start, allow_nan=False, allow_infinity=False))
return {'start': start, 'stop': stop, 'step': step}
@given(payload=ranged_payload())
def test_ranged_choices(payload):
choices = sorted(islice(RangedChoices.from_yaml(payload=payload).choices, 10_000))
note(f'{choices = }')
assert payload['start'] == choices[0]
assert payload['start'] + (payload['stop'] - payload['start']) // payload['step'] * payload['step'] == choices[-1]
@dataclass(frozen=True)
class Prompt:
prompt : str
params : dict[str, LiteralChoices | RangedChoices]
@classmethod
def from_yaml(cls, *, payload):
params = {
k: (RangedChoices if isinstance(v, dict) else LiteralChoices).from_yaml(payload=v)
for k, v in payload['params'].items()
}
return cls(prompt=payload['prompt'], params=params)
async def do_post(client, url, params, data):
if not params:
yield (await client.post(url, data=data)).json()
return
v = params.pop(k := next(iter(params)))
for x in v.choices:
async for res in do_post(client, url, params.copy(), data={**data, k: x}):
yield res
async def main(config_filename):
root_url = 'http://llm.example'
with open(config_filename) as f:
config = [Prompt.from_yaml(payload=x) for x in load(f, Loader=Loader)]
async with AsyncClient() as client:
results = []
for x in config:
async for res in do_post(client, f'{root_url}/prompt', x.params, {'prompt': x.prompt}):
results.append(res)
print(f'{results = }')
if __name__ == '__main__':
with TemporaryDirectory() as d:
d = Path(d)
with open(config_filename := (d / 'config.yml'), mode='wt') as f:
dump([
{
'prompt': 'What is the capitol of California?',
'params': {
'temperature': {'start': .1, 'stop': .2, 'step': .1},
},
},
], f)
run(main(config_filename))
Let’s zoom in on the logic…
from dataclasses import dataclass
from itertools import islice
from hypothesis import given, assume, note
from hypothesis.strategies import floats, composite
@dataclass(frozen=True)
class RangedChoices:
start : float
stop : float
step : float
@classmethod
def from_yaml(cls, *, payload):
return cls(**payload)
@property
def choices(self):
x = self.start
while x <= self.stop:
yield x
x += self.step
@composite
def ranged_payload(draw):
start = draw(floats(min_value=0, max_value=1, allow_nan=False, allow_infinity=False))
stop = start + draw(floats(min_value=0.1, max_value=1, allow_nan=False, allow_infinity=False))
step = draw(floats(min_value=(stop - start) / 100, max_value=stop - start, allow_nan=False, allow_infinity=False))
return {'start': start, 'stop': stop, 'step': step}
@given(payload=ranged_payload())
def test_ranged_choices(payload):
choices = sorted(islice(RangedChoices.from_yaml(payload=payload).choices, 10_000))
note(f'{choices = }')
assert payload['start'] == choices[0]
assert payload['start'] + (payload['stop'] - payload['start']) // payload['step'] * payload['step'] == choices[-1]
Here, we can see a very subtle bug arising from inexactness of IEEE-754 floating point values.
from dataclasses import dataclass
from decimal import Decimal
from itertools import islice
from hypothesis import given, assume, note
from hypothesis.strategies import builds, decimals, composite
@dataclass(frozen=True)
class RangedChoices:
start : Decimal
stop : Decimal
step : Decimal
@classmethod
def from_yaml(cls, *, payload):
return cls(**payload)
@property
def choices(self):
x = self.start
while x <= self.stop:
yield x
x += self.step
@composite
def ranged_payload(draw):
start = draw(decimals(min_value=0, max_value=1, allow_nan=False, allow_infinity=False, places=4))
stop = start + draw(decimals(min_value=0.1, max_value=1, allow_nan=False, allow_infinity=False, places=4))
step = draw(decimals(min_value=(stop - start) / 100, max_value=stop - start, allow_nan=False, allow_infinity=False, places=4))
return {'start': start, 'stop': stop, 'step': step}
@given(payload=ranged_payload())
def test_ranged_choices(payload):
choices = sorted(islice(RangedChoices.from_yaml(payload=payload).choices, 10_000))
note(f'{choices = }')
assert payload['start'] == choices[0]
assert payload['start'] + (payload['stop'] - payload['start']) // payload['step'] * payload['step'] == choices[-1]
Let’s reïntegrate the results as well as other small improvements:
from asyncio import run, gather
from dataclasses import dataclass, asdict
from decimal import Decimal
from itertools import islice, product, repeat
from math import prod
from pathlib import Path
from tempfile import TemporaryDirectory
from httpx import AsyncClient
from yaml import dump, load, Loader, add_representer
from hypothesis import given, assume, note
from hypothesis.strategies import sets, builds, floats, composite, decimals, dictionaries, text, characters
@dataclass(frozen=True)
class LiteralChoices:
choices : frozenset
@classmethod
def from_yaml(cls, *, payload):
return cls(choices=frozenset(payload))
@given(payload=sets(builds(object)))
def test_literal_choices(payload):
assert {*LiteralChoices.from_yaml(payload=[]).choices} == {*()}
choices = LiteralChoices.from_yaml(payload=payload).choices
assert len(choices) == len(set(choices))
if payload:
x = payload.pop()
assert x not in {*LiteralChoices.from_yaml(payload=payload).choices}
@dataclass(frozen=True)
class RangedChoices:
start : Decimal
stop : Decimal
step : Decimal
@classmethod
def from_yaml(cls, *, payload):
return cls(**payload)
@property
def choices(self):
x = self.start
while x <= self.stop:
yield x
x += self.step
@composite
def ranged_payloads(draw):
start = draw(decimals(min_value=0, max_value=1, allow_nan=False, allow_infinity=False, places=4))
stop = start + draw(decimals(min_value=0.1, max_value=1, allow_nan=False, allow_infinity=False, places=4))
step = draw(decimals(min_value=(stop - start) / 100, max_value=stop - start, allow_nan=False, allow_infinity=False, places=4))
return {'start': start, 'stop': stop, 'step': step}
@given(payload=ranged_payloads())
def test_ranged_choices(payload):
choices = sorted(islice(RangedChoices.from_yaml(payload=payload).choices, 10_000))
note(choices)
assert payload['start'] == choices[0]
assert payload['start'] + (payload['stop'] - payload['start']) // payload['step'] * payload['step'] == choices[-1]
@dataclass(frozen=True)
class Prompt:
prompt : str
params : dict[str, LiteralChoices | RangedChoices]
@classmethod
def from_yaml(cls, *, payload):
params = {
k: (RangedChoices if isinstance(v, dict) else LiteralChoices).from_yaml(payload=v)
for k, v in payload['params'].items()
}
return cls(prompt=payload['prompt'], params=params)
@property
def all_params(self):
yield from map(dict, product(*(zip(repeat(k), v.choices) for k, v in self.params.items())))
@given(
payload=builds(dict,
prompt=builds(object),
params=dictionaries(
text(alphabet=characters(codec='ascii', min_codepoint=ord('a'), max_codepoint=ord('z')), min_size=1, max_size=20),
ranged_payloads(),
min_size=0,
max_size=4,
)
),
)
def test_prompt(payload):
p = Prompt.from_yaml(payload=payload)
all_params = [*p.all_params]
if payload['params']:
assert len(all_params) == prod(len({*v.choices}) for v in p.params.values())
else:
assert len(all_params) == 1
async def main(config_filename):
root_url = 'http://llm.example'
with open(config_filename) as f:
config = [Prompt.from_yaml(payload=x) for x in load(f, Loader=Loader)]
async with AsyncClient() as client:
requests = [
client.post(f'{root_url}/prompt', data={'prompt': x.prompt, **params})
for x in config
for params in x.all_params
]
results = [x.json() for x in await gather(*requests)]
print(f'{results = }')
if __name__ == '__main__':
with TemporaryDirectory() as d:
d = Path(d)
with open(config_filename := (d / 'config.yml'), mode='wt') as f:
dump([
{
'prompt': 'What is the capitol of California?',
'params': {
'temperature': {'start': Decimal('.1'), 'stop': Decimal('.3'), 'step': Decimal('.1')},
},
},
], f)
run(main(config_filename))
In general, this is why we want to lean toward first establishing entities that describe the nature of our work. It is also why we want to rely on techniques like using iteration helpers.
If we revisit a previous example, we can see how this plays out even in cases where we cannot strictly avoid churn.
Here is our pure Python code for processing our data set:
from collections import defaultdict, namedtuple
from csv import reader
from datetime import datetime
from itertools import chain
from pathlib import Path
from statistics import mean
def load_data(filename):
Entry = namedtuple('Entry', 'date entity value')
data = defaultdict(lambda: defaultdict(set))
with open(filename) as f:
for lineno, (date, entity, value) in enumerate(reader(f), start=1):
if lineno == 1: continue # skip header
ent = Entry(datetime.fromisoformat(date), entity, float(value))
data[ent.entity][ent.date].add(ent)
return data
if __name__ == '__main__':
data_dir = Path('data')
entries = load_data(data_dir / 'values.csv')
headings = {'name': 7, 'value': 7}
print(
(title := 'Average Value (per Entity)'),
f'{"":\N{box drawings double horizontal}<{len(title)}}',
' '.join(k.center(v) for k, v in headings.items()),
' '.join(f'{'':\N{box drawings light horizontal}^{v}}' for k, v in headings.items()),
sep='\n',
)
for entity, ents in sorted(entries.items()):
print(f' {entity:<6} {mean(x.value for x in chain.from_iterable(ents.values())):>6,.2f}')
Let’s zoom in on the CSV parsing code.
from csv import reader
from pathlib import Path
data_dir = Path('data')
filename = data_dir / 'values.csv'
with open(filename) as f:
for lineno, (date, entity, value) in enumerate(reader(f), start=1):
if lineno == 1: continue # skip header
...
We may consider writing an iteration helper to create more “surface area” for testing:
from csv import reader
from pathlib import Path
def skip_headers(lines):
for lineno, contents in enumerate(lines, start=1):
if lineno == 1: continue # skip header
yield contents
if __name__ == '__main__':
data_dir = Path('data')
filename = data_dir / 'values.csv'
with open(filename) as f:
for date, entity, value in skip_headers(reader(f)):
...
Except when we need to adjust how this works, we end up with something that starts to churn.
from csv import reader
from pathlib import Path
def skip_headers(lines, has_underline=False):
for lineno, contents in enumerate(lines, start=1):
if lineno == 1: continue # skip header
if has_underline and lineno == 2: continue # skip underline
yield contents
if __name__ == '__main__':
data_dir = Path('data')
filename = data_dir / 'values.new.csv'
with open(filename) as f:
for date, entity, value in skip_headers(reader(f), has_underline=True):
...
If we instead write this as a proper iteration helper, supplementing the iteration but allowing decisions to be made at the superficial-most layer, we end with something more flexibile (i.e., something that churns less,) and more testable.
from csv import reader
from collections import namedtuple
from datetime import datetime
from pathlib import Path
class Line(namedtuple('Line', 'contents lineno is_header is_underline')):
@classmethod
def from_contents(cls, contents, *, lineno):
return cls(contents, lineno, lineno == 1, lineno == 2)
def mark_headers(lines):
for lineno, contents in enumerate(lines, start=1):
yield Line.from_contents(contents=contents, lineno=lineno)
class Entry(namedtuple('Entry', 'date entity value')):
@classmethod
def from_line(cls, *, date, entity, value):
return cls(date=datetime.fromisoformat(date), entity=entity, value=float(value))
if __name__ == '__main__':
data_dir = Path('data')
filename = data_dir / 'values.new.csv'
with open(filename) as f:
for line in mark_headers(reader(f)):
if line.is_header:
header = line.contents
continue
if line.is_underline: continue
ent = Entry.from_line(**dict(zip(header, line.contents)))
In this sense, we restructure our code to move the pieces that are less interesting or harder to test to the superficial most layers and expose surface area and reduce churn on pieces that are less likely to change in significant ways!
What makes for good or bad code?
Is it a sense of taste, a sense of smell, a personal judgement?
In some cases, yes. But in most cases, we can establish an objective metric for evaluating the quality of our code, by negotiating and agreeing upon possible changes to requirements and the amenability of that code to support those changes. We can measure the quality of our code objectively in terms of the amount of code that “churns” or changes and the scope of how much code we have to revalidate.
Similarly, when we think about the quality of our code, we can also consider
an objective metric relating to the testability of our code. We know that
hypothesis in Python and property testing in general is an extremely powerful
way for us to build confidence in the correctness of our work. Thus, we can
consider that better code is code that reveals more properties for testing.
From these two simple criteria, we can derive direct, objective guidance on techniques and approaches to employ that we can use to significantly improve the quality of our work!