image frame

All those... moments... will be lost in time, like tears... in... rain.

一个通过猜测api绕过支付的案例

背景

案例已去除所有敏感信息.

前段时间考研时接到的一个需求, 客户是抓取自己的网站, 可能是用来做测试?

该网站是售卖现在比较火的龙傲天短剧, 通常是那种每集2分钟, 动辄100+集的龙傲天短剧。通常采用前xx集免费,等消费者看上头了,后面剧集便需要通过支付购买方可解锁观看,由此实现变现。

支付这个功能实际上是对安全的要求十分高的,通常使用大厂的接口会靠谱的多。然而,支付接口也只是支付这一大块功能中的一个组成部分。即便使用了靠谱的支付接口,其他方面的逻辑漏洞也又可能导致

本次的案例就是支付接口之外的漏洞。该网站在消费者支付后,会在后端验证该消费者的消费信息后,给该主机后台发送所购买的视频的链接。在前端并不能看到这个链接,也看不出什么异常。

分析

在多次的抓包分析后,发现了这样的问题:对付费视频的api做请求时,请求中并没有任何用于检验身份的token。这就意味着,未付费的消费者如果碰巧猜测到了付费视频的api,那么便可以不用付费即可获得视频数据。

本次案例中的视频api有这样的结构:xxxx.com/video_id/num_id/base.m3u8. 其中当集数增加1后,base以36进制也增加1。由此,我们可以通过前一段免费的视频来猜测后面收费视频的api。

问题的关键在base+1的实现。该base是一个字符串,每个位上的字符取值集合为[0-9a-z],按照顺序在数值上主次+1,也就是4+1=5,d+1=e,当z+1后会变回0,并向高位进位。数学上与36进制数字同构。

实现

关键实现一个36进制数字的类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class ABNum:
def __init__(self, value: str | int, pad=-1) -> None:
self.pad = pad
if isinstance(value, int):
self.value = value
return
value_list = [self.map_to_int(v) for v in reversed(value)]
self.pad = len(value_list)
self.value = 0
for index, v in enumerate(value_list):
self.value += v * 36 ** index

def plusOne(self):
self.value += 1

def __add__(self, other):
if isinstance(other, int):
value = self.value + other
return ABNum(value, self.pad)

def __str__(self):
value = self.value
value_list = []
while True:
value, div = divmod(value, 36)
value_list.append(div)
if value < 36:
value_list.append(value)
break
pad_num = self.pad - len(value_list)
value_list.reverse()
if pad_num > 0:
value_list = [0] * pad_num + value_list
return ''.join([self.map_to_str(v) for v in value_list])

@staticmethod
def map_to_int(v: str) -> int:
assert len(v) == 1
if v.isdecimal():
return int(v)
return ord(v) - 87

@staticmethod
def map_to_str(v: int) -> str:
if v >= 0 and v <= 9:
return str(v)
return chr(87+v)

该类并没有完全实现作为一个数字该实现的所有方法,这里仅满足使用即可。算法使用常规的短除法,按照数字定义,从10进制与36进制之间转换。类的内部使用10进制保存数值以及处理算术运算,在需要时再转换为36进制即可。

需要注意的点有:

1. 注意原始36进制数据的位数,最终求出来的36进制数要记得填充0.

2. 注意10进制与36进制转换时,需要做reverse.

总结

作为渗透测试人员、爬虫工程人员,要多留意数据之间的潜在关系。看出关系后,要能联系到可以实现的数学表达。遇到该类网站,未了解全貌后不要擅自爬取。本质上这个已经有些超过合法爬虫的范围,已经触犯法律了。可以提醒站长漏洞。

作为网站运营,这种出售链接的方式,应当使用完全随机的链接,以防止攻击者猜测出规律,绕过支付。更稳妥的方式还是验证浏览者身份信息,以及即时生成随机的视频链接。

理解pandas架构

前言

Pandas 是数据科学中常用的数据处理库;然而初学者并没有学习到Pandas的思想,仍然用数组的思维将Pandas仅仅作为数据的容器,总使用迭代的方式处理数据,这样既不优雅,又容易犯错;此外Pandas的API众多,初学者很容易迷失在茫茫的API文档中,不得要领;

数据模型

首先需要牢记的是,Pandas采用计算式;也就是说绝大多数的操作都是计算出一个结果,而并不修改原数据;如果你需要对原数据作出修改,通常需要手动重新赋值,或者加上inplace参数;

初识数据

这是一个常规数据表的结构示意图:

在Pandas中,数据由Series与DataFrame表示;其中DataFrame可以视为Series的集合;

索引

列名

特征(列)

观测(行)

选择

如果是选择某列,可以直接采用下面的方式选取:

1
col = df["col_name"]

此时得到的是一个Series对象;

若选择多列,则采用下面的方式选取:

1
cols = df[["col_name_1", "col_name_2", "col_name_3"]]

此时得到的cols是一个DataFrame对象。

若选择涉及到行,则需要使用数据框的loc/iloc接口;

loc

loc 暴露了数据框根据index对数据进行选择的接口:

1
sub_df = df.loc[]

iloc

iloc 暴露了数据框根据位置对数据进行选择的接口:

1
sub_df = df.iloc[2:8, 1:9]

筛选

筛选的过程分成两步:

  1. 根据筛选条件生成掩码;

  2. 根据掩码对数据进行筛选。

掩码

以行方向举例,每行都是一个独立的观测

用筛选条件,对所有观测做一个映射,满足筛选条件则映射到True,否则映射到False;

由此获得一个与索引长度相同的掩码数组;

根据该掩码数组,对应为True的观测被保留;对应为False的观测被舍弃。

上述为通用的筛选逻辑,其中前文介绍的loc/iloc接口均可以使用,例如:

1
df.loc[lambda s: s['shield'] == 8, :]

处理

数据处理常用的三大函数类:apply、map、reduce

这三大函数类都接受一个处理函数func、一个序列seq、以及其他的定制参数作为参数,但它们在数据处理的逻辑上会有一些差别。

apply

apply类函数会将序列整体作为一个参数传递给处理函数,该处理函数会将序列整体作为参数进行处理,并返回一个值作为结果;

从效果上来看比较类似:

1
2
3
df['col_name'].apply(func)

func(df['col_name'])

但当apply的序列对象为一个DataFrame时,apply会作用在DataFrame的每个Series上,并将每个Series的计算结果合并成一个Series。此时需要额外传递一个参数axis指明apply的方向。默认0代表行方向,1代表列方向。

1
df.apply(func, axis=1)

map

map类函数与apply类似,接受一个处理函数func,一个序列seq作为参数;不同的是,apply会将seq作为一个整体被func调用,但在map中,func会作用在seq中的每个元素上,并返回对应的计算值。最终所有计算值由map收集,并拼凑成一个与原seq形状相同的数据结构返回。

通常map只作用在Series上:

1
2
3
df['col_name'].map(func)

[func(i) for i in df['col_name']]

如果需要对DataFrame整体进行map操作,api为applymap

reduce

reduce类函数同样接受一个处理函数func,一个序列seq作为参数;

reduce函数会取seq中前两个元素s1, s2作为参数去调用func:func(s1, s2),并将结果继续与后续元素做func运算,直到消耗完seq中的所有元素,并将最后的规约值作为结果返回;

1
2
3
df['col_name'].reduce(func)

func(func(func(func(s1, s2), s3), s4)...)

合并

数据合并通常有两类:具有相同特征的两个数据框按行合并成具有更多观测的数据框、具有相同观测的两个数据框按列合并成具有更多特征的数据框。

通常前者更简单,使用concat:

1
df = pd.concat(df_1, df_2)

往往采用外连结的方式保证所有特征都被保留。

后者通常会复杂一些,往往使用merge:

1
df = pd.merge(df_1, df_2, left_on="left_col", right_on="right_col", how="left")

这里参考数据库的左右连接与内外连接,两者逻辑上完全一致。

重铸

重铸这个词来自R中的数据处理。

用来修改数据的结构。

melt

stack

unstack

窗口 window

聚合 groupby

groupby 同样是数据处理中最重要的函数之一。

groupby

groupby 将数据按照某一标准分类,以便让方便后续的apply操作;

groupby通常传入列名即可,pandas会将改列值相同的观测分成一类;

1
df.groupby('col_name')

groupby返回的是一个惰性对象,也就是说分类并不会立刻开始,而是会在后续运算时一并运算。

常见的会在其后加上

Pandas API 一览

Series 是Pandas 中的最小单位

强化学习笔记

教材:蘑菇书

绪论

  • 强化学习与监督学习的不同

    • 强化学习输入的样本是序列数据,而不像监督学习里面样本都是独立的。
    • 强化学习需要不断的试错探索。探索和利用是强化学习里核心的问题。需要在探索和利用之间找到一个权衡。
    • 强化学习里没有强监督者,只有奖励信号,并且奖励信号是延迟的。
    • 强化学习可以达到超人例子,而监督学习只能接近人。
  • 一些术语和概念

    • 通过预演(rollout)获取一系列观测

    • 每个观测称之为一个轨迹(trajectory)

    • 轨迹是当前帧状态以及它的策略动作的序列:$(s_0, a_0, s_1, a_1,…)$

    • 每个观测结束,我们可以通过观测序列以及最终奖励(eventual reward)来训练智能体

    • 一场游戏成为一个回合(episode)或者实验(trial)

      强化学习 -> 深度强化学习
      
      将特征工程并入深度学习中
      

序列决策

智能体和环境

  • 概念
    • 强化学习研究的问题是智能体与环境交互的问题

奖励

  • 概念
    • 奖励是由环境给的一种标量的反馈信号,可以显示智能体在某一步的某个策略的表现。
    • 强化学习的目的是最大化期望的累积奖励

序列决策

奖励分为近期奖励和远期奖励,两者的权衡是强化学习的重要课题之一

历史是观测、动作、奖励的序列:
$$H_t=o_1, a_1, r_1,…,o_t,a_t,r_t$$

智能体的当前动作依赖于它之前得到的历史,整个游戏的(智能体)状态可以看为关于历史的函数:
$$s_t=f(H_t)$$

此外,环境有自己的函数:
$$s^e_t=f^e(H_t)$$

智能体自身的函数:
$$s^a_t=f^a(H_t)$$

观测与状态:

状态是对世界的完整描述
观测是对状态的部分描述(通常是智能体能获取的信息)

完全可观测和部分可观测

完全可观测指智能体状态和环境状态等价,此时通常建模为马尔可夫决策过程,此时$o_t=s^e_t=s^a_t$
部分可观测指智能体无法获取环境运作的所有状态,此时通常建模为部分可观测马尔可夫决策过程

动作空间

动作空间指给定环境下,有效动作的集合

离散动作空间
连续动作空间

智能体的组成成分和类型

策略,智能体用策略选择下一步的动作
价值函数,用价值函数评估智能体当前的状态
模型,表示智能体对环境状态的理解

