文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

基于纯SQL训练机器学习模型

2024-12-02 04:21

关注

审校 | 梁策 孙淑娟

在​​《用纯SQL在BigQuery上实现深层神经网络》​​一文中,作者声称使用纯SQL方式实现了一个深层神经网络模型。但在我打开他的​​GitHub代码仓库​​分析后发现,他是使用Python来实现迭代训练的,而这并不是真正的纯SQL方式。

在本文中,我将分享我是如何在​​开源分布式SQL数据库TiDB​​上用纯SQL方式训练机器学习模型的。主要步骤包括:

  1. 选择Iris数据集
  2. (https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html)。
  3. 选择softmax逻辑回归模型用于训练。
  4. 编写SQL语句来实现模型推理。
  5. 开始模型训练。

在测试中,我训练了一个softmax逻辑回归模型。在测试期间,我发现TiDB不允许在递归公共表表达式(CTE)中使用子查询和聚合函数。通过修改TiDB的代码,我绕过了这些限制,成功地训练了一个模型,并在Iris数据集上获得了98%的准确率。

为什么我选择TiDB来实现机器学习模型?

TiDB 5.1引入了许多新功能,包括ANSI SQL 99标准的通用表表达式(CTE)。我们可以使用CTE作为临时视图的语句来解耦复杂的SQL语句,并更高效地开发代码。此外,递归CTE可以引用自身,这对于改进SQL功能非常重要。CTE和窗口函数使SQL成为一种图灵完备的语言。

【说明】因为递归CTE可以“迭代”,所以我想尝试一下,看看是否可以使用纯SQL在TiDB上实现机器学习模型训练和推理。

鸢尾花(Iris)数据集

我选择使用scikit-learn的Iris数据集。该数据集包含3种类型,每种类型有50条记录,一共150条。每个记录有4个特征:萼片长度(SL)、萼片宽度(SW)、花瓣长度(PL)和花瓣宽度(PW)。我们可以利用这些特征来预测鸢尾花是否属于山鸢尾(Iris-setosa) 、 变色鸢尾(Iris-versicolor)和维吉尼亚鸢尾(Iris-virginica)。

以CSV格式下载数据后,我将其导入了TiDB数据库。使用的SQL脚本如下:

create table iris(sl float, sw float, pl float, pw float, type varchar(16));
LOAD DATA LOCAL INFILE 'iris.csv' INTO TABLE iris FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' ;
select * from iris limit 10;
+------+------+------+------+------------------+
| sl | sw | pl | pw | type |
+-------+--------+-------+--------+------------+
| 5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa |
| 4.9 | 3 | 1.4 | 0.2 | Iris-setosa |
| 4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa |
| 4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa |
| 5 | 3.6 | 1.4 | 0.2 | Iris-setosa |
| 5.4 | 3.9 | 1.7 | 0.4 | Iris-setosa |
| 4.6 | 3.4 | 1.4 | 0.3 | Iris-setosa |
| 5 | 3.4 | 1.5 | 0.2 | Iris-setosa |
| 4.4 | 2.9 | 1.4 | 0.2 | Iris-setosa |
| 4.9 | 3.1 | 1.5 | 0.1 | Iris-setosa |
+---- --+-------+-------+-------+--------------+
10 rows in set (0.00 sec)
select type, count(*) from iris group by type;
+-------------------+------------------+
| type | count(*) |
+-------------------+-----------------+
| Iris-versicolor | 50 |
| Iris-setosa | 50 |
| Iris-virginica | 50 |
+-------------------+----------------+
3 rows in set (0.00 sec)

Softmax逻辑回归

我选择了一个简单的机器学习模型:用于多类分类的Softmax逻辑回归。在Softmax回归中:

成本函数是:

梯度是:

因此,我们可以使用梯度下降来升级梯度:

模型推理

我编写了一条SQL语句来实现推理。基于上面定义的模型和数据,输入数据x有五个维度(SL、SW、PL、PW和一个常数1.0),输出使用了一种热编码。SQL脚本如下:

create table data(
x0 decimal(35, 30), x1 decimal(35, 30), x2 decimal(35, 30), x3 decimal(35, 30), x4 decimal(35, 30),
y0 decimal(35, 30), y1 decimal(35, 30), y2 decimal(35, 30)
);
insert into data
select
sl, sw, pl, pw, 1.0,
case when type='Iris-setosa'then 1 else 0 end,
case when type='Iris-versicolor'then 1 else 0 end,
case when type='Iris-virginica'then 1 else 0 end
from iris;

共有15个参数(3种类型*5个维度)。SQL脚本如下:

create table weight(
w00 decimal(35, 30), w01 decimal(35, 30), w02 decimal(35, 30), w03 decimal(35, 30), w04 decimal(35, 30),
w10 decimal(35, 30), w11 decimal(35, 30), w12 decimal(35, 30), w13 decimal(35, 30), w14 decimal(35, 30),
w20 decimal(35, 30), w21 decimal(35, 30), w22 decimal(35, 30), w23 decimal(35, 30), w24 decimal(35, 30));

