首页 > 解决方案 > SQLAlchemy 和 PyTest:如何在测试期间更改数据库?

问题描述

我正在用 Python 编写一个“批处理”过程(不使用任何框架)。项目配置在一个config.ini文件中

[db]
db_uri = mysql+pymysql://root:password@localhost:3306/manage

我有另一个文件config.test要在测试期间交换

[db]
db_uri = sqlite://

我有一个简单的test_sample.py

# tests/test_sample.py

import pytest
import shutil
import os

import batch
import batch_manage.utils.getconfig as getconfig_class

class TestClass():
    def setup_method(self, method):
        """ Rename the config """
        shutil.copyfile("config.ini", "config.bak")
        os.remove('config.ini')
        shutil.copyfile("config.test", "config.ini")

    def teardown_method(self, method):
        """ Replace the config """
        shutil.copyfile("config.bak", "config.ini")
        os.remove('config.bak')


    def test_can_get_all_data_from_table(self):
        conf = getconfig_class.get_config('db')
        db_uri = conf.get('db_uri')
        assert db_uri == "sqlite://"
        # This pass! ok!
        people = batch.get_all_people()
        assert len(people) == 0
        # This fails, because counts the records in production database

db_uriassert 可以(在测试时是 sqlite 而不是 mysql)但 len 不是 0,而是 42(MySql 数据库中的记录数。

我怀疑 SqlAlchemy ORM 的会话存在问题。我做了几次尝试,没有可能覆盖/删除它。

其余的代码非常简单:

# batch_manage/models/base.py
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

import batch_manage.utils.getconfig as getconfig_class

conf = getconfig_class.get_config('db')
db_uri = conf.get('db_uri')

engine = create_engine(db_uri)
Session = sessionmaker(bind=engine)

Base = declarative_base()
# batch_manage/models/persone.py
from sqlalchemy import Column, String, Integer, Date

from batch_manage.models.base import Base


class Persone(Base):
    __tablename__ = "persone"

    idpersona = Column(Integer, primary_key=True)
    nome = Column(String)
    created_at = Column(Date)

    def __init__(self, nome, created_at):
        self.nome = nome
        self.created_at = created_at

而它batch.py本身

# batch.py
import click

from batch_manage.models.base import Session
from batch_manage.models.persone import Persone

def get_all_people():
    """ Get all people from database """
    session = Session()

    people = session.query(Persone).all()

    return people


@click.command()
def batch():
    click.echo("------------------------------")
    click.echo("Running Batch")
    click.echo("------------------------------")

    people = get_all_people()

    for item in people:
        print(f"Persona con ID {item.idpersona} creata il {item.created_at}")



if __name__ == '__main__':
    batch()

第一个“解决方案”(不优雅,我会重构它)

我暂时通过以下方式进行了更改测试:

def test_can_get_all_data_from_table(self):
        conf = getconfig_class.get_config('db')
        db_uri = conf.get('db_uri')
        assert db_uri == "sqlite://"

        from sqlalchemy.orm import sessionmaker
        from sqlalchemy import create_engine
        engine = create_engine(db_uri)
        Session = sessionmaker(bind=engine)
        session = Session()

        people = batch.get_all_people(session)
        assert len(people) == 0

get_all_people方法

def get_all_people(session = None):
    """ Get all people from database """
    if session is None:
      session = Session()

    people = session.query(Persone).all()

    return people

但是这个解决方案并不优雅,并且还减少了代码覆盖率,因为没有遵循 if 路径。

标签: pythonsqlalchemy

解决方案


因此,如果我正确地遵循您的代码,看起来您在设置测试之前正在导入您的 ORM 内容。这是您当前的操作顺序:

  1. batch.py是进口的。
  2. 您的其他模块已导入。
  3. 在文件的顶级模块代码中models/base.py,您配置要使用的数据库。
  4. 您的测试类设置,在它已经加载后更改配置。

因此,对于解决方案:

在测试本身中导入所有模块

如果您只是想更改操作顺序,请在您进入测试之前不要导入您的代码。无论如何,这通常是很好的测试实践:

class TestClass():
    ... (your existing code) ...

    def test_can_get_all_data_from_table(self):
        # ONLY import stuff inside your test
        from batch_manage.models.base import Session
        from batch_manage.models.persone import Persone

这可能会解决您眼前的问题,但可能有一个更优雅的解决方案

不要立即配置数据库

我不知道你是否在使用 Flask,但无论哪种方式,Flask 测试文档都有一些关于如何设置测试数据库的很好的说明。导入模块后,您需要配置数据库 URL。

例如:

from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

(your models)

请注意,我还没有定义引擎。我可以在运行时做到这一点。

def setup_engine():
    engine = create_engine(db_url)
    Base.metadata.bind = engine

在您的主代码中,在向用户提供内容之前,您需要调用setup_engine. 在您的测试环境中,您将调用自己的setup_engine绑定到测试环境的方法。


推荐阅读