策略

策略是一个函数,将输入状态映射到动作

随机性策略,使用$\pi$函数,$\pi(a|s)=p(a_t=a|s_t=s)$,根据概率分布选择动作。
确定性策略,$a = argmax\pi(a|s)$,容易被对手预测

价值函数

价值函数的值是对未来奖励的预测,用于评估状态的好坏。

折扣因子是将收益从时间向收益的转换

价值函数定义:
$$V_{\pi}(s)=E_{\pi}[G_t|s_t=s]=E[\sum_{k=0}^{\INF}\gamma^kr_{t+k+1}|s_t=s],对于所有的s\inS$$

另一种价值函数Q函数,包含两个变量,状态和动作,定义为
$$Q_{\pi}(s)=E_{\pi}[G_t|s_t=s, a_t=a]=E[\sum_{k=0}^{\INF}\gamma^kr_{t+k+1}|s_t=s, a_t=a]$$

Q函数是需要学习的函数,可以获取某个状态需要采取的最优动作。

模型

模型由转移概率和奖励函数组成。

转移概率:
$$p^a_{ss’}=p(s_{t+1}=s’|s_t=s,a_t=a)$$

奖励函数值当前状态采取某动作,可以获得的奖励:
$$R(s, a)=E[r_{t+1}|s_t=s, a_t=a]$$

策略 + 价值函数 + 模型 => 马尔可夫决策过程

智能体分类

学习过程

基于策略的学习,基于策略的智能体,有策略梯度算法
基于价值的学习,基于价值的智能体,有Q学习,Sarsa算法,适用离散环境,连续环境下效果差
演员-评论员智能体,类似上述两者的集成

模型

有模型,通过学习状态的转移来采取策略,类似多了特征工程,可以一定程度减缓数据匮乏的问题
免模型,通过学习价值函数和策略函数做决策,泛化性强,消除了特征工程与真实环境之间的信息损失

学习与规划

学习与规划是序列决策的两个基本问题。

探索和利用

探索和利用是强化学习的两个核心问题。

探索-利用窘境

单步强化学习

知道每个动作的奖励
执行奖励最大的动作

单步强化学习对应K-臂赌博机模型。

