文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

数据工程中的单元测试完全指南

2024-11-30 08:25

关注

本文带你深入探索如何将这些成熟的软件工程实践应用到数据工程中。

1 单元测试的重要性

在数据工程的背景下,采用单元测试可以确保您的数据和业务逻辑的准确性,进而产出高质量的数据,获得您的数据分析师、科学家和决策者对数据的信任。

2 单元测试数据流水线

数据流水线通常涉及复杂的数据抽取、转换和加载(ETL)操作序列,出错的可能性很大。为了对这些操作进行单元测试,我们将流水线拆分为单个组件,并对每个组件进行独立验证。

以一个简单的流水线为例,该流水线从CSV文件中提取数据,通过清除空值来转换数据,然后将其加载到数据库中。以下是使用pandas的基于Python的示例:

import pandas as pd
from sqlalchemy import create_engine

# 加载CSV文件的函数
def load_data(file_name):
    data = pd.read_csv(file_name)
    return data

# 清理数据的函数
def clean_data(data):
    data = data.dropna()
    return data

# 将数据保存到SQL数据库的函数
def save_data(data, db_string, table_name):
    engine = create_engine(db_string)
    data.to_sql(table_name, engine, if_exists='replace')

# 运行数据流水线
data = load_data('data.csv')
data = clean_data(data)
save_data(data, 'sqlite:///database.db', 'my_table')

为了对这个流水线进行单元测试,我们使用像pytest这样的库为每个函数编写单独的测试。

在这个示例中,有三个主要的函数:load_data、clean_data和save_data。我们会为每个函数编写测试。对于load_data和save_data,需要设置一个临时的CSV文件和SQLite数据库,可以使用pytest库的fixture功能来实现。

import os
import pandas as pd
import pytest
from sqlalchemy import create_engine, inspect

# 使用pytest fixture来设置临时的CSV文件和SQLite数据库
@pytest.fixture
def csv_file(tmp_path):
    data = pd.DataFrame({
        'name': ['John', 'Jane', 'Doe'],
        'age': [34, None, 56]  # Jane的年龄缺失
    })
    file_path = tmp_path / "data.csv"
    data.to_csv(file_path, index=False)
    return file_path


@pytest.fixture
def sqlite_db(tmp_path):
    file_path = tmp_path / "database.db"
    return 'sqlite:///' + str(file_path)


def test_load_data(csv_file):
    data = load_data(csv_file)
    
    assert 'name' in data.columns
    assert 'age' in data.columns
    assert len(data) == 3


def test_clean_data(csv_file):
    data = load_data(csv_file)
    data = clean_data(data)
    
    assert data['age'].isna().sum() == 0
    assert len(data) == 2  # Jane的记录应该被删除


def test_save_data(csv_file, sqlite_db):
    data = load_data(csv_file)
    data = clean_data(data)
    save_data(data, sqlite_db, 'my_table')
    
    # 检查数据是否保存正确
    engine = create_engine(sqlite_db)
    inspector = inspect(engine)
    tables = inspector.get_table_names()
    
    assert 'my_table' in tables
    
    loaded_data = pd.read_sql('my_table', engine)
    assert len(loaded_data) == 2  # 只应该存在John和Doe的记录

这里是另一个例子:假设您有一个从CSV文件中加载数据并将其中的“日期”列从字符串转换为日期时间的流水线:

def convert_date(data, date_column):
    data[date_column] = pd.to_datetime(data[date_column])
    return data

为上述函数编写的单元测试将传入具有已知日期字符串格式的DataFrame。然后,它将验证函数是否正确将日期转换为日期时间对象,并且它是否适当处理无效格式。

我们为上述场景编写一个单元测试。该测试首先使用有效日期检查函数,断言输出DataFrame中的“date”列确实是datetime类型,并且值与预期相符。然后,它检查在给出无效日期时,函数是否正确引发了ValueError。

import pandas as pd
import pytest

def test_convert_date():
    # 使用有效日期进行测试
    test_data = pd.DataFrame({
        'date': ['2021-01-01', '2021-01-02']
    })
    
    converted_data = convert_date(test_data.copy(), 'date')
    
    assert pd.api.types.is_datetime64_any_dtype(converted_data['date'])
    assert converted_data.loc[0, 'date'] == pd.Timestamp('2021-01-01')
    assert converted_data.loc[1, 'date'] == pd.Timestamp('2021-01-02')

    # 使用无效日期进行测试
    test_data = pd.DataFrame({
        'date': ['2021-13-01']  # 这个日期是无效的,因为没有第13个月
    })
    
    with pytest.raises(ValueError):
        convert_date(test_data, 'date')

以下是最后一个例子:假设您有一个加载数据并进行聚合的流水线,计算每个地区的总销售额:

def aggregate_sales(data):
    aggregated = data.groupby('region').sales.sum().reset_index()
    return aggregated

为该函数编写的单元测试将向其传递具有各个地区销售数据的DataFrame。测试将验证函数是否正确计算每个地区的总销售额。

我们为该函数编写一个单元测试。在这个测试中,我们首先向aggregate_sales函数传递一个具有已知销售数据的DataFrame,并检查它是否正确聚合了销售额。然后,向其传递一个没有销售数据的DataFrame,并检查它是否正确将这些销售额聚合为0。这样可以确保函数正确处理典型情况和边缘情况。

以下是使用pytest库为aggregate_sales函数编写单元测试的示例:

import pandas as pd
import pytest

def test_aggregate_sales():
    # 各个地区的销售数据
    test_data = pd.DataFrame({
        'region': ['North', 'North', 'South', 'South', 'East', 'East', 'West', 'West'],
        'sales': [100, 200, 300, 400, 500, 600, 700, 800]
    })
    
    aggregated = aggregate_sales(test_data)
    
    assert aggregated.loc[aggregated['region'] == 'North', 'sales'].values[0] == 300
    assert aggregated.loc[aggregated['region'] == 'South', 'sales'].values[0] == 700
    assert aggregated.loc[aggregated['region'] == 'East', 'sales'].values[0] == 1100
    assert aggregated.loc[aggregated['region'] == 'West', 'sales'].values[0] == 1500

    # 没有销售数据的测试
    test_data = pd.DataFrame({
        'region': ['North', 'South', 'East', 'West'],
        'sales': [0, 0, 0, 0]
    })
    
    aggregated = aggregate_sales(test_data)
    
    assert aggregated.loc[aggregated['region'] == 'North', 'sales'].values[0] == 0
    assert aggregated.loc[aggregated['region'] == 'South', 'sales'].values[0] == 0
    assert aggregated.loc[aggregated['region'] == 'East', 'sales'].values[0] == 0
    assert aggregated.loc[aggregated['region'] == 'West', 'sales'].values[0] == 0

本文转载自微信公众号「Java学研大本营」,可以通过以下二维码关注。转载本文请联系公众号。

来源:Java学研大本营内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