我将输入数据初始化为0.1、0.2、0.3。为了便于演示,我使用了不同的数字。将它们全部初始化为0.1是可以的。SQL脚本如下:

insert into weight values (
0.1, 0.1, 0.1, 0.1, 0.1,
0.2, 0.2, 0.2, 0.2, 0.2,
0.3, 0.3, 0.3, 0.3, 0.3);

接下来,我编写了一条SQL语句来计算数据推断结果的准确性。为了更好地理解,我使用伪代码来描述这个过程:

weight = (   
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24
)
for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:
exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)
exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)
exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)
sum_exp = exp0 + exp1 + exp2

// softmax
p0 = exp0 / sum_exp
p1 = exp1 / sum_exp
p2 = exp2 / sum_exp

//推理结果
r0 = p0 > p1 and p0 > p2
r1 = p1 > p0 and p1 > p2
r2 = p2 > p0 and p2 > p1

data.correct = (y0 == r0 and y1 == r1 and y2 == r2)
return sum(Data.correct) / count(Data)

在上面的代码中,我计算了每行数据中的元素。为了对样本进行推断:

  1. 我计算出加权向量的EXP。
  2. 并且计算出softmax值。
  3. 然后,选择p0、p1和p2中最大的一个作为1,并将其余的设置为0。

如果样本的推断结果与其原始分类一致,则预测正确。然后,我将所有样本的正确数量相加,得到最终的准确率。

下面的代码显示了SQL语句的实现。我将每一行数据加上一个权重(只有一行数据),计算每一行的推断结果,并将正确的样本数相加:

select sum(y0 = r0 and y1 = r1 and y2 = r2) / count(*)
from
(select
y0, y1, y2,
p0 > p1 and p0 > p2 as r0, p1 > p0 and p1 > p2 as r1, p2 > p0 and p2 > p1 as r2
from
(select
y0, y1, y2,
e0/(e0+e1+e2) as p0, e1/(e0+e1+e2) as p1, e2/(e0+e1+e2) as p2
from
(select
y0, y1, y2,
exp(
w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
) as e0,
exp(
w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4
) as e1,
exp(
w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4
) as e2
from data, weight) t1
)t2
    )t3;

上面的SQL语句几乎一步一步地实现了伪代码的计算过程。我得到了如下结果:

+-------------------------------------------------------------+  
| sum(y0 = r0 and y1 = r1 and y2 = r2)/count(*) |
+-------------------------------------------------------------+
| 0.3333 |
+-------------------------------------------------------------+
1 row in set (0.01 sec)

接下来,我开始学习模型参数。

模型训练

注意:为了简化问题,我没有考虑“训练集”和“验证集”问题,而是把所有的数据仅用于进行训练。

我编写了伪代码,然后在此基础上编写了一条SQL语句:

weight = (     
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24
)
for iter in iterations:
sum00 = 0
sum01 = 0
...
sum23 = 0
sum24 = 0
for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:
exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)
exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)
exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)
sum_exp = exp0 + exp1 + exp2
// softmax
p0 = y0 - exp0 / sum_exp
p1 = y1 - exp1 / sum_exp
p2 = y2 - exp2 / sum_exp
sum00 += p0 * x0
sum01 += p0 * x1
sum02 += p0 * x2
...
sum23 += p2 * x3
sum24 += p2 * x4
w00 = w00 + learning_rate * sum00 / Data.size
w01 = w01 + learning_rate * sum01 / Data.size
...
w23 = w23 + learning_rate * sum23 / Data.size
    w24 = w24 + learning_rate * sum24 / Data.size

因为我手动扩展了sum和w向量,所以这段代码看起来有点麻烦。然后,我开始编写SQL训练代码。首先,我编写了一条只用一次迭代的SQL语句。

我设置了如下所示的学习速率和样本数:

set @lr = 0.1;  
Query OK, 0 rows affected (0.00 sec)
set @dsize = 150;
Query OK, 0 rows affected (0.00 sec)

代码迭代了一次:

select 
w00 + @lr * sum(d00) / @dsize as w00, w01 + @lr * sum(d01) / @dsize as w01, w02 + @lr * sum(d02) / @dsize as w02, w03 + @lr * sum(d03) / @dsize as w03, w04 + @lr * sum(d04) / @dsize as w04 ,
w10 + @lr * sum(d10) / @dsize as w10, w11 + @lr * sum(d11) / @dsize as w11, w12 + @lr * sum(d12) / @dsize as w12, w13 + @lr * sum(d13) / @dsize as w13, w14 + @lr * sum(d14) / @dsize as w14,
w20 + @lr * sum(d20) / @dsize as w20, w21 + @lr * sum(d21) / @dsize as w21, w22 + @lr * sum(d22) / @dsize as w22, w23 + @lr * sum(d23) / @dsize as w23, w24 + @lr * sum(d24) / @dsize as w24
from
(select
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,
p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,
p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24
from
(select
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4,
y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2
from
(select
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4, y0, y1, y2,
exp(
w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
) as e0,
exp(
w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4
) as e1,
exp(
w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4
) as e2
from data, weight) t1
)t2
    )t3;

