Python 是一种动态类型语言,使用起来非常简单,如果我们不想接触复杂的程序,它肯定是进行复杂计算的首选语言。Python 提供了优秀的库(Pandas、NumPy、Matplotlib、ScyPy、PyTorch、TensorFlow 等)来支持对数据结构或数组的逻辑、数学和科学操作。
Java 是一种非常健壮的语言,具有强类型,因此有更严格的语法规则,所以不易出现程序错误。与Python一样,它也提供了大量的库来处理数据结构、线性代数、机器学习和数据处理(ND4J、Mahout、Spark、Deeplearning4J 等)。
本文将介绍如何对大量表格数据进行简单的数据分析,并使用 Java 和 Python 计算一些统计数据。我们可以看到使用各个平台进行数据分析的不同技术,对比它们的扩展方式,以及应用并行计算来提高其性能的可行性。
提出问题
我们要对不同州的一大批城市的价格做一个简单的分析,这里假设有一个包含此信息的 CSV 文件。阅读文件并继续过滤掉一些州,并将剩余的州按城市-州分组以进行一些基本统计。希望能够找到有效执行的解决方案,并且能够随着输入数据规模的增长而有良好的扩展。
数据样本是:
城市 | 州 | 基本价格 | 实际价格 |
La Jose | PA | 34.17 | 33.19 |
Preachers Slough | WA | 27,46 | 90.17 |
Doonan Corners | NY | 92.0 | 162.46 |
Doonan Corners | NY | 97.45 | 159.46 |
Castle Rock | WA | 162.16 | 943.21 |
Marble Rock | IA | 97.13 | 391.49 |
Mineral | CA | 99.13 | 289.37 |
Blountville | IN | 92.50 | 557.66 |
Blountsville | IN | 122.50 | 557.66 |
Coe | IN | 187.85 | 943.98 |
Cecilia | KY | 92.85 | 273.61 |
目的是展示如何使用 Java 和 Python 解决这些类型的问题。该示例非常简单且范围有限,但很容易拓展到更具挑战性的问题。
Java 的方法
首先定义一个封装数据元素的 Java 记录:
record InputEntry(String city, String state, double basePrice, double actualPrice) {}
记录(record)是 JDK 14 中引入的一种新型类型声明。它是定义提供构造函数、访问器、equals 和哈希实现的不可变类的一种简捷方式。
接下来,读取 CVS 文件并将它们增加到一个列表中:
List inputEntries = readRecordEntriesFromCSVFile(recordEntries.csv);
为了按城市和州对输入的元素进行分组,将其定义:
record CityState(String city, String state) {};
使用以下类来封装属于一个组的所有元素的统计信息:
record StatsAggregation(StatsAccumulator basePrice, StatsAccumulator actualPrice) {}
StatsAccumulator是Guava 库的一部分。可以将双精度值集合添加到类中,它会计算基本统计数据,例如计数、平均值、方差或标准差。可以使用StatsAccumulator来获取InputEntry的basePrice和actualPrice的统计数据。
现在我们已经拥有了解决问题的所有材料。Java Streams提供了一个强大的框架来实现数据操作和分析。它的声明式编程风格,对选择、过滤、分组和聚合的支持,简化了数据操作和统计分析。它的框架还提供了一个强大的实现,可以处理大量的(甚至是无限的流),并通过使用并行性、懒惰性和短路操作来高效处理。所有这些特性使Java Streams成为解决这类问题的绝佳选择。实现非常简单:
Map stats = inputEntries.stream().
filter(i -> !(i.state().equals("MN") || i.state().equals("CA"))).collect(
groupingBy(entry -> new CityState(entry.city(), entry.state()),
collectingAndThen(Collectors.toList(),
list -> {StatsAccumulator sac = new StatsAccumulator();
sac.addAll(list.stream().mapToDouble(InputEntry::basePrice));
StatsAccumulator sas = new StatsAccumulator();
sas.addAll(list.stream().mapToDouble(InputEntry::actualPrice));
return new StatsAggregation(sac, sas);}
)));
在代码的第 2 行,我们使用Stream::filter. 这是一个布尔值函数,用于过滤列表中的元素。可以实现一个 lambda 表达式来删除任何包含“MN”或“CA”状态的元素。
然后继续收集列表的元素并调用Collectors::groupingBy()(第 3 行),它接受两个参数:
- 一个分类功能,使用CityState记录来做城市和州的分组(第3行)。
- 下游的收集器,包含属于同一<城州>的元素。使用Collectors::collectingAndThen(第 4 行),它采用两个参数分两步进行归约:
·我们使用Collectors::toList(第 4 行),它返回一个收集器,它将属于同一<城州>的所有元素放到一个列表中。
·随后对这个列表进行了整理转换。使用一个lambda函数(第5行至第9行)来定义两个StatsAccumulator(s),在这里分别计算前一个列表中的basePrice和actualPrice元素的统计数据。最后,返回到新创建的包含这些元素的StatsAggregation记录。
正如前文所述,使用Java Streams的优势之一是,它提供了一种简单的机制,可以使用多线程进行并行处理。这允许利用CPU的多核资源,同时执行多个线程。只要在流中添加一个 "parallel":
Map stats = inputEntries.stream().parallel().
这导致流框架将元素列表细分为多个部分,并同时在单独的线程中运行它们。随着所有不同的线程完成它们的计算,框架将它们串行添加到生成的 Map 中。
在第4行中使用Collectors::groupingByConcurrent而不是Collectors:groupingBy。在这种情况下,框架使用并发映射,允许将来自不同线程的元素直接插入到此映射中,而不必串行组合。
有了这三种可能性,可以检查它们如何执行之前的统计计算(不包括从 CSV 文件加载数据的时间),因为加载量从500万条翻倍到2000万条:
串行 | 平行 | 并行 & GroupByConcurrent | |
五百万个元素 | 3.045 秒 | 1.941 秒 | 1.436 秒 |
一千万个元素 | 6.405 秒 | 2.876 秒 | 2.785 秒 |
两千万个元素 | 8.507 秒 | 4.956 秒 | 4.537 秒 |
可以看到并行运行大大提高了性能;随着负载的增加,时间几乎减半。使用 GroupByConcurrent 还可额外获得 10% 的收益。
最后,得到结果是微不足道的;例如,要获得印第安纳州 Blountsville 的统计数据,我们只需要:
StatsAggregation aggreg = stateAggr.get(new CityState("Blountsville ", "IN"));
System.out.println("Blountsville, IN");
System.out.println("basePrice.mean: " + aggreg.basePrice().mean());
System.out.println("basePrice.populationVariance: " + aggreg.basePrice().populationVariance());
System.out.println("basePrice.populationStandardDeviation: " + aggreg.basePrice().populationStandardDeviation());
System.out.println("actualPrice.mean: " + aggreg.basePrice().mean());
System.out.println("actualPrice.populationVariance: " + aggreg.actualPrice().populationVariance());
System.out.println("actualPrice.populationStandardDeviation: " + aggreg.actualPrice().populationStandardDeviation());
得到的结果:
Blountsville : IN
basePrice.mean: 50.302588996763795
basePrice.sampleVariance: 830.7527439246837
basePrice.sampleStandardDeviation: 28.822781682632293
basePrice.count: 309
basePrice.min: 0.56
basePrice.max: 99.59
actualPrice.mean: 508.8927831715211
actualPrice.sampleVariance: 78883.35878833274
actualPrice.sampleStandardDeviation: 280.86181440048546
actualPrice.count: 309
actualPrice.min: 0.49
actualPrice.max: 999.33
Python的方法
在 Python 中,有几个库可以处理数据统计和分析。其中,Pandas 库非常适合处理大量表格数据,它提供了非常有效的过滤、分组和统计分析方法。
使用 Python 分析以前的数据:
import pandas as pd
def group_aggregations(df_group_by):
df_result = df_group_by.agg(
{'basePrice': ['count', 'min', 'max', 'mean', 'std', 'var'],
'actualPrice': ['count', 'min', 'max', 'mean', 'std', 'var']}
)
return df_result
if __name__ == '__main__':
df = pd.read_csv("recordEntries.csv")
excluded_states = ['MN', 'CA']
df_st = df.loc[~ df['state'].isin(excluded_states)]
group_by = df_st.groupby(['city', 'state'], sort=False)
aggregated_results = group_aggregations(group_by)
在主要部分,先调用pandas.read_csv()(第 11 行)将文件中用逗号分隔的值加载到 PandasDataFrame中。
在第13行,使用~df['state'].isin(excluded_states)来得到一个Pandas系列的布尔值,使用pandas.loc()来过滤其中不包括的州(MN和CA)。
接下来,在第14行使用DataFrame.groupby()来按城市和州进行分组。结果由group_aggregations()处理,保存每个组的basePrice和actualPrice的统计数据。
在Python中打印结果是非常直接的。IN和Blountsville的结果:
print(aggregated_results.loc['Blountsville', 'IN']['basePrice'])
print(aggregated_results.loc['Blountsville', 'IN']['actualPrice'])
统计数据:
base_price:
Name: (Blountsville, IN), dtype: float64
count 309.000000
min 0.560000
max 99.590000
mean 50.302589
std 28.822782
var 830.752744
actual_price:
Name: (Blountsville, IN), dtype: float64
count 309.000000
min 0.490000
max 999.330000
mean 508.892783
std 280.861814
var 78883.358788
为了并行运行前面的代码,我们必须记住,Python并不像Java那样支持细粒度的锁机制。必须解决好与全局解释器锁(GIL)的问题,无论你有多少个CPU多核或线程,一次只允许一个线程执行。
为了支持并发,我们必须考虑到有一个CPU 密集型进程,因此,最好的方法是使用multiprocessing。所以需要修改代码:
from multiprocessing import Pool
import pandas as pd
def aggreg_basePrice(df_group):
ct_st, grp = df_group
return ct_st, grp.basePrice.agg(['count', 'min', 'max', 'mean', 'std', 'var'])
if __name__ == '__main__':
df = pd.read_csv("recordEntries.csv")
start = time.perf_counter()
excluded_states = ['MN', 'CA']
filtr = ~ df['state'].isin(excluded_states)
df_st = df.loc[filtr]
grouped_by_ct_st = df_st.groupby(['city', 'state'], sort=False)
with Pool() as p:
list_parallel = p.map(aggreg_basePrice, [(ct_st, grouped) for ct_st, grouped in grouped_by_ct_st])
print(f'Time elapsed parallel: {round(finish - start, 2)} sec')
和之前一样,使用Pandas groupby()来获得按城市和州分组的数据(第14行)。在下一行,使用多进程库提供的Pool()来映射分组的数据,使用aggreg_basePrice来计算每组的统计数据。Pool()会对数据进行分割,并在几个平行的独立进程中进行统计计算。
正如下面的表格中所示,多进程比串行运行进程慢得多。因此,对于这些类型的问题,不值得使用这种方法。
可以使用另一种并发运行代码 - Modin。Modin提供了一种无缝的方式来并行化你的代码,当你必须处理大量的数据时是非常有用的。将导入语句从import pandas as pd改为import modin.pandas as pd,可以并行运行代码,并利用环境中可能存在的内核集群来加速代码的执行。
下面的表格是刚刚涉及的不同场景的运行时间(和以前一样,不包括从CSV文件中读取数据的时间):
串行 | 多进程 | Modin 过程 | |
五百万个元素 | 1.94 秒 | 20.25 秒 | 6.99 秒 |
一千万个元素 | 4.07 秒 | 25.1 秒 | 12.88 秒 |
两千万个元素 | 7.62 秒 | 36.2 秒 | 25.94 秒 |
根据表格显示,在Python中串行运行代码甚至比在Java中更快。然而,使用多进程会大大降低性能。使用Moding可以改善结果,使串行运行进程更有利。值得一提的是,和以前一样,我们在计算时间时不包括从CSV文件中读取数据的时间。
可以发现,对于 Pandas 中的 CPU 密集型进程来说,并行化代码是没有优势的。从某种意义上说,这反映了 Pandas 最初的架构方式。Pandas 在串行模式下的运行速度令人印象深刻,而且即使处理大量数据也具有很好的扩展性。
需要指出的是,Python中统计数字的计算速度取决于它的执行方式。为了得到快速的计算,需要应用到统计函数。例如,做一个简单的pandas.DataFrame.describe()来获得统计信息,运行速度会非常慢。
Java 的 Streams 或 Python 的 Pandas 是对大量数据进行分析和统计的两个绝佳选择。两者都有非常可靠的框架,以及足够的支持,能够实现出色的性能和可扩展性。
Java 提供了非常强大的基础架构,非常适合处理复杂的程序流。它非常高效,可以有效地并行运行进程。适用于快速获得结果。
Python 非常适合做数学和统计。它非常简单,相当快,非常适合进行复杂的计算。
译者介绍
翟珂,51CTO社区编辑,目前在杭州从事软件研发工作,做过电商、征信等方面的系统,享受分享知识的过程,充实自己的生活。
原文Data Statistics and Analysis With Java and Python,作者:Manu Barriola