Skip to content

Commit 2e53727

Browse files
committed
Finalized database router support
1 parent 8314379 commit 2e53727

4 files changed

Lines changed: 213 additions & 5 deletions

File tree

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,38 @@ TASKS = {
6868

6969
The `id_function` must return a UUID (either `uuid.UUID` or string representation). Additionally, the PostgreSQL-specific [`RandomUUID`](https://docs.djangoproject.com/en/stable/ref/contrib/postgres/functions/#django.contrib.postgres.functions.RandomUUID) or other database expressions are supported on Django 6.0+.
7070

71+
### Using a separate database
72+
73+
Database routing is controlled by Django's standard database router API. If you want
74+
`django_tasks_db` models to use a separate database, define a router.
75+
76+
> **Note:** In the example below, `"task_queue"` is a placeholder. You should replace it with the specific database **alias** you have configured in your `settings.DATABASES` dictionary.
77+
78+
```python
79+
class TaskDBRouter:
80+
def db_for_read(self, model, **hints):
81+
if model._meta.app_label == "django_tasks_database":
82+
return "task_queue"
83+
return None
84+
85+
def db_for_write(self, model, **hints):
86+
if model._meta.app_label == "django_tasks_database":
87+
return "task_queue"
88+
return None
89+
90+
def allow_migrate(self, db, app_label, model_name=None, **hints):
91+
if app_label == "django_tasks_database":
92+
return db == "task_queue"
93+
return None
94+
```
95+
96+
Then enable it:
97+
98+
```python
99+
DATABASE_ROUTERS = ["path.to.TaskDBRouter"]
100+
```
101+
102+
71103
## Contributing
72104

73105
See [CONTRIBUTING.md](./CONTRIBUTING.md) for information on how to contribute.

tests/settings.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,15 @@
3232
DATABASES = {
3333
"default": dj_database_url.config(
3434
default="sqlite:///" + os.path.join(BASE_DIR, "db.sqlite3")
35-
)
35+
),
36+
"secondary": dj_database_url.parse(
37+
"sqlite:///" + os.path.join(BASE_DIR, "db-secondary.sqlite3")
38+
),
3639
}
3740

38-
if "sqlite" in DATABASES["default"]["ENGINE"]:
39-
DATABASES["default"]["TEST"] = {"NAME": os.path.join(BASE_DIR, "db-test.sqlite3")}
41+
for alias, db in DATABASES.items():
42+
if "sqlite" in db["ENGINE"]:
43+
db["TEST"] = {"NAME": os.path.join(BASE_DIR, f"db-test-{alias}.sqlite3")}
4044

4145

4246
USE_TZ = True

tests/settings_fast.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .settings import *
22

33
# Unset custom test settings to use in-memory DB
4-
if "sqlite" in DATABASES["default"]["ENGINE"]:
5-
del DATABASES["default"]["TEST"]
4+
for db in DATABASES.values():
5+
if "sqlite" in db["ENGINE"] and "TEST" in db:
6+
del db["TEST"]

tests/tests.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,3 +1753,174 @@ def test_display_run_after_returns_db_value(self) -> None:
17531753
result = self.admin.display_run_after(db_task_result)
17541754

17551755
self.assertEqual(result, expected_run_after)
1756+
1757+
1758+
class DBTaskResultSecondaryRouter:
1759+
def db_for_read(self, model: type[Any], **hints: Any) -> str | None:
1760+
if model._meta.app_label == "django_tasks_database":
1761+
return "secondary"
1762+
return None
1763+
1764+
def db_for_write(self, model: type[Any], **hints: Any) -> str | None:
1765+
if model._meta.app_label == "django_tasks_database":
1766+
return "secondary"
1767+
return None
1768+
1769+
1770+
class DBTaskResultSplitRouter:
1771+
def db_for_read(self, model: type[Any], **hints: Any) -> str | None:
1772+
if model._meta.app_label == "django_tasks_database":
1773+
return "default"
1774+
return None
1775+
1776+
def db_for_write(self, model: type[Any], **hints: Any) -> str | None:
1777+
if model._meta.app_label == "django_tasks_database":
1778+
return "secondary"
1779+
return None
1780+
1781+
1782+
@override_settings(
1783+
TASKS={
1784+
"default": {"BACKEND": "django_tasks_db.DatabaseBackend"},
1785+
"dummy": {"BACKEND": "django_tasks.backends.dummy.DummyBackend"},
1786+
},
1787+
DATABASE_ROUTERS=["tests.tests.DBTaskResultSecondaryRouter"],
1788+
)
1789+
class DatabaseRouterTestCase(TransactionTestCase):
1790+
databases = {"default", "secondary"}
1791+
1792+
def tearDown(self) -> None:
1793+
logger = logging.getLogger("django_tasks_db")
1794+
tasks_logger = logging.getLogger("django_tasks")
1795+
1796+
# Reset the logger after every run, to ensure the correct `stdout` is used
1797+
for handler in logger.handlers:
1798+
logger.removeHandler(handler)
1799+
1800+
for handler in tasks_logger.handlers:
1801+
tasks_logger.removeHandler(handler)
1802+
1803+
def test_enqueue_uses_router_database(self) -> None:
1804+
result = test_tasks.calculate_meaning_of_life.enqueue()
1805+
1806+
self.assertTrue(
1807+
DBTaskResult.objects.using("secondary").filter(id=result.id).exists()
1808+
)
1809+
self.assertFalse(
1810+
DBTaskResult.objects.using("default").filter(id=result.id).exists()
1811+
)
1812+
1813+
def test_get_result_uses_router_database(self) -> None:
1814+
backend = task_backends["default"]
1815+
1816+
db_result = DBTaskResult.objects.using("secondary").create(
1817+
task_path="tests.tasks.calculate_meaning_of_life",
1818+
args_kwargs={"args": [], "kwargs": {}},
1819+
run_after=test_tasks.calculate_meaning_of_life.run_after, # type: ignore[misc]
1820+
backend_name="default",
1821+
)
1822+
1823+
retrieved_result = backend.get_result(db_result.id)
1824+
self.assertEqual(retrieved_result.id, str(db_result.id))
1825+
1826+
with self.assertRaises(DBTaskResult.DoesNotExist):
1827+
DBTaskResult.objects.using("default").get(id=db_result.id)
1828+
1829+
def test_worker_uses_router_database(self) -> None:
1830+
result = test_tasks.calculate_meaning_of_life.enqueue()
1831+
1832+
self.assertEqual(DBTaskResult.objects.ready().using("secondary").count(), 1)
1833+
self.assertEqual(DBTaskResult.objects.ready().using("default").count(), 0)
1834+
1835+
call_command(
1836+
"db_worker",
1837+
verbosity=0,
1838+
batch=True,
1839+
interval=0,
1840+
startup_delay=False,
1841+
backend_name="default",
1842+
)
1843+
1844+
result.refresh()
1845+
self.assertEqual(result.status, TaskResultStatus.SUCCESSFUL)
1846+
1847+
1848+
@override_settings(
1849+
TASKS={"default": {"BACKEND": "django_tasks_db.DatabaseBackend"}},
1850+
DATABASE_ROUTERS=["tests.tests.DBTaskResultSplitRouter"],
1851+
)
1852+
class DatabaseRouterSplitReadWriteTestCase(TransactionTestCase):
1853+
databases = {"default", "secondary"}
1854+
1855+
def tearDown(self) -> None:
1856+
logger = logging.getLogger("django_tasks_db")
1857+
tasks_logger = logging.getLogger("django_tasks")
1858+
1859+
# Reset the logger after every run, to ensure the correct `stdout` is used
1860+
for handler in logger.handlers:
1861+
logger.removeHandler(handler)
1862+
1863+
for handler in tasks_logger.handlers:
1864+
tasks_logger.removeHandler(handler)
1865+
1866+
for handler in prune_db_tasks_logger.handlers:
1867+
prune_db_tasks_logger.removeHandler(handler)
1868+
1869+
def test_worker_uses_write_router_database(self) -> None:
1870+
result = test_tasks.calculate_meaning_of_life.enqueue()
1871+
1872+
self.assertTrue(
1873+
DBTaskResult.objects.using("secondary").filter(id=result.id).exists()
1874+
)
1875+
self.assertEqual(DBTaskResult.objects.using("default").count(), 0)
1876+
1877+
call_command(
1878+
"db_worker",
1879+
verbosity=0,
1880+
batch=True,
1881+
interval=0,
1882+
startup_delay=False,
1883+
backend_name="default",
1884+
)
1885+
1886+
db_result = DBTaskResult.objects.using("secondary").get(id=result.id)
1887+
self.assertEqual(db_result.status, TaskResultStatus.SUCCESSFUL)
1888+
1889+
def test_prune_uses_write_router_database(self) -> None:
1890+
result = test_tasks.noop_task.enqueue()
1891+
1892+
DBTaskResult.objects.using("secondary").filter(id=result.id).update(
1893+
status=TaskResultStatus.SUCCESSFUL,
1894+
finished_at=timezone.now(),
1895+
)
1896+
self.assertEqual(
1897+
DBTaskResult.objects.using("secondary").finished().count(),
1898+
1,
1899+
)
1900+
1901+
call_command(
1902+
"prune_db_task_results",
1903+
verbosity=0,
1904+
min_age_days=0,
1905+
)
1906+
1907+
self.assertEqual(DBTaskResult.objects.using("secondary").count(), 0)
1908+
1909+
def test_prune_dry_run_uses_write_router_database(self) -> None:
1910+
result = test_tasks.noop_task.enqueue()
1911+
1912+
DBTaskResult.objects.using("secondary").filter(id=result.id).update(
1913+
status=TaskResultStatus.SUCCESSFUL,
1914+
finished_at=timezone.now(),
1915+
)
1916+
1917+
stdout = StringIO()
1918+
call_command(
1919+
"prune_db_task_results",
1920+
verbosity=3,
1921+
min_age_days=0,
1922+
dry_run=True,
1923+
stdout=stdout,
1924+
)
1925+
1926+
self.assertEqual(DBTaskResult.objects.using("secondary").count(), 1)

0 commit comments

Comments
 (0)