一次迭代后,输出结果是模型参数,如下所示:

以下是核心代码部分,我使用递归CTE进行迭代训练:

set @num_iterations = 1000;
Query OK, 0 rows affected (0.00 sec)

其核心思想是,每次迭代的输入都是前一次迭代的结果,此外我添加了一个增量迭代变量来控制迭代次数。总体框架代码是:

with recursive cte(iter, weight) as
(
select 1, init_weight
union all
select iter+1, new_weight
from cte
where ites < @num_iterations

接下来,我将迭代的SQL语句与这个迭代框架结合在一起。为了提高计算精度,我在中间结果中添加了类型转换:

with recursive weight( iter,
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24) as
(
select 1,
cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast (0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),
cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),
cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30))
union all
select
iter + 1,
w00 + @lr * cast(sum(d00) as DECIMAL(35, 30)) / @dsize as w00, w01 + @lr * cast(sum(d01) as DECIMAL(35, 30)) / @dsize as w01, w02 + @lr * cast(sum(d02) as DECIMAL(35, 30)) / @dsize as w02, w03 + @lr * cast(sum(d03) as DECIMAL(35, 30)) / @dsize as w03, w04 + @lr * cast(sum(d04) as DECIMAL(35, 30)) / @dsize as w04 ,
w10 + @lr * cast(sum(d10) as DECIMAL(35, 30)) / @dsize as w10, w11 + @lr * cast(sum(d11) as DECIMAL(35, 30)) / @dsize as w11, w12 + @lr * cast(sum(d12) as DECIMAL(35, 30)) / @dsize as w12, w13 + @lr * cast(sum(d13) as DECIMAL(35, 30)) / @dsize as w13, w14 + @lr * cast(sum(d14) as DECIMAL(35, 30)) / @dsize as w14,
w20 + @lr * cast(sum(d20) as DECIMAL(35, 30)) / @dsize as w20, w21 + @lr * cast(sum(d21) as DECIMAL(35, 30)) / @dsize as w21, w22 + @lr * cast(sum(d22) as DECIMAL(35, 30)) / @dsize as w22, w23 + @lr * cast(sum(d23) as DECIMAL(35, 30)) / @dsize as w23, w24 + @lr * cast(sum(d24) as DECIMAL(35, 30)) / @dsize as w24
from
(select
iter, w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,
p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,
p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24
from
(select
iter, w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4,
y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2
from
(select
iter, w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4, y0, y1, y2,
exp(
w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
) as e0,
exp(
w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4
) as e1,
exp(
w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4
) as e2
from data, weight where iter < @num_iterations) t1
)t2
)t3
having count(*) > 0
)
select * from weight where iter = @num_iterations;

这个代码块和上面一次迭代的代码块之间有两个区别。在此代码块中:

上述代码运行结果是:

ERROR 3577 (HY000): In recursive query block of Recursive Common Table Expression 'weight', the recursive table must be referenced only once, and not in any subquery

这表明递归CTE不允许在递归部分使用子查询。不过,我可以合并上面所有的子查询。但是,即使在我手动合并它们之后还是得到了以下错误提示:

ERROR 3575 (HY000): Recursive Common Table Expression 'cte' can contain neither aggregation nor window functions in recursive query block

这表明不允许使用聚合函数。然后,我决定改变TiDB的实现代码。

根据​​提案​​​中的介绍,递归CTE的实现遵循了TiDB的基本执行框架。在咨询​​PingCAP​​的研发人员黄文军(Wenjun Huang)之后,我了解到子查询和聚合函数不被允许的原因有两个:

但我只是想测试一下这些功能。为此,我暂时删除了​​diff​​中对子查询和聚合函数的检查。

最后,我再次执行修改后的代码,输出结果如下:

成功了!经过1000次迭代,我得到了参数。

接下来,我使用新参数重新计算正确的速率:

+--------------------------------------------------------------+
| sum(y0 = r0 and y1 = r1 and y2 = r2) / count(*) |
+--------------------------------------------------------------+
| 0.9867 |
+--------------------------------------------------------------+
1 row in set (0.02 sec)

这一次,准确率达到了98%。

结论

通过使用TiDB 5.1中的递归CTE,我成功地使用纯SQL在TiDB上训练了softmax逻辑回归模型。

在测试期间,我发现TiDB的递归CTE不允许子查询和聚合函数,所以我修改了TiDB的代码以绕过这些限制。最后,我成功地训练了一个模型,并在Iris数据集上获得了98%的准确率。

最后,作为补充,在我的上述工作中还总结了下面几个想法:

原文I Trained a Machine Learning Model in Pure SQL,作者:Mingcong Han


来源:51CTO内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