仅探索法 (类似创业
仅利用法 (类似上班

马尔可夫决策过程

马尔可夫过程

掠过

马尔可夫性质

齐次马尔可夫,简化版

马尔可夫链

马尔可夫奖励过程

比马尔可夫链多了一个奖励函数

回报与价值函数

注意这里,回报、价值、奖励是不同的概念

奖励 r: 奖励是当前时刻触发的收益
回报 g: 回报是当前状态往后,直到最终时刻所获的所有收益
价值 v: 价值函数是回报的期望

贝尔曼方程

贝尔曼方程是价值函数的另一种形式

价值函数 = 即时奖励 + 未来所有奖励

可以通过价值函数的贝尔曼形式得到解析解,问题在于通常来说P是不知道的,需要学习

解析解的问题在于复杂度很高,对于高状态的矩阵算起来非常困难

迭代算法

蒙特卡洛

动态规划

时序差分学习

马尔可夫决策过程

决策过程中多一个智能体做决策的动作

价值函数 V

Q 函数

Q 函数是基于 V 函数的动作层面的空间划分

Q 是某个动作的价值,指站在此状态,选择某个动作的价值

V 是所有动作总空间的期望价值,指站在此状态,对所有动作的价值的期望

贝尔曼期望方程

先把价值函数写成Q函数的分解,再对Q函数作贝尔曼方程分解,得到一个贝尔曼期望方程的一种形式

先把价值函数写成即时奖励和后续奖励的分解,再将后续奖励中的价值函数作贝尔曼方程分解,得到贝尔曼期望方程的另一种形式

备份图

备份图描述了状态价值函数的计算分解

策略评估

前提:已经马尔可夫决策过程、策略

此时不断迭代,所有的价值函数最终都会收敛

预测与控制

预测:已知马尔可夫决策过程,策略,计算每个状态的价值

控制:已知马尔可夫决策过程,虚招最佳策略,以及最佳价值函数

动态规划

dp 解决同时有最优子结构和重叠子问题性质的问题。马尔可夫决策过程就具有这种性质,可以用动态规划去解。

马尔可夫决策过程中的策略评估

前提:已经马尔可夫决策过程、策略

将贝尔曼期望备份转换为迭代过程,直到收敛,这个过程是同步备份。

马尔可夫决策过程的控制

如何寻找最优策略,进而计算最佳价值函数

朴素方法:穷举

迭代算法:策略迭代、价值迭代

策略迭代

策略评估 + 策略改进

巨大的基础:马尔可夫过程已知。

  1. 初始化策略,和价值函数
  2. 根据当前策略,计算状态新价值函数
  3. 根据状态价值函数,推算Q 函数
  4. 最优化Q 函数,得到新策略
  5. 判断迭代是否结束,未结束,则跳转到2

基础:马尔可夫过程已知

策略 + 价值函数 –贝尔曼方程迭代–> 新价值函数
价值函数 –> Q 函数
最大化Q 函数 –> 新的策略

Q 函数可以看成一个Q 表格

这个是大致的策略迭代算法

价值迭代

最优性原理: 一个策略在状态s达到最优价值的充要条件是,任何可以从s到达的状态s’都达到了最优价值。

算法:

更新Q ,更新V 若干次
从最后的V中提取最优策略

表格型方法

使用查找表

比如蒙特卡洛、Q学习、Sarsa

有模型

使用概率函数和奖励函数来描述环境

免模型

因为很多时候环境未知,概率函数甚至奖励函数都未知。

此时免模型需要通过试探来估计概率函数等环境信息。

有模型和免模型

有模型可以直接从环境推导智能体
免模型需要不断和环境交互,迭代智能体

Q 表格

Q 表格是主要的训练对象

强化是指用下一个状态的价值来更新当前状态的价值。

免模型预测

前提:无法获取马尔可夫决策过程模型

蒙特卡洛方法

蒙特克罗使用采样,使用经验品君回报估计价值函数

优点:不需要状态转移函数和奖励函数,也不用动态规划中的自举。

缺点:只能用于有终止的马尔可夫决策过程

算法:

蒙特克罗和动态规划比较

  1. 蒙特克罗适用于环境未知,而且更新速度快
  2. 动态规划适用于有模型,但每次迭代需要更新所有状态,速度很慢

时序差分方法(TD)

免模型控制

策略迭代进行广义推广,兼容蒙特卡洛和时序差分,也就是广义策略迭代。

Sarsa 同策略时序差分控制

Q 学习 异策略时序差分控制

同策略和异策略的区别

一个例子:用Q 学习解决悬崖寻路问题

策略梯度

策略梯度算法

策略梯度实现技巧

1. 添加基线

2. 分配合适的分数

蒙特卡洛策略梯度

一个例子:用策略梯度算法解决悬崖寻路问题

近端策略优化

一个例子:用近端策略优化算法解决悬崖寻路问题

深度Q 网络

理解Flask架构

WSGI 协议

WSGI 是一种网络接口协议;在Web与Python之间构造了一层转换器;
当用户请求发到服务器后,服务器按照WSGI协议将请求转换为Python中的env变量与response函数;
通常开发人员从env中提取出必要的信息,作出反应后按照WSGI协议调用response函数,实现
WSGI协议的服务器会将返回值转换为Web响应。

Flask 架构

Flask 架构

Flask 上下文

Flask 的架构与Flask 上下文的设计密切相关。

在Flask 应用开发代码中,有两个特点:

  1. 开发人员不需要处理过多的线程问题,全局变量request即可自动处理好。
  1. 开发人员不需要将请求作为参数或返回值在不同的函数间传来传去;

ContextVar + LocalProxy 解决方案

ContextVar 源码分析
1
from contextvars import ContextVar

ContextVar 是属于标准库的数据结构;旨在提供一个统一的变量入口;该变量在开发代码中唯一,在运行时会自动根据线程分离。

该标准在3.7加入python标准库,原作者也写了一个3.6的兼容版本:https://github.com/MagicStack/contextvars/blob/master/contextvars/__init__.py

该版本的contextvar库是构建在thread.local上的。

在代码的最底部定义了一个全局变量_state用于管理所有的Context与ContextVar,而且该变量是一个thread.local创建的对象,这保证了不同线程间变量是分隔的;

比如Flask中的request上下文变量在Flask收到多个请求时,实际的数据是这样的:

上下文变量在多线程中的实际表现

上图左边Flask处理多请求;右边为对应线程中request在整个进程中的实际存储形式;

下面我们来看contextvars的源码:

首先最重要的是该模块中的全局变量_state:

1
_state = threading.local()

该变量为一个local变量,会自己隔离不同线程中的变量;此外该变量没有声明在__all__中,是一个模块级的全局变量,不用担心全局的命名冲突。

以及该变量对应的两个辅助函数:

1
2
3
4
5
6
7
8
9
10
def _get_context():
ctx = getattr(_state, 'context', None)
if ctx is None:
ctx = Context()
_state.context = ctx
return ctx


def _set_context(ctx):
_state.context = ctx

_get_context 函数用于从获取_state 变量的 ‘context’ 属性,该’context’实际上是线程分离的(因为_state 是local变量,而’context’是_state 的属性);
若当前线程的’context’为None,则将其初始化为一个上下文对象Context,并返回;

_set_context 函数就很好理解了,传入一个上下文对象,将该上下文对象设置为当前线程的’context’。

然后我们来看上下文对象和上下文变量对象。其关系是,上下文对象为上下文变量对象的容器。我们先看上下文对象:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class ContextMeta(type(collections.abc.Mapping)):

# contextvars.Context is not subclassable.

def __new__(mcls, names, bases, dct):
cls = super().__new__(mcls, names, bases, dct)
if cls.__module__ != 'contextvars' or cls.__name__ != 'Context':
raise TypeError("type 'Context' is not an acceptable base type")
return cls


class Context(collections.abc.Mapping, metaclass=ContextMeta):

def __init__(self):
self._data = immutables.Map()
self._prev_context = None

def run(self, callable, *args, **kwargs):
if self._prev_context is not None:
raise RuntimeError(
'cannot enter context: {} is already entered'.format(self))

self._prev_context = _get_context()
try:
_set_context(self)
return callable(*args, **kwargs)
finally:
_set_context(self._prev_context)
self._prev_context = None

def copy(self):
new = Context()
new._data = self._data
return new

def __getitem__(self, var):
if not isinstance(var, ContextVar):
raise TypeError(
"a ContextVar key was expected, got {!r}".format(var))
return self._data[var]

def __contains__(self, var):
if not isinstance(var, ContextVar):
raise TypeError(
"a ContextVar key was expected, got {!r}".format(var))
return var in self._data

def __len__(self):
return len(self._data)

def __iter__(self):
return iter(self._data)

Context 对象的结构如下:

Context 对象的结构

_data 属性是一个映射字典,用于保存当前上下文中的上下文变量。

_prev_context 属性用于临时保存之前的上下文对象,用于临时上下文切换。

比较重要的是这个方法:__getitem__

__getitem__ 方法实现了映射协议,该方法先检查传入的key是否为上下文变量ContextVar实例,如果不是则抛出异常,如果是则在_data属性中以该实例为key,获取实际的保存的值;

接下来是ContextVar的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class ContextVarMeta(type):

# contextvars.ContextVar is not subclassable.

def __new__(mcls, names, bases, dct):
cls = super().__new__(mcls, names, bases, dct)
if cls.__module__ != 'contextvars' or cls.__name__ != 'ContextVar':
raise TypeError("type 'ContextVar' is not an acceptable base type")
return cls

def __getitem__(cls, name):
return


class ContextVar(metaclass=ContextVarMeta):

def __init__(self, name, *, default=_NO_DEFAULT):
if not isinstance(name, str):
raise TypeError("context variable name must be a str")
self._name = name
self._default = default

@property
def name(self):
return self._name

def get(self, default=_NO_DEFAULT):
ctx = _get_context()
try:
return ctx[self]
except KeyError:
pass

if default is not _NO_DEFAULT:
return default

if self._default is not _NO_DEFAULT:
return self._default

raise LookupError

def set(self, value):
ctx = _get_context()
data = ctx._data
try:
old_value = data[self]
except KeyError:
old_value = Token.MISSING

updated_data = data.set(self, value)
ctx._data = updated_data
return Token(ctx, self, old_value)

def reset(self, token):
if token._used:
raise RuntimeError("Token has already been used once")

if token._var is not self:
raise ValueError(
"Token was created by a different ContextVar")

if token._context is not _get_context():
raise ValueError(
"Token was created in a different Context")

ctx = token._context
if token._old_value is Token.MISSING:
ctx._data = ctx._data.delete(token._var)
else:
ctx._data = ctx._data.set(token._var, token._old_value)

token._used = True

def __repr__(self):
r = '<ContextVar name={!r}'.format(self.name)
if self._default is not _NO_DEFAULT:
r += ' default={!r}'.format(self._default)
return r + ' at {:0x}>'.format(id(self))

ContextVar 有两个属性:_name_default 以及一个动态属性:name

_name 属性用于标记该ContextVar的名称,_default用与指明该ContextVar的默认值;

name 返回当前保存的_name值。

ContextVar 有三个重要的方法:get, set, reset

我们先聊get 方法

get 方法会先调用_get_context加载当前线程的上下文,然后尝试在该上下文中查找当前ContextVar对应的值,并返回;如果该上下文没有当前ContextVar,则尝试返回默认值;

set 方法同样先调用_get_context加载当前线程的上下文,然后获取该上下文的_data,尝试先获取其ContextVar的旧值,再更新为新值,最终返回一个Token,该Token中保存了上下文对象,当前上下文变量对象以及旧值;

reset 方法传入一个Token,使用Token中的信息将当前ContextVar恢复到上一次的值;

总结一下:

  1. _state 这个local 变量负责分离、管理线程变量;该对象是模块自身生成管理的;

local 逻辑模拟图

  1. Context 是一个上下文对象,用于管理上下文变量对象;该对象实现了映射协议,key值为ContextVar,val为对应的实际的值;

Context 结构

  1. ContextVar 是一个上下文变量对象,与实际的值之间有一个对应关系;此外,该对象还实现了临时恢复的功能;

ContextVar 结构

LocalProxy

Context 与 ContextVar帮助我们实现了变量的线程分离;但每次使用变量,总是需要调用some_var.get() 方法与some_var.set() 方法,非常的不优雅,而且容易产生语义上的误解;

werkzeug 针对ContextVar,设计了LocalProxy类,该类将ContextVar的某个属性封装起来,并将赋值,读值等方法内部转发委托给其封装的ContextVar或其某个属性;某种角度上,LocalProxy类是ContextVar的一种语法糖类;

AppContext + RequestContext

AppContext 与 RequestContext 实例在Flask中是全局变量,在所有线程中均可见,但是被local分离,互不影响;

我们检查global.py,可以看到这几个关键的全局变量声明:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
_cv_app: ContextVar["AppContext"] = ContextVar("flask.app_ctx")
app_ctx: "AppContext" = LocalProxy(
_cv_app, unbound_message=_no_app_msg
)
current_app: "Flask" = LocalProxy(
_cv_app, "app", unbound_message=_no_app_msg
)
g: "_AppCtxGlobals" = LocalProxy(
_cv_app, "g", unbound_message=_no_app_msg
)

_cv_request: ContextVar["RequestContext"] = ContextVar("flask.request_ctx")
request_ctx: "RequestContext" = LocalProxy(
_cv_request, unbound_message=_no_req_msg
)
request: "Request" = LocalProxy(
_cv_request, "request", unbound_message=_no_req_msg
)
session: "SessionMixin" = LocalProxy(
_cv_request, "session", unbound_message=_no_req_msg
)

可以看到这里声明了两个上下文变量,分别是应用上下文变量_cv_app与请求上下文变量_cv_request

而我们常用的 current_app, g, request, session 都是它们的某个属性代理;

总结一下,全局环境中存在两个ContextVar变量,_cv_app_cv_request,这两个变量负责管理current_appgrequestsession这些LocalProxy变量;

当我们使用_cv_app这些ContextVar变量的时候,我们实际上使用的是当前线程对应的’context’字典中以_cv_app作为key找到的值;

当我们使用g这些LocalProxy变量的时候,我们实际上使用的是该LocalProxy对应的ContextVar变量本身或其某个属性;

如下图所示:

Flask 上下文结构

需要注意,以上的变量均是全局变量。若干的context是指不同线程中的context变量。

上下文推送

在Flask里,当我们需要使用这些LocalProxy变量时,我们需要将对应的上下文变量推送到当前上下文栈中。

在未进行上下文推送前,_cv_app_cv_request 实际上是空对象,此时使用这些变量下的LocalProxy变量,往往会触发所谓的上下文错误。

在上下文推送后,这两个ContextVar会被更新,此时其对应的LocalProxy变量也才变得可用。

通常上下文推送是由Flask自动推送的,如果在开发期间需要使用,需要手动推送上下文变量。通常这个发生在初始化过程中,比如数据库初始化等等。

应用上下文
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class AppContext:
"""The app context contains application-specific information. An app
context is created and pushed at the beginning of each request if
one is not already active. An app context is also pushed when
running CLI commands.
"""

def __init__(self, app: "Flask") -> None:
self.app = app
self.url_adapter = app.create_url_adapter(None)
self.g: _AppCtxGlobals = app.app_ctx_globals_class()
self._cv_tokens: t.List[contextvars.Token] = []

def push(self) -> None:
"""Binds the app context to the current context."""
self._cv_tokens.append(_cv_app.set(self))
appcontext_pushed.send(self.app)

def pop(self, exc: t.Optional[BaseException] = _sentinel) -> None: # type: ignore
"""Pops the app context."""
try:
if len(self._cv_tokens) == 1:
if exc is _sentinel:
exc = sys.exc_info()[1]
self.app.do_teardown_appcontext(exc)
finally:
ctx = _cv_app.get()
_cv_app.reset(self._cv_tokens.pop())

if ctx is not self:
raise AssertionError(
f"Popped wrong app context. ({ctx!r} instead of {self!r})"
)

appcontext_popped.send(self.app)

def __enter__(self) -> "AppContext":
self.push()
return self

def __exit__(
self,
exc_type: t.Optional[type],
exc_value: t.Optional[BaseException],
tb: t.Optional[TracebackType],
) -> None:
self.pop(exc_value)

这里的关键在于AppContext维护了一个_cv_tokens栈,用于保存之前的上下文;push方法将全局_cv_app设置为当前的AppContext的实例;并将之前的_cv_app的值以token的形式存在栈里,方便后序恢复;

此外,该类实现了with协议,可以使用with语句激活上下文;

Flask 与 蓝图系统

蓝图实际上和FlaskApp本质上很像,两者实际上都是继承自Scaffold;

在实际开发中,几乎可以把蓝图和FlaskApp同等对待,甚至蓝图本身还可以注册蓝图;

这里我们着重讲一下蓝图的注册;

使用app.register_blueprint(blueprint)将blueprint注册到app.

register_blueprint 会调用blueprint的register方法做以下这些事情:

  1. 从blueprint构建蓝图的name

  2. 根据name判断是否已经有同名蓝图注册;如果有,则抛出异常(注意这里只是判断同名,不是判断是否已经注册过)

  3. 判断blueprint是否有自己的静态文件夹,如果有,就添加到路由规则集中;

  4. 判断blueprint是否之前被注册过了,如果注册已经注册过了,就将当前的蓝图以覆盖的形式更新注册;

  5. 将自身的命令接口注册到主app上;

  6. 更新blueprint的子blueprint的路由信息,并调用这些子blueprint的register方法,将这些子blueprint注册到主app上;

因此最终只会有一个路由规则集,那就是主app的路由规则集,其余所有蓝图及其子孙蓝图的路由规则集都会在合理变换后加到主app的路由规则集里;

路由系统

在FlaskAPP 开发过程中,每个App和蓝图都定义自己的路由系统,不过这只是开发阶段;在运行时,实际上只存在一套路由系统,那就是主App的路由系统,其余所有的子路由系统都在运行时做了合适的变换,并补充到了主App的路由系统中;

Flask 路由系统并不复杂,但在深入之前,我们先了解一些werkzeug中的路由组件;

Map 与 Rule

Map和Rule是werkzeug中负责管理路由的组件;Rule负责定义具体的路由规则,而Map则是Rule的容器,同时也负责根据路径对其保存的Rule进行查询,并给出路由终点;

一个来自官方文档的例子:

1
2
3
4
5
url_map = Map([
Rule('/', endpoint='new_url'),
Rule('/<short_id>', endpoint='follow_short_link'),
Rule('/<short_id>+', endpoint='short_link_details')
])

传入请求的environ,调用Map的bind_to_environ方法,生成一个URLAdapter;

1
adapter = url_map.bind_to_environ(request.environ)

该adapter中保存着路由的终点以及传入的参数值;

1
2
for endpoint, values in adapter.match():
do_something()

这里我们需要注意的是,werkzeug中的路由系统仅仅完成了url到endpoint这一步,并没有进一步找到对应的视图函数;在Flask中,存在一个由endpoint到视图函数的字典,这个字典完成了拼图上的最后一块;

Flask 路由

推送上下文后,Flask会尝试做一次full_dispatch_request();

在full_dispatch_request中,Flask会做一些初始判断与对request的预处理,若一切顺利,Flask会再做一次dispatch_request();

在dispatch_request中,我们会发现Flask已经获取到负责路由请求的Rule,并根据其endpoint在Flask的view_functions中查询视图函数;后面便是对视图函数的调用了;

实际上Flask在请求传递进来,构造请求上下文时,就已经完成了路由,并将这些路由信息(如路由规则,路由终点,路由参数等)记录在了请求对象request中;

检查Flask的请求上下文的__init__函数,我们可以看到请求上下文的构造过程:

  1. 将上下文本身的app指向FlaskApp

  2. 根据FlaskApp提供的请求类与传入的environ构造请求对象,并赋值给request

  3. 使用app.create_url_adapter处理请求;在这个过程中,调用了werkzeug中的函数获取路由终点,并将结果保存在上下文的url_adapter变量中;

在请求上下文推送后,上下文会调用自己的match_request方法,尝试从自身url_adapter变量中获取路由终点,路由参数,并将这些信息绑定到request上,至此,路由信息才被写入请求中;

后面的逻辑就很平坦了,Flask从请求中获取路由终点与路由参数;以路由终点作为key,从view_functions中查询视图函数,并将路由参数作为参数调用视图函数,获取响应;

完成响应的构造与后序处理后,委托werkzeug返回给用户端;

视图函数与可拔插视图类

视图函数就是我们在web开发模型的MVC中常说的View,名副其实的视图函数;

通常视图函数就是我们的业务代码,按照正常的业务逻辑编写即可;其中有两点值得称道:

  1. 函数只用从请求链接中接受参数,并不需要传入响应本身;直接导入全局的request使用即可;

  2. 函数写完后只需要加一个装饰器即可自动注册到目标app或蓝图中,同时可以直接定义从链接中接受那些参数,非常直观方便;

Flask中优秀的上下文设计使得视图函数与请求解藕,使得开发逻辑更加清晰;

建议将后台的一些工具辅助函数单独写在一个模块中,而将业务逻辑相似的代码单独放在一起,方便开发管理;

这里我们重点讲一下可拔插视图类,这个工具让视图函数的抽象层级提高了一层,合理的利用可拔插视图类,我们可以利用继承,将相似的路由逻辑和业务逻辑抽离出来,提高我们的效率;你完全可以在工作的过程中不断自己总结,编写自己的可拔插视图类,这些代码将会在未来解放你的大量生产力;

我们来看一个简单的例子,这个例子来自Flask的官方文档:

我们需要将某个链接路由到对用户列表的查询,我们完全可以使用视图函数实现:

1
2
3
4
@app.route('/query/users/')
def show_users():
users = User.query.all()
return render_template('users.html', users=users)

假设此时又出现一个需求,要求我们将某个链接路由到对商品列表的查询,此时你不得不在写一个逻辑类似的视图函数:

1
2
3
4
@app.route('/query/goods/')
def show_goods():
users = Goods.query.all()
return render_template('goods.html', goods=goods)

假设此时又出现一个需求…

这么编写非常冗余,而且有个致命的问题,万一某一天,你的老板决定要给这些查询加上一个额外的参数,用于单独查询某页的数据结果;这个时候你不得不重新修改所有的视图函数,而这些视图函数还可能分布在不同的地方,此外这种乏味单调的修改极容易出现错误;

此时有若干解决方法:

lisp程序员可能会使用宏,在程序运行时用宏自动生成对应的视图函数;也有可能想到使用函数式编程,先写一个工厂函数,在运行时用工厂函数即时计算对应的视图函数;等等

这里Flask提供一种类似函数式编程思想的工具:可拔插视图类;

首先我们将相似的逻辑抽象出来:

1
2
3
4
5
6
7
8
9
10
11
12
13
from flask.views import View

class ListView(View):

def get_template_name(self):
raise NotImplementedError()

def render_template(self, context):
return render_template(self.get_template_name(), **context)

def dispatch_request(self):
context = {'objects': self.get_objects()}
return self.render_template(context)

然后,我们继承该类,完善处理具体逻辑的子类:

1
2
3
4
5
6
7
class UserView(ListView):

def get_template_name(self):
return 'user.html'

def get_objects(self):
return User.query.all()

此时如果需要临时加上额外参数,那么仅修改基类即可完成所有类的逻辑更新;

某种角度上来说,可拔插视图类的思想有些类似函数式编程中的高阶函数;不同的是在面向对象中,新的类通过继承的方式产生,通过新类的重写实现子类的定制,将相似逻辑写在基类中的方式将相似的逻辑抽离出来;而在函数式编程中,新的函数通过某个函数的计算生成,通过传参的方式定制新函数的行为,将相似逻辑写在工厂函数中的方式将相似的逻辑抽离出来;

模板系统

模板系统在视图函数返回与响应生成这两个时间点之间发挥作用;

视图函数返回的并不是最终的响应,而是最终响应的一些必要信息,保存在变量rv中,需要以此为参数,调用FlaskApp的finalize_request()才能得到最终的响应对象;

通常生成rv过程中会涉及到模板的渲染;

Flask默认使用jinja2作为渲染引擎;

开发过程中,最常用的渲染函数有render_template以及stream_template;前者用于渲染常规响应,后者用于渲染流式响应;

模板系统通常做两件事情:

  1. 根据提供的模板名称寻找到模板本身

  2. 在提供的上下文的帮助下,填充并渲染模板,并将结果返回

模板路由

模板渲染

找到渲染模板后,实际的渲染过程发生在_render_stream中,两者接受相同的参数:app、template、context;但逻辑有些许差别,下面我们分开讨论:

1
2
3
4
5
6
def _render(app: "Flask", template: Template, context: t.Dict[str, t.Any]) -> str:
app.update_template_context(context)
before_render_template.send(app, template=template, context=context)
rv = template.render(context)
template_rendered.send(app, template=template, context=context)
return rv

_render中,首先用app的一些信息更新了context中的信息;注意这里的上下文是指用于模板渲染的上下文,而不是我们前面讨论的应用上下文等;这里更新的原因很简单,模板渲染很可能也需要一些来自app的信息;

中间插入了一个信号发送,这里暂时不展开;

紧接着用context作为上下文对模板进行渲染,获取到rv,这里的rv是字符串;

又一个信号发送;

返回渲染后的字符串rv作为渲染结果

_stream函数与_render类似:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

def _stream(
app: "Flask", template: Template, context: t.Dict[str, t.Any]
) -> t.Iterator[str]:
app.update_template_context(context)
before_render_template.send(app, template=template, context=context)

def generate() -> t.Iterator[str]:
yield from template.generate(context)
template_rendered.send(app, template=template, context=context)

rv = generate()

# If a request context is active, keep it while generating.
if request:
rv = stream_with_context(rv)

return rv

不同的地方是这里使用template.generate替代了template_render,并且返回的是一个字符串迭代器作为rv,而非字符串;

响应

在full_dispatch_request方法的最后,获取了rv并调用finalize_request方法,完成响应的构建;

关键的响应生成逻辑写在了make_response中,这个方法非常长,但逻辑并不复杂,这里就不放出来,仅做一些逻辑说明:

make_response接受rv作为参数,并返回一个响应对象;

前面的很多叙述省略了很多异常处理的情况,真实情况,rv可能的类型很多:str、bytes、dict、list、generator、iterator、tuple;这里对其类型做了判断,并分开处理;

  • tuple
    rv必须是长度为2或3的tuple。长度为2时应当为(rv, headers)或(rv, status)的形式将其解包;

  • None
    抛出异常

  • 不是响应类

    • str、bytes、bytearray
      使用rv作为参数构造响应对象,并赋值回给rv

    • dict、list
      使用rv作为参数构造json响应,并赋值回给rv

    • BaseResposne、callable
      使用response_class.force_type将rv强制转换为响应对象,并返回给rv

最终rv必定为响应类型,更新rv的headers及status后,作为最终响应返回;

这里需要注意的只有一点,就是rv作为callable对象返回的情况,此时rv应当接受两个参数:environ与start_response;

基于python的算法实现

定式

用数学去分析和思考问题

  1. 搜索类
  1. 将问题看成一个代数空间
  2. 算法的每个部分的作用,通常前面会先处理特殊情况,然后对输入做一些规范化,转换为一般问题
  3. 循环要满足循环不变的性质,这在算法证明中很重要
  4. 从信息的角度看问题,我有哪些信息,使用那种数据结构可以充分高效地存储这些信息,以及基于这种数据结构,使用哪些算法可以充分的用到所有信息。
  5. 除了信息本身,被研究问题的一些数学性质可以很好的压缩解的代数空间,从而优化算法。

条件 condition

(增量式) 创建某种语境

1
2
3
4
# context 1 or context 2
if context_1:
do_some_thing()
# target context

(分支式) 创建某种语境

1
2
3
4
5
6
7
8
# context 1 of context 2 or ...
if context_1:
do_some_thing_1()
elif context_2:
do_some_thing_2()
elif context_3:
do_some_thing_3()
# target context

(分支式) 分类处理

1
2
3
4
5
6
7
8
# context 1 of context 2 or ...
if context_1:
do_some_thing_1()
elif context_2:
do_some_thing_2()
elif context_3:
do_some_thing_3()
# end

序列 sequence

一些有关序列的小技巧, 帮助你快速自信的写出正确的代码

python 以 0 作为序列首位索引

使用左闭右开区间,length 必定数组越界,length-1 为最后索引

好处是处理 start, length,end 的关系时无需+-1

1. 索引遍历数组

1
2
for _ in range(len(array)):    # [0, length) -> [0, length-1]
do_some_thing()

2. start + length 切片

[start: start+length]

3. end + length 切片

[(end-1)-length, (end-1)]

4. start + end 计算 length

length = end - start

循环 loop

1
2
3
while not target_context:
do_some_thing()
# target context

递归 recurse

编码风格

  1. 终止条件写在最前面

操作式递归

  1. 如何理解操作式递归代码

阅读递归代码时要将递归链/树视为一个整体,而递归函数就是逐步零碎的调整整个链/树,直到触发终止条件;

实际上终止条件在大多数的递归中都不会触发,而是一种兜底,面向链/树整体的全局终止条件;

在递归点前后的代码分别称为前序代码和后序代码;前序代码对整体的调整由浅入深;直到遇到递归点,此时深度再进一层;

在触发终止条件后,后序代码开始对整体由深返浅地调整,最终结束。

计算式递归

  1. 如何理解计算式递归代码

同样的,首先定义终止条件,并需要注意到,终止条件中返回的值将是整个递归树的基础值;所有结果都将由基础值组合得到;

通常计算式递归并没有什么前序代码,如果有,那大部分是为了计算返回值而设计的;

计算式递归需要有一个锚点接受后续递归的计算结果;通常这个锚点是多个;用于组合成当前的计算值并返回;

因此计算式递归中有两种值:从终止条件计算的基础值;由基础值组合而成的组合值;

通常终止条件会与前序代码混在一起,用于计算基础值;后序代码通常用于计算组合值;

例题1. 计算斐波那契数列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def feb_rec(n):
# 基础值
if n in [0, 1]:
return n

# 递归锚点
result_rec_2 = feb_rec(n-2)
result_rec_1 = feb_rec(n-1)

# 计算组合值
result_rec = result_rec_1 + result_rec_2

# 返回组合值
return result_rec

例题2. 计算LCA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def LCA(node, p, q):
result = None
def LCA_rec(node):
# 定义基础值 [None, p, q]
if not node or node is p or node is q:
return node

# 递归锚点
left = LCA_rec(node.left)
right = LCA_rec(node.right)

# 计算组合值
# [None & None] => None
# [None & [p | q]] => [p | q]
# [p & q] => LCA
if left is None and right is None:
return None
elif (left is p and right is q) or (left is q and right is p):
result = node
return node
elif left is None:
return right
else:
return left

LCA_rec(root, p, q)
return result

例题3. 判断镜像树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def mirror(root):
def mirror_rec(left, right):
# 基础值
if not left and not right:
return True
elif not left or not right:
return False
if left.val != right.val:
return False

# 递归锚点
mirror_out = mirror_rec(left.left, right.right)
mirror_in = mirror_rec(left.right, right.left)

# 计算组合值
mirror = mirror_out and mirror_in

return mirror

1. 链表倒转

链表 [1, 3, 5, 9, 10] 递归倒转为 [10, 9, 5, 3, 1]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def rec(node: ListNode):
# 终止条件写在前面
next_node = node.next
if not next_node:
return

# 递归点
rec(next_node)

# 后序
next_node.next = node

# 调整
if node is header:
next_node = None

2. 镜像树判断

二分 binary

二分是建立在以排序序列上的搜索算法

采用左必右开的习惯

需要注意的点:

  1. 何时停止

  2. 非严格排序时控制搜索位置

1
2
3
4
5
6
7
8
9
def left_binsearch(nums, left, right, value):
"""非降序列第一个不小于value的位置."""
while left < right:
mid = left + (right-left)//2
if nums[mid] < value:
left = mid + 1
else:
right = mid
return left
1
2
3
4
5
6
7
8
9
def right_binsearch(nums, left, right, value):
"""非降序列第一个大于value的位置."""
while left < right:
mid = left + (right-left)//2
if nums[mid] <= value:
left = mid + 1
else:
right = mid
return left

杂项 misc

二维数组

二维数组使用r,c坐标系

1. 二维数组的快速初始化

1
grid = [[0] * width for _ in range(height)]

此时

height = len(grid) 与 row 相关

width = len(grid[0]) 与 col 相关

2. 二维数组遍历

  1. 正常遍历(行优先,从左向右,自上而下)
1
2
3
for r in range(len(grid)):
for c in range(len(grid[0])):
grid[r][c] = 0
  1. 遍历对角(r==c)
1
2
3
for r in range(len(grid)):
c = r
grid[r][c] = 0
  1. 遍历反对角(r+c==len-1)
1
2
3
for r in range(len(grid)):
c = len(grid)-1-r
grid[r][c] = 0
  1. 遍历对角方向(加上偏移)

3. 二维数组游走

1
2
d = {(+1, +1), (+1, -1), (-1, +1), (-1, -1)}
pos = (0, 0)

pyhthon specific

高效倍增与折半

1
2
i <= 1
i >= 1 # 向下去整

判断奇偶

1
n & 1 == True # 奇数 否则为偶数

倍增

1
n <<= 1

减半并向下取整

通常用于完全树与堆的数组实现中

1
n >>= 1

数据结构

栈是简化的树

栈通常与递归联系起来

单调栈

单调栈是栈的加强,在栈的基础上要求栈内部的元素有序。

单调栈的用途并不广泛,而是集中处理一系列问题。

预先定义好顺序信息后,单调只需要两个元素即可破坏。因此当某元素需要入栈时,需要检测当前元素与栈顶元素,判断是否破坏了栈内部的单调性质,如果破坏了,那么说明栈顶元素需要抛出。重复检查,直到不破坏栈内部的单调性,再将元素入栈。

和栈不同,单调栈通常并不和递归联系起来。

单调栈实践

1
2
3
4
5
stack = []
for i in range(n):
while (len(stack) != 0 and stakc[-1] > nums[i]):
stack.pop()
stack.append(nums[i])

并查集

并查集通常用于处理等价类

1. 并查集的字典实现

字典结构 + 函数式

前提:这些类(数学意义上的类)要可以hash,否则实现上做不了字典的key

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 初始化
uf = {item: -1 for item in [1, 2, 3, 4, 5, 6, 7]} 负数表示此为根,且绝对值表示此类的大小

# 查 根
def find(uf: dict, one: int) -> int:
while uf[one] > 0:
one = uf[one]
return one

# 并 切记需要并根元素
def union(uf: dict, one, other) -> None:
# 查根
root_one = find(uf, one)
root_other = find(uf, other)

# 判断/ 也是一种逻辑切分
if root_one == root_other:
return

# 将小树并到大树上,尽量保持平衡
if uf[root_one] < uf[root_other]:
root_one, root_other = root_other, root_one

uf[root_other] += uf[root_one]
uf[root_one] = root_other

2. 并查集的类实现

类 + 对象式

TODO

1

堆的数组实现

基于堆是完全树的拓扑结构;

数组从1开始索引;索引0位存储堆的当前存数量;

数组索引为n的节点,其父节点索引为n//2,左节点为2n,右节点为2n+1;

最后一个节点的索引heap[-1]

这里是最小堆的实现, 因此要将大的数下沉

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# 初始化
heap = [0] * 50

# 获取节点数量
def get_num(heap):
return heap[0]

# 下沉 pos 位置的元素
def sink(heap, pos):
# 获取当前节点值, 左右节点值
current, left, right = heap[pos], heap[2*pos], heap[2*pos+1]

# 当前值最小, 停止下沉
if current <= left and current <= right:
return

# 当前值最大, 挑选左右较小节点
if current > left and current > right:
small_pos = 2*pos if heap[2*pos] < heap[2*pos+1] else 2*pos+1
# 否则若左节点较小
elif current > left:
small_pos = 2*pos
# 否则必定右节点较小
else:
small_pos = 2*pos+1

heap[pos], heap[small_pos] = heap[small_pos], heap[pos]

# 递归点
sink(heap, small_pos)

# 上浮操作
def up(heap, pos):
# 获取父节点
parent_pos = pos // 2

# 若父节点是元节点, 或父节点比当前节点小
if parent_pos == 0 or heap[parent_pos] <= heap[pos]:
return

# 交换父子节点
heap[parent_pos], heap[pos] = heap[pos], heap[parent_pos]

# 递归点
up(heap, parent_pos)

# 取元素
def pop(heap):
# 先判断是否有元素
if heap[0] <= 0:
return None

# 取出堆顶元素
value = heap[1]

# 将最后元素置于堆顶
heap[1], heap[heap[0]] = heap[heap[0]], 0
# 元素计数 -1
heap[0] -= 1
# 堆顶元素下沉
sink(heap, 1)
return value

# 放入元素
def push(heap, value):
# 元素计数 +1
heap[0] += 1
# 将元素放在最后
heap[heap[0]] = value
# 最后元素上浮
up(heap, heap[0])

堆的树实现

略, 使用双向树, 即节点不仅需要知道子节点, 还需要知道父节点

表示图的数据结构

邻接矩阵

邻接矩阵在处理一些代数相关的问题时,比较常用

邻接表

1
2
3
4
5
from collections import defaultdict
adj_table = defaultdict(list)

# 添加 a -> b, 权重为 w 的边
adj_table[a].append((b, w))

⚠️: 对于无向图, 切记初始化的时候不要忘记记录另一个方向的边

算法

排序

冒泡排序

最简单的一种排序方式之一

1
2
3
4
5
6
7
8
9
10
11
def sort(nums):
length = len(nums)

# 外循环控制冒泡次数
for out in range(length-1):

# 内循环控制冒泡操作
for inner in range(length-1-out):
if nums[inner] > nums[inner+1]:
nums[inner], nums[inner+1] = nums[inner+1], nums[inner]
return nums

选择排序

一种与插入排序对偶的排序方式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def sort(nums):
length = len(nums)

# 外循环控制未排序的范围
for left in range(length):

small_pos = left
# 内循环负责寻找最小值
for lst in range(left, length):
if nums[lst] < nums[small_pos]:
small_pos = lst

# 循环结束后将选择的最小索引与外循环控制的排序部分交换
nums[left], nums[small_pos] = nums[small_pos], nums[left]
return nums

插入排序

另一种最简单的排序方式之一

1
2
3
4
5
6
7
8
9
10
11
12
13
def sort(nums):
length = len(nums)

# 外循环逐个检查未排序元素
for current_index in range(length):
# 内循环负责将当前元素放到合适的位置
for sorted_index in range(current_index, 0, -1):
if nums[sorted_index-1] > nums[sorted_index]:
nums[sorted_index], nums[sorted_index-1] = nums[sorted_index-1], nums[sorted_index]
else:
break

return nums

希尔排序

基于插入排序的改良

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def sort(nums):
length = len(nums)
h = 1

while h < length/3: h = 3*h + 1

while h >= 1:
for i in range(h, length):
for j in range(i, h-1, -h):
if nums[j] < nums[j-h]:
nums[j], nums[j-h] = nums[j-h], nums[j]
else:
break
h = h // 3

return nums

归并排序

一种后递归排序

  1. V 形数据排序

  2. 锯齿数据排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
aux = [0] * len(nums)

def merge(nums, lo, mid, hi):
left, right = lo, mid+1

for k in range(lo, hi+1):
aux[k] = nums[k]

for k in range(lo, hi+1):
if left > mid:
nums[k] = aux[right]
right += 1
elif right > hi:
nums[k] = aux[left]
left += 1
elif aux[left] > aux[right]:
nums[k] = aux[right]
right += 1
else:
nums[k] = aux[left]
left += 1

def sort_u2d(nums):
def sort_rec(nums, lo, hi):
if lo >= hi:
return
mid = lo + (hi-lo)//2
sort_rec(nums, lo, mid)
sort_rec(nums, mid+1, hi)
merge(nums, lo, mid, hi)

length = len(nums)
sort_rec(nums, 0, length-1)
return nums

def sort_d2u(nums):
length = len(nums)
size = 1

while size < length:
for lo in range(0, length-size, 2*size):
merge(nums, lo, lo+size-1, min(lo+size-1, length-1))
size *= 2

return nums

快速排序

归并排序的一种改良

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def partition(nums, lo, hi):
left, right = lo, hi
v = nums[lo]

while left != right:
while nums[right] >= v and left < right:
right -= 1

while nums[left] <= v and left < right:
left += 1

if left < right:
nums[left], nums[right] = nums[right], nums[left]

nums[lo], nums[left] = nums[left], nums[lo]
return left

def _sort(nums, lo, hi):
if lo >= hi:
return
k = partition(nums, lo, hi)
_sort(nums, lo, k-1)
_sort(nums, k+1, hi)

def sort(nums):
random.shuffle(nums)
_sort(nums, 0, len(nums)-1)
return nums

另一种写法

1
2
3
4
5
6
7
8
9
10
def quick_sort(nums, l , r):
if l >= r: return
i, j = l, r
while i < j:
while nums[j] >= nums[l] and i < j: j -= 1
while nums[i] <= nums[l] and i < j: i += 1
nums[i], nums[j] = nums[j], nums[i]
nums[i], nums[l] = nums[l], nums[i]
quick_sort(nums, l, i - 1)
quick_sort(nums, i + 1, r)

作者:jyd
链接:https://leetcode.cn/problems/ba-shu-zu-pai-cheng-zui-xiao-de-shu-lcof/solution/mian-shi-ti-45-ba-shu-zu-pai-cheng-zui-xiao-de-s-4/
来源:力扣(LeetCode)
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

堆排序

一种动态的惰性排序

搜索

大多数常见的搜索算法依赖树(图)状数据结构

二分搜索

二分搜索利用了数据的顺序性质,这跳过了许多潜在的错误目标,以此提高效率。

左二分搜索

1
2
3
4
5
6
7
8
def bisect_left(sequence, target, left, right):
while left < right:
mid = (left + right) // 2
if sequenct[mid] < target:
left = mid + 1
else:
right = mid
return left

右二分搜索

1
2
3
4
5
6
7
8
def bisect_right(sequence, target, left, right):
while left < right:
mid = (left + right) // 2
if sequenct[mid] <= target:
left = mid + 1
else:
right = mid
return left

DFS 搜索

BFS 搜索

BFS 本质上为一种暴力搜索,但是搜索的模式顺序是固定的。

如果利用数据的某些数学性质进行剪枝,则可以大幅度优化BFS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from collections import deque
def bfs(init_state):
queue = deque([init_state])
step = 0
visited = set()

def get_next_state(current_state):
""""""
pass

while queue:
for _ in range(len(queue)):
current_state = queue.popleft()
if current_state is ok:
return step
for next_state in get_next_state(current_state):
if next_state not in visited:
visited.add(next_state)
queue.append(next_state)
return -1

双向BFS

BFS 的变体,可以减少搜索空间,同时压缩时间

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from collections import deque
def double_bfs(init_state, target_state):
queue_init = deque([init_state])
queue_target = deque([target_state])
step = 0
visited_forward = set()
visited_backward = set()

def get_next_state(current_state, direction):
pass

while not queue_init and not queue_target:
queue, direction = queue_init, "f" if len(queue_init) < len(queue_target) else queue_target, "b"
for _ in range(len(queue)):
current_state = queue.popleft()
if (direction == "f" and current_state in visited_backward) or (direction == "b" and current_state in visited_forward):
return step

for next_state in get_next_state(current_state, direction):
if direction == "f" and next_state not in visited_forward:
visited_forward.add(next_state)
elif direction == "b" and next_state not in visited_backward:
visited_backward.add(next_state)
queue.append(next_state)
return -1

动态规划

最长子序列

最长子串

最长递增子序列

子串最大和

0/1 背包问题

0/1 是指挑选物品,拿或者不拿,而拿做多只能拿一个

限制:背包限制,每个物品有体积,而背包容积有限

优化目标:每个物品具有某种效益,最大化总效益

SOLVE:

定义 dp[i][j] 为从前i个物品中选出总重量不超过j时总价值的最大值

已知 dp[i-1][j],现在物品范围从前i-1个扩大到前i个,

dp表中,每行为i,每列为j

内循环中,j,也就是背包容量逐渐增大

外循环中,i,也就是可选的物品范围逐渐增大

当下面临的问题是,当前新增的物品能不能装下,若能装下,是拿还是不拿

1
2
3
4
5
6
7
8
for i in range(1, n+1):
for j in range(1, W+1):
if w[i] > j:
dp[i+1][j] = dp[i][j]
else:
dp[i+1][j] = max(dp[i][j], v[i]+dp[i][j-w[i]])

return dp[n][W]

  1. LIS 问题

串算法

  1. 字符

假定一集合,该集合内元素个数有限且良序,则该集合可以称为一个字符集;
该集合内的所有元素称为字符。

  1. 字符串

给定一指定字符集,由该字符集内字符组成的有序排列,称为字符串。

字符串定式

  1. 索引
1
2
3
4
5
6
7
8
9
10
# 字符串首
s[0]

# 越界
s[len(s)]

# 字符串尾
s[len(s)-1]

s[-1]

字符串基础算法

双指针

  1. 同向快慢指针

  2. 对向指针

AC 树

dp

经典例题: 计算两字符串的编辑距离

前缀函数

前置定义

对于一个长度为n的字符串s,以及一个小于n的整数i,
记其前缀s[0:i]与后缀s[n-i:n]为一个前后缀组合

前缀函数定义

对于长度为n的字符串,记其前缀函数为$\pi(i)$
其定义域为[1, n] 之间的整数;

其定义为前缀s[0:i]的所有前后缀组合中,前后缀相等且长度最长的组合的长度。

朴素前缀函数计算
1
2
3
4
5
6
7
8
9
def prefix_function(s):
n = len(s)
pi = [0] * n
for i in range(1, n):
for j in range(i, -1, -1):
if s[0:j] == s[(i+1)-j:(i+1)]:
p[i] = j
break
return pi
前缀函数优化(一)

注意到 pi[i+1] 最多比 p[i]
大1,这时,当且仅当新增加的字符s[i+1]与s[p[i]-1+1]相同

因此,j从i位置逐渐扫描到0的过程中,超出p[i]部分的扫描是多余的。

1
2
3
4
5
6
7
8
9
def prefix_function(s):
n = len(s)
pi = [0] * n
for i in range(1, n):
for j in range(p[i-1]+1, -1, -1):
if s[0:j] == s[(i+1)-j:(i+1)]:
p[i] = j
break
return pi
前缀函数优化(二)

当新增字符s[i+1]与s[p[i]]不同时,我们的思路为找到s[0:i]中第二长的前后缀组合,再检测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def prefix_function(s):
n = len(s)
pi = [0] * n

for i in range(1, n):
j = pi[i-1]

while not (j <= 0 or s[i] == s[j]):
j = pi[j-1]

if s[i] == s[j]:
j += 1

pi[i] = j

return pi

KMP

后缀数组

倍增法计算后缀数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from itertools import zip_longest, islice

def to_int_keys(l):
"""
l: iterable of keys
returns: a list with integer keys
"""
index = {v: i for i, v in enumerate(sorted(set(l)))}
return [index[v] for v in l]

def suffix(to_int_keys, s):
n = len(s)
k = 1
ans = to_int_keys(s)
while k < n:
ans = to_int_keys(
list(zip_longest(ans, islice(ans, k, None),
fillvalue=-1)))
k <<= 1
return ans

suffix(to_int_keys, 'banana')

判断字符是否由重复字符组成

1
2
def repeated(s):
return s in (s+s)[1:-1]

判断字符串是否回文

1
2
def check(s):
return s[:] == s[::-1]

LCP 最长公共前缀

若干字符中查找最长的公共前缀,如果不存在,则返回’’

  1. 使用遍历

可以注意到

LCP(s1, s2, …, sn) = LCP(s1, LCP(s2, LCP(s3, LCP(…LCP(sn-1, sn)))))

这本质上是个 reduce 的关系,先写一个寻找两个字符串最大公共前缀的函数lcp
再将lcp用reduce函数作用[s1, s2, …, sn]的迭代列表上即可

某个字符串查询任意两后缀的最长公共子串

  1. LCP例子

  2. 使用倍增预先计算QMR数组,再用QMR数组推导结果

LIS 最长递增子序列

使用前缀数组,注意到前面的前缀LIS可以推导其后紧接着的前缀LIS
这是一个广义reduce问题

记前缀为q1, q2, …, qk,则有关系:

$$LIS(q_n) = max({LIS(q_i) | i < n & q_i[-1] < q_n[-1]}) + 1$$

之所以是广义reduce,
是指每一个后续指是有前面所有数值reduce出来的,而非狭义的相邻两个

最长递增子串

同上,为广义reduce问题

$$LIS(q_n) = q_n[1] > q_{n-1} ? LIS(q_{n-1}+1) : 1 $$

LCS 最长公共子序列

前缀数组+DP 问题

1
2
3
4
5
6
if i == 0 or j == 0:
LCS[i][j] = 0
elif x[i] == y[j]:
LCS[i][j] = LCS[i-1][j-1] + 1
else: # x[i] != y[j]
LCS[i][j] = max(LCS[i-1][j], LCS[i][j-1])

最长公共子串

前缀数组+DP 问题

1
2
3
4
5
6
7
if i == 0 or j == 0:
LCS[i][j] = 0
elif x[i] == y[j]:
LCS[i][j] = LCS[i-1][j-1] + 1
else: # x[i] != y[j]
LCS[i][j] = 0

子串性质

长串包含子串,记N(S, s) 为S中s的个数,若ls包含ss,则有N(S, ls) <= N(S, ss)

可以这样计算N(S, s):

1
2
def N(S, s):
return len(S.split(s))-1

一些技巧

对于数组 a = range(10)

1. 差分

差分数组为 diff[i] = a[i] - a[i-1], i 从 0 到 9,其中d[0] = 0

且有如下性质:

  1. 差分数组与原数组的关系

    d[i+1] 指 a[i] 变换到 a[i+1] 的增量;

    因此有:

    a[i] + d[i+1] + d[i+2] + … + d[j] = a[j]

    因此有:

    a[j] - a[i] = sum([d[i+1],…, d[j]])

    也就是说:

    a[i] 到 a[j] 的增量为 sum(d): (i, j] 左开右闭原则,中间的差分项为j-i个

  2. 差分数组增减与原数组对应的变换

    d[i] += n

    表示a[i] 及之后的数组全部断层平移n个单位;

    d[i] += n; d[j] -= n; j >= i;

    表示a[i] 到a[j]内的数组全部断层平移n个单位;包含a[i]但不包含a[j];左闭右开;

2. 前缀和

前缀和s[i] = sum(a): [0, i]

3. 倍增

倍增是一种思想,不是明确的算法。其核心的想法基于下面这个事实:

对于任意大于0的整数n,总是存在一个整数i,$2^{i-1}<n<2^i$
使得[1, n]之间的任何整数均可由${2^k: k\inZ&k[1,
i]}$中的若干元素加和而的,且每个元素至多取一次

实际上任何数字均可以这样对应,这也是二进制的原理。

其优点往往在于可以利用以往的结果来推演后续结果,而减少了重复计算;

比如后面的倍增法计算后缀数组

LCA 问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def LCA(node, left, right):
def LCA_rec(node):
if not node or node == left or node == right:
return node

node_left = LCA(left, left, right)
node_right = LCA(right, left, right)

if node_left is None:
return node_right
if node_right is None:
return node_left

return node

简单图算法

邻接表

数组实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 点数 与 边数
num_vtex = 5
num_edge = 8

# 需要+1, 这是0索引导致的
# start, end, weight, 中保存了所有边的信息
start = [-1] * num_edge # 起点
end = [-1] * num_edge # 终点
weight = [-1] * num_edge # 权重

# first 中存储每个顶点的第一条边
first = [-1] * num_vtex

# next 中存储某边的下一个边
next = [-1] * num_edge

# 读入边
for edge in range(num_edge):
# 边i的下一个边为上一个对应的起始边
next[edge] = first[start[edge]]
# 更新起始边
first[start[edge]] = edge

# 遍历顶点i的所有边
# 1. 找到顶点i的第一条边
k = first[i]
# 2. 遍历
while k != -1:
print(start[k], end[k], weight[k])
k = next[k]

# 遍历所有边
for i in range(num_vtex):
k = first[i]
while k != -1"
print(start[k], end[k], weight[k])
k = next[k]

字典实现

1
adj_table = {}
  1. 读入边
1
2
3
4
for start, end, weight in edges:
if start not in adj_table:
adj_table[start] = []
adj_table.append(end, weight)
  1. 遍历某个点的所有边
1
2
for end, weight in adj_table[start]:
do_some_thing(end, weight)

defaultdict实现

1
2
from collections import defaultdict
adj_table = defaultdict(list)
  1. 读入边
1
2
for start, end, weight in edges:
adj_table[start].append((end, weight))
  1. 遍历某个节点的邻接点

遍历节点s

1
2
for end, weight in adj_table[s]:
do_some_thing()

邻接矩阵

1
2
3
4
5
6
7
num_vtex = 20

adj_mat = [[INF] * num_vtex for _ in ragne(num_vtex)]

# 读入边
for start, end, weight in edges:
adj_mat[start][end] = weight

n次邻接矩阵

到达矩阵

邻接矩阵中元素仅有0、1;0表示无法一步到达,1表示可以一步到达。

前提:简单图
记方阵A_{nn}为这样的单步邻接矩阵,记A^k_{nn}为A_{nn}的k次矩阵幂,那么A^k_{nn}[i][j]则为k步从点i到点j的路径数目。

参考离散数学。

概率转移矩阵

参考随机过程。

快幂算法

快幂算法有点类似倍增法,数学基础都是二进制。

假设我们需要进行n次幂乘

  1. 找到一个k,使得2^k<=n<2^{k+1}
  2. 计算X^{(i)}=A^{2^i},其中i\in[0, k],需要注意的是,这里不断倍增即可,即X^{(i)}=X^{(i-1}*X^{(i-1)}
  3. 找到n的二进制表示,根据二进制表示用X^{(i)}组合出乘积的形式

最短路径

Floyd-Warshall 算法

求任意两点间的最短距离

使用邻接矩阵

1
2
3
4
5
for k in range(num_vtex):
for start in range(num_vtex):
for end in range(num_vtex):
if adj_mat[start][end] > adj_mat[start][k]+adj_mat[k][end]:
adj_mat = adj_mat[start][k] + adj_mat[k][end]

Dijkstra 算法

求某点到其余点的最短距离

将点分为两组,P表示已找到最短距离的点的集合,Q表示未找到最短距离的点的集合

  1. 初始化。将原点加入P中,并标记原点的最短距离为0;
  2. 迭代。在Q中找到距离原点最近的点,将其加入P中。对其所有出边进行松弛操作;
  3. 直到Q为空集,迭代结束

维护一张表,该表记录某个点是否为P集合点,以及前点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from math import inf

# 1. 初始化。假设源点为0

distance = [inf] * num_vtex
prefix = [None] * num_vtex
finished = set()
unfinish = set(vtex)

distance[0] = 0
finished.add(0)
unfinish.remove(0)

# 2. 迭代。

while unfinish:

# unfinish 中查找最小distance的点
min = inf
for vtex in unfinish:
if distance[vtex] < min:
min = distance[vtex]
start = vtex

finished.add(start)
unfinish.remove(start)

# 松弛该点的所有出边
for end, weight in adj_table[start]:
if distance[start] + weight < distance[end]:
distance[end] = dis[start] + weight
prefix[end] = start

Bellman-Ford 算法

求任意两点间的最小距离,可处理负权重边

用邻接表

1
2
3
4
for _ in range(num_vtex-1):
for edge in range(num_edge):
if dis[end[edge]] > dis[end[edge]] + weight[edge]:
dis[end[edge]] = dis[end[edge]] + weight[edge]

搜索

深度优先

判断两点是否连通,计算当前连通分量内的点数

1
2
3
4
5
6
7
8
9
10
11
12
visited = [False] //
total = 0

def dfs(graph, node):
if visited(node):
return

visited[node] = True
total += 1

for adj_node in graph.adj(node):
dfs(graph, adj_node)
计算连通分量

广度优先

需要使用一个队列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
total = 0
visited = [False] //

def bfs(graph, node):
queue = [node]

while queue:
for _ in range(len(queue)):
current_node = queue.popleft()

visited[current_node] = True

do_some_thing()

for adj_node in graph.adj(node):
if not visited[adj_node]:
queue.append(adj_node)

判断成环

dfs 判断

适用于有环图与无环图

不断深度遍历,若出现某节点遍历到已遍历节点,则说明有环,否则无环;可以利用n个点的信息提前结束遍历

Union-Find 判断 (更适合动态情况)

适用于无环图

遍历所有边,判断边两点的根结点

若不一致,则无环,将两点并起来

若根结点一致,则有环,结束遍历

拓扑排序 判断

适用于有环图

不断删除入度为0的节点,若最终全部删除完,则无环,否则有环

最小生成树

两种最小生成树的原理都是切分定理,Krustra处理边,Prime处理点

注意:n个点的最小生成树由n-1条边组成,可以用此提前结束循环

Krustra

取一个空集,不断从边中取出最小边,试图加入该集合

若使集合成环,则舍弃;否则加入集合

最终大小为n-1的边集合必定组成最小生成树

算法的关键在于判断成环

Prime

取一个空集,随机取一个点加入空集,作为初始数据

遍历该集合所有外连边,取最小边,加入该集合

最终大小为n的点集合的边必定组成最小生成树

算法的关键在于维护、遍历外连边

高级算法

AC 自动机

多叉树 + trie

概念

推荐b站的视频

python 实现

  1. 定义树节点
1
2
3
4
5
6
7
8
class Node:
def __init__(self, value: str, parent=None):
self.children = {}
self.value = value
self.faild = None
self.words = []
if not parent:
self.parent = self
  1. 扫描 pattern

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    def scan(patterns):
    root = Node('')
    for pattern in patterns:
    p = root
    for char in pattern:
    if char not in p.children:
    p.children[char] = Node(char, parent=p)
    p = p.children[char]
    p.words = len(pattern)
    return root
  2. 构建 failed 指针

1
2
3
4
5
6
7
8
9
10
11
12
13
def build(root):
queue = [root]
while queue:
for _ in range(len(queue)):
current_node = queue.pop(0)
prob_node = current_node.parent.failed
while prob_node.value != '':
if current_node.value in prob_node.children:
current_node.failed = prob_node.children[current_node.value]
break
prob_node = prob_node.failed
for child_key in sorted(current_node.children.keys()):
queue.append(current_node.children[child_key])
  1. 查询

AD 自动微分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
import math
import copy

class Expression:
_symbols = {}

@staticmethod
def updateSymbols(**kwargs):
while kwargs:
symbol_str, symbol_val = kwargs.popitem()
symbol_obj = Expression._symbols[symbol_str]
symbol_obj.value = symbol_val

def evalAndDeriveAt(self, variable, **kwargs):
self.updateSymbols(**kwargs)

def __add__(self, other):
if isinstance(other, (int, float)):
other = Constant(other)
return AddExpression(self, other)

def __sub__(self, other):
if isinstance(other, (int, float)):
other = Constant(other)
return SubExpression(self, other)

def __mul__(self, other):
if isinstance(other, (int, float)):
other = Constant(other)
return MultiExpression(self, other)

def __truediv__(self, other):
if isinstance(other, (int, float)):
other = Constant(other)
return DivExpression(self, other)

def __pow__(self, other):
if isinstance(other, (int, float)):
other = Constant(other)
return PowExpression(self, other)

def __radd__(self, other):
if isinstance(other, (int, float)):
other = Constant(other)
return AddExpression(other, self)

def __rsub__(self, other):
if isinstance(other, (int, float)):
other = Constant(other)
return SubExpression(other, self)

def __rmul__(self, other):
if isinstance(other, (int, float)):
other = Constant(other)
return MultiExpression(other, self)

def __rtruediv__(self, other):
if isinstance(other, (int, float)):
other = Constant(other)
return DivExpression(other, self)

def __rpow__(self, other):
if isinstance(other, (int, float)):
other = Constant(other)
return PowExpression(other, self)

@staticmethod
def sin(expression):
return SinExpression(expression)

@staticmethod
def cos(expression):
return CosExpression(expression)

@staticmethod
def log(expression):
return LogExpression(expression)

@staticmethod
def abs(expression):
return AbsExpression(expression)

@staticmethod
def exp(expression):
return ExpExpression(expression)

class SingleExpression(Expression):
def __init__(self, expression) -> None:
super().__init__()
self.expression = expression

def evalAndDeriveAt(self, variable, **kwargs):
super().evalAndDeriveAt(variable, **kwargs)
return self.expression.evalAndDeriveAt(variable, **kwargs)

class SinExpression(SingleExpression):
def evalAndDeriveAt(self, variable, **kwargs):
value, partial = super().evalAndDeriveAt(variable, **kwargs)
return Dual(math.sin(value), partial*math.cos(value))

def __repr__(self) -> str:
return f"sin({self.expression})"

class CosExpression(SingleExpression):
def evalAndDeriveAt(self, variable, **kwargs):
value, partial = super().evalAndDeriveAt(variable, **kwargs)
return Dual(math.cos(value), -partial*math.sin(value))

def __repr__(self) -> str:
return f"cos({self.expression})"

class LogExpression(SingleExpression):
def evalAndDeriveAt(self, variable, **kwargs):
value, partial = super().evalAndDeriveAt(variable, **kwargs)
assert value > 0
return Dual(math.log(value), partial/value)

def __repr__(self) -> str:
return f"log({self.expression})"

class AbsExpression(SingleExpression):
def evalAndDeriveAt(self, variable, **kwargs):
value, partial = super().evalAndDeriveAt(variable, **kwargs)
return Dual(math.abs(value), partial*math.sign(value))

def __repr__(self) -> str:
return f"|{self.expression}|"

class ExpExpression(SingleExpression):
def evalAndDeriveAt(self, variable, **kwargs):
value, partial = super().evalAndDeriveAt(variable, **kwargs)
return Dual(math.exp(value), partial*math.exp(value))

def __repr__(self) -> str:
return f"e ^ ({self.expression})"

class BinaryExpression(Expression):
def __init__(self, left, right) -> None:
self.left = left
self.right = right

def evalAndDeriveAt(self, variable, **kwargs):
super().evalAndDeriveAt(variable, **kwargs)
return self.left.evalAndDeriveAt(variable, **kwargs), self.right.evalAndDeriveAt(variable, **kwargs)

class AddExpression(BinaryExpression):
def evalAndDeriveAt(self, variable, **kwargs):
[vleft, dleft], [vright, dright] = super().evalAndDeriveAt(variable, **kwargs)
return Dual(vleft+vright, dleft+dright)

def __repr__(self) -> str:
return f"({self.left} + {self.right})"

class SubExpression(BinaryExpression):
def evalAndDeriveAt(self, variable, **kwargs):
[vleft, dleft], [vright, dright] = super().evalAndDeriveAt(variable, **kwargs)
return Dual(vleft-vright, dleft-dright)

def __repr__(self) -> str:
if self.left == 0:
return f"- {self.right}"
return f"({self.left} - {self.right})"

class MultiExpression(BinaryExpression):
def evalAndDeriveAt(self, variable, **kwargs):
[vleft, dleft], [vright, dright] = super().evalAndDeriveAt(variable, **kwargs)
return Dual(vleft*vright, vleft*dright+dleft*vright)

def __repr__(self) -> str:
return f"({self.left} * {self.right})"

class DivExpression(BinaryExpression):
def evalAndDeriveAt(self, variable, **kwargs):
[vleft, dleft], [vright, dright] = super().evalAndDeriveAt(variable, **kwargs)
assert vright != 0, ZeroDivisionError
return Dual(vleft/vright, (dleft*vright-vleft*dright)/(vright*vright))

def __repr__(self) -> str:
return f"({self.left} / {self.right})"

class PowExpression(BinaryExpression):
def evalAndDeriveAt(self, variable, **kwargs):
[vleft, dleft], [vright, dright] = super().evalAndDeriveAt(variable, **kwargs)
assert vright != 0, ZeroDivisionError
return Dual(math.pow(vleft, vright), (vright*math.pow(vleft, vright-1)*dleft))

def __repr__(self) -> str:
return f"({self.left} ^ ({self.right}))"

class Dual(Expression):
def __init__(self, value=0, partial=0) -> None:
self.value = value
self.partial = partial

def eval(self):
return self.value

def __iter__(self):
yield self.value
yield self.partial

def __repr__(self) -> str:
return str(self.value)

class Variable(Dual):
def __init__(self, symbol, value=0, partial=0) -> None:
super().__init__(value, partial)
self.symbol = symbol
Expression._symbols[symbol] = self

def evalAndDeriveAt(self, variable, **kwargs):
partial = 1 if variable is self else 0
value = self.value
return value, partial

def __repr__(self):
return self.symbol

class Constant(Dual):
def evalAndDeriveAt(self, variable, **kwargs):
return self.value, 0

def __repr__(self) -> str:
return str(self.value)

class Matrix(Expression):
def __init__(self, matdata) -> None:
self.data = []
first_row = len(matdata[0])
for row in matdata:
assert(len(row) == first_row)
current_row = [ele for ele in row]
self.data.append(current_row)
self.size = len(self.data), len(self.data[0]) # row, col

@property
def T(self):
col, row = self.size
data = [[0] * col for _ in range(row)]
for r in range(row):
for c in range(col):
data[r][c] = self.data[c][r]
return Matrix(data)

def __getitem__(self, index):
return self.data[index]

def __iter__(self):
return iter(self.data)

def __add__(self, other):
assert self.size == other.size
data = copy.deepcopy(self.data)
row, col = self.size
for r in range(row):
for c in range(col):
data[r][c] += self[r][c]
return Matrix(data)

def __sub__(self, other):
assert self.size == other.size
data = copy.deepcopy(self.data)
row, col = self.size
for r in range(row):
for c in range(col):
data[r][c] -= self[r][c]
return Matrix(data)

def __mul__(self, other):
assert isinstance(other, (int, float, Expression, Matrix))
if isinstance(other, Matrix):
left_row, left_col = self.size
right_row, right_col = other.size
assert left_col == right_row
data = [[0]*right_col for _ in range(left_row)]
for row in range(left_row):
for col in range(right_col):
data[row][col] = self[row][0] * other[0][col]
for index in range(1, left_col):
data[row][col] += self[row][index] * other[index][col]
return Matrix(data)

if isinstance(other, (int, float, Expression)):
data = copy.deepcopy(self.data)
row, col = self.size
for r in range(row):
for c in range(col):
data[r][c] *= other
return Matrix(data)

def __truediv__(self, other):
assert isinstance(other, (int, float, Expression))
if isinstance(other, (int, float)):
assert other != 0
return self * (1/other)

def __rmul__(self, other):
assert isinstance(other, (int, float, Expression))
return self * other

def __repr__(self):
msg = ""
for row in self:
msg += "|"
msg += ", ".join([repr(r) for r in row])
msg += "|\n"
return msg

def evalAndDeriveAt(self, variable, **kwargs):
row, col = self.size
data = [[0] * col for _ in range(row)]
value = [[0] * col for _ in range(row)]
partial = [[0] * col for _ in range(row)]
for r in range(row):
for c in range(col):
data[r][c] = self.data[r][c].evalAndDeriveAt(variable, **kwargs)
value[r][c] = data[r][c].value
partial[r][c] = data[r][c].partial
return Dual(Matrix(value), Matrix(partial))

class Vector(Matrix):
def __init__(self, matdata) -> None:
data = []
for ele in matdata:
data.append([ele])
super().__init__(data)

def __repr__(self):
return f"<Vector: {[d[0] for d in self.data]}>"

sin = Expression.sin
cos = Expression.cos
log = Expression.log
abs = Expression.abs
exp = Expression.exp

if __name__ == "__main__":

x = Variable('x')
y = Variable('y')
const_2 = Constant(2)

f = sin(const_2*x*y+x)
g = sin(2*x*y+x)

print(f)
print(g)

print(f.evalAndDeriveAt(x, x=5, y=10).partial)
print(g.evalAndDeriveAt(x, x=5, y=10).partial)

print()

TestVector = Vector([
x + y,
x * y
])
print()
print(TestVector)
print()
print(TestVector*TestVector.T)
print()
print(TestVector.T*TestVector)
print()
print((TestVector*TestVector.T).evalAndDeriveAt(x, x=4, y=2).value)
print((TestVector*TestVector.T).evalAndDeriveAt(x, x=4, y=2).partial)
print()
print((TestVector.T*TestVector).evalAndDeriveAt(y, x=4, y=2).value)
print((TestVector.T*TestVector).evalAndDeriveAt(y, x=4, y=2).partial)
print()

x1 = Variable('x1')
x2 = Variable('x2')
x3 = Variable('x3')

G = Vector([
3 * x1 + cos(x2*x3) - 3/2,
4 * x1 ** 2 - 625 * x2 ** 2 + 2 * x2 - 1,
exp(0-x1*x2) + 20 * x3 + (10*math.pi-3)/3
])

print()
print(G)
print()

F = (1/2 * G.T * G)[0][0]

print()
print(F)
print()

print(F.evalAndDeriveAt(x1, x1=5, x2=8, x3=10).value)
print(F.evalAndDeriveAt(x1, x1=5, x2=8, x3=10).partial)

LRU 缓存

LRU 缓存的另一种说法:保留插入顺序的字典

为了保证查找,插入的O(1)复杂度,我们使用字典作为缓存,此外使用双向链表来记录key的访问

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class DLinkedNode:
def __init__(self, key=0, value=0):
self.key = key
self.value = value
self.prev = None
self.next = None

class LRUCache:
def __init__(self, capacity: 'int'):
self.cache = {}
self.head = DLinkedNode()
self.tail = DLinkedNode()
self.head.next = self.tail
self.tail.next = self.head
self.capacity = capacity
self.size = 0

def get(self, key: 'int'):
if key not in self.cache:
return -1
node = self.cache[key]
self.move_to_head(node)
return node.value

def put(self, key: 'int', value: 'int'):
if key not in self.cache:
node = DLinkedNode(key, value)
self.cache[key] = node
self.add_to_head(node)
self.size += 1
if self.size > self.capacity:
removed = self.remove_tail()
self.cache.pop(removed.key)
self.size -= 1
else:
node = self.cache[key]
node.value = value
self.move_to_head(node)

def add_to_head(self, node):
node.prev = self.head
node.next = self.head.next
self.head.next.prev = node
self.head.next = node

def remove_node(self, node):
node.prev.next = node.next
node.next.prev = node.prev

def move_to_head(self, node):
self.remove_node(node)
self.add_to_head(node)

def remove_tail(self):
node = self.tail.prev
self.remove_node(node)
return node

MISC 一些重要的python细节

列表和字典key的查询方式并不相同,时间复杂度也不是一个量级

1
2
3
4
5
6
7
8
9
10
11
# 先构建列表,字典,集合
lst = [_ for _ in range(10000)]
dct = {_: None for _ in range(10000)}
st = set(lst)

%timeit 9999 in lst
99.1 µs ± 111 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit 9999 in dct
42.3 ns ± 0.346 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
%timeit 9999 in st
41.1 ns ± 0.0835 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

列表使用线性查找O(n),而字典使用hash查找O(1)

RMQ 问题

ST 算法

参考: RMQ问题

ST 算法先对数据进行预处理,为一种离线算法,适合数据不变情况,
动态情况考虑使用线段树

假设原始数据:

nums = [3, 4, 2, 2, 2, 4, 2, 0, 0, 0, 0, 1]

  1. 先计算预处理数组

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    length = int(math.log2(len(array))) + 1
    rmq = [[0]* length for _ in range(len(nums))]

    def rmq_init():
    for l, num in enumerate(nums):
    rmq[l][0] = num

    for r in range(1, length):
    for l in range(len(nums)):
    if l + (1 << (r-1)) >= len(nums):
    break
    rmq[l][r] = max(rmq[l, r-1],rmq[l + (1 << (r-1))][r-1] )
  2. 查询

    1
    2
    3
    def rmq_query(l, r):
    k = int(math.log2(r-l+1))
    return max(rmq[l][k], rmq[r-(1<<k)+1][k])
  • Copyrights © 2023-2024 Ivory
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信