a_template

本文最后更新于 2024年11月9日 晚上

板子

优雅的解法,少不了优雅的板子。

注:打 ()(*) 内容表示有待完善。

基础算法

二分

闭区间寻找左边界:

1
2
3
4
5
6
7
8
9
10
11
12
13
bool findLeft(int x) {
int l = 0, r = n - 1;
while (l < r) {
int mid = (l + r) >> 1;
// if (a[mid] < x) l = mid + 1;
if (缺了) {
l = mid + 1;
} else (刚好 | 超了) {
r = mid;
}
}
return a[r] == x;
}

闭区间寻找右边界:

1
2
3
4
5
6
7
8
9
10
11
12
13
bool findRight(int x) {
int l = 0, r = n - 1;
while (l < r) {
int mid = (l + r + 1) >> 1;
// if (a[mid] <= x) l = mid;
if (缺了 | 刚好) {
l = mid;
} else (超了) {
r = mid - 1;
}
}
return a[r] == x;
}

哈希

在 C++ 中,使用 std::unordered_map 时可能会因为哈希冲突导致查询、插入操作降低到 O(n)O(n),此时可以使用 std::map 进行替代,或者自定义一个哈希函数。

在 Python3 中,同理。但是 Python 不允许自定义哈希函数,此时可以尝试桶哈希。

1
2
3
4
5
6
7
8
9
10
11
12
13
template<class T>
struct CustomHash {
size_t operator()(T x) const {
static const size_t _prime = 0x9e3779b97f4a7c15;
size_t _hash_value = std::hash<T>()(x);
return _hash_value ^ (_hash_value >> 30) ^ _prime;
}
};

// 示例
std::unordered_map<int, int, CustomHash<int>> f1;
std::unordered_map<long long, int, CustomHash<long long>> f2;
std::unordered_map<std::string, int, CustomHash<long long>> f3;

数据结构

并查集

并查集虽然一般用来解决集合问题,但数据结构实现上本质是一个由多棵有向根树组成的森林。在采用了路径压缩和按秩合并后,每一次查询与插入的时间复杂度都会均摊为一个常数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class DisjointSetUnion {
/* 并查集类
集合元素定义为从 0 开始的整数。
*/

int sz; // 集合个数
std::vector<int> p; // p[i]表示第i个结点的祖宗编号
std::vector<int> cnt; // cnt[i]表示第i个结点所在集合中的结点总数

public:
DisjointSetUnion(int n) : p(n), cnt(n, 1) {
for (int i = 0; i < n; i++) {
p[i] = i;
}
sz = n;
}

int find(int x) {
if (p[x] != x) {
p[x] = find(p[x]);
}
return p[x];
}

void merge(int a, int b) {
int pa = find(a), pb = find(b);
if (pa != pb) {
p[pa] = pb;
cnt[pb] += cnt[pa];
sz--;
}
}

bool same(int a, int b) {
return find(a) == find(b);
}

int size() {
return sz;
}

int size(int a) {
int pa = find(a);
return cnt[pa];
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class DSU:
def __init__(self, n: int) -> None:
self.n = n
self.sz = n # 集合个数
self.p = [i for i in range(n)] # p[i]表示第i个结点的祖宗编号
self.cnt = [1 for i in range(n)] # cnt[i]表示第i个结点所在集合中的结点总数

def find(self, x: int) -> int:
if self.p[x] != x:
self.p[x] = self.find(self.p[x])
return self.p[x]

def merge(self, a: int, b: int) -> None:
pa, pb = self.find(a), self.find(b)
if pa != pb:
self.p[pa] = pb
self.cnt[pb] += self.cnt[pa]
self.sz -= 1

def same(self, a: int, b: int) -> bool:
return self.find(a) == self.find(b)

def size(self) -> int:
return self.sz

def size(self, a: int) -> int:
return self.cnt[a]

树状数组

利用更多的区间维护一个序列的信息,所有维护信息的区间组成的形状形如一棵树,故称为树状数组。

下方代码模板目前支持的操作有:

  • 区间查询:查询序列 [1, pos] 索引的元素之和。时间复杂度 O(logn)O(\log n)
  • 单点修改:修改序列 pos 索引的元素值。时间复杂度 O(logn)O(\log n)

更多内容见:https://oiwiki.org/ds/fenwick/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
template<class T>
class BinaryIndexedTree {
private:
int n;
std::vector<T> arr;

int lowbit(int x) {
return x & (-x);
}

public:
BinaryIndexedTree(int n) : n(n), arr(n + 1) {}

void update(int pos, T x) {
while (pos <= n) {
arr[pos] += x;
pos += lowbit(pos);
}
}

T query_sum(int pos) {
T ret = 0;
while (pos) {
ret += arr[pos];
pos -= lowbit(pos);
}
return ret;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class BinaryIndexedTree:
def __init__(self, n: int):
"""
初始化序列 O(n)。下标从 1 开始,初始化维护序列区间为 [1,n]。
"""
self.n = n
self.arr = [0] * (n + 1)

def update(self, pos: int, x: int) -> None:
"""
单点修改 O(log n)。在 pos 这个位置加上 x。
"""
while pos <= self.n:
self.arr[pos] += x
pos += self._lowbit(pos)

def query_sum(self, pos: int) -> int:
"""
区间求和 O(log n)。返回 [1,pos] 的区间和。
"""
ret = 0
while pos:
ret += self.arr[pos]
pos -= self._lowbit(pos)
return ret

def _lowbit(self, x: int) -> int:
return x & (-x)

SortedList *

例程:https://www.acwing.com/activity/content/code/content/8475415/

官方:https://github.com/grantjenks/python-sortedcontainers/blob/master/src/sortedcontainers/sortedlist.py

有序列表类。导入方法 from sortedcontainers import SortedList。可以类比 C++ 中的 map 类。共有以下内容,全部都是 O(logn)O(\log n) 的时间复杂度:

  1. add(value): 添加一个值到有序列表
  2. discard(value): 删除列表中的值(如果存在)
  3. remove(value): 删除列表中的值(必须存在)
  4. pop(index=-1): 删除并返回指定索引处的值
  5. bisect_left(value): 返回插入值的最左索引
  6. bisect_right(value): 返回插入值的最右索引
  7. count(value): 计算值在列表中的出现次数

数学

模运算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
template<class T>
T modPower(T a, T b, T p) {
// return: a^b % p
T res = 1 % p;
for (; b; b >>= 1, a = (a * a) % p) {
if (b & 1) {
res = (res * a) % p;
}
}
return res;
}

template<class T>
T modAdd(T a, T b, T p) {
// return: a+b % p
return ((a % p) + (b % p)) % p;
}

template<class T>
T modMul(T a, T b, T p) {
// return: a*b % p
T res = 0;
for (; b; b >>= 1, a = modAdd(a, a, p)) {
if (b & 1) {
res = modAdd(res, a, p);
}
}
return res;
}

template<class T>
T modSumOfEqualRatioArray(T q, T k, T p) {
// return: (q^0 + q^1 + ... + q^k) % p
if (k == 0) {
return 1;
}
if (k % 2 == 0) {
return modAdd<T>((T) 1, modMul(q, modSumOfEqualRatioArray(q, k - 1, p), p), p);
}
return modMul(((T) 1 + modPower(q, k / 2 + (T) 1, p)), modSumOfEqualRatioArray(q, k / 2, p), p);
}

质数筛

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
struct PrimesCount {
int n;
vector<int> pre, vis;
PrimesCount(int n) : n(n), pre(n + 1), vis(n + 1) {
eulerFilter();
}
void eulerFilter() {
// O(n)
vector<int> primes;
for (int i = 2; i <= n; i++) {
if (!vis[i]) {
primes.push_back(i);
pre[i] = pre[i - 1] + 1;
} else {
pre[i] = pre[i - 1];
}
for (int j = 0; primes[j] <= n / i; j++) {
vis[primes[j] * i] = true;
if (i % primes[j] == 0) {
break;
}
}
}
}
void eratosthenesFilter() {
// O(nloglogn)
for (int i = 2; i <= n; i++) {
if (!vis[i]) {
pre[i] = pre[i - 1] + 1;
for (int j = i; j <= n; j += i) {
vis[j] = true;
}
} else {
pre[i] = pre[i - 1];
}
}
}
void simpleFilter() {
// O(nlogn)
for (int i = 2; i <= n; i++) {
if (!vis[i]) {
pre[i] = pre[i - 1] + 1;
} else {
pre[i] = pre[i - 1];
}
for (int j = i; j <= n; j += i) {
vis[j] = true;
}
}
}
};

/* usage
PrimesCount obj(n); // construct an object
cout << obj.pre[n] << "\n"; // pre[i] means prime numbers in range of [1, i]
*/

乘法逆元

假设当前需要在 % p\% \ p 的情况下除以 aa,则可以转化为乘以 aa 的乘法逆元 a1a^{-1},即:

numanum×a1(mod p)其中 a1=ap2 当且仅当 a 与 p 互质\begin{aligned} &\frac{\text{num}}{a} \equiv \text{num} \times a^{-1} (\text{mod } p)\\ &\text{其中 } a^{-1} = a^{p-2} \text{ 当且仅当 $a$ 与 $p$ 互质} \end{aligned}

对于任意 aa 的整数倍 tt,一定有下式成立:其中的 xx 就是整数 aa 的乘法逆元,记作 a1a^{-1}

tat×x(modp)1a1×x(modp)1a×x(modp)\begin{aligned}\frac{t}{a} \equiv t \times x\quad (\mod p) \\\frac{1}{a} \equiv 1 \times x\quad (\mod p) \\1 \equiv a \times x\quad (\mod p) \\\end{aligned}

费马小定理:对于两个互质的整数 g,hg,h 而言,一定有下式成立:

gh11(modh)g^{h-1} \equiv 1\quad (\mod h)

于是本题的推导就可以得到,当 aapp 互质时,有:

ap11(modp)a^{p-1} \equiv 1 \quad (\mod p)

于是 aa 的乘法逆元就是:

a1=ap2a^{-1} = a^{p-2}

时间复杂度 O(logp)O(\log p)

组合数

Cnk=C(n,k)=(nk)=n!k!(nk)!C_n^k = C(n, k) = \binom{n}{k} = \frac{n!}{k!(n-k)!}

Python3.8 库函数求解

如果使用 Python3.8 及以上的版本,则可以直接使用 math.comb(n, k) 来计算组合数 CnkC_n^k。时间复杂度为 O(min(k,nk))O(\min(k,n-k))

递推法求解

利用 Cnk=Cn1k+Cn1k1C_n^k = C_{n-1}^k + C_{n-1}^{k-1} 进行递推求解。

例题:https://www.acwing.com/problem/content/887/。求解 qqCnk % pC_{n}^k\ \%\ p 的结果,其中 q104,1kn2×103q\le 10^4,1\le k \le n \le 2\times 10^3pp 为常数 109+710^9+7

O(nk)O(nk) 预处理出所有的组合数,O(q)O(q) 查询 qq 次组合数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include <iostream>
using namespace std;

const int N = 2000;
const int K = 2000;
const int P = 1e9 + 7;

int C[N + 1][K + 1];

int main() {
// O(nk) 预处理
for (int a = 0; a <= N; a++) {
for (int b = 0; b <= a; b++) {
if (b == 0) {
C[a][b] = 1;
} else {
C[a][b] = (C[a - 1][b] + C[a - 1][b - 1]) % P;
}
}
}

// O(1) 查询
int q;
cin >> q;
while (q--) {
int n, k;
cin >> n >> k;
cout << C[n][k] << "\n";
}

return 0;
}

乘法逆元法求解

如果题目中有取模运算,就可以将组合数公式中的「除法运算」转换为「关于逆元的乘法运算」进行求解。

例题:https://www.acwing.com/problem/content/888/。求解 qqCnk % pC_{n}^k\ \%\ p 的结果,其中 q104,1kn105q\le 10^4,1\le k \le n \le 10^5pp 为常数 109+710^9+7。此题中需要对组合数 CnkC_n^k 的计算结果模上常数 pp,由于此题的模数 ppn,kn,k 一定互质,因此才可以采用将除法转换为乘法逆元的预处理做法来求解。如果仍然采用上述递推法将会超时。

O(nlogp)O(n\log p) 预处理出所有的阶乘和乘法逆元,O(q)O(q) 查询 qq 次组合数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <iostream>

using namespace std;
using ll = long long;

const int N = 1e5;
const int P = 1e9 + 7;

int fact[N + 1]; // fact[i] 表示 i 的阶乘
int infact[N + 1]; // infact[i] 表示 i 的阶乘的逆元

int qmi(int a, int b, int p) {
int res = 1 % p;
while (b) {
if (b & 1) res = (ll) res * a % p;
a = (ll) a * a % p;
b >>= 1;
}
return res;
}

int main() {
// O(n log p) 预处理
fact[0] = 1, infact[0] = 1;
for (int a = 1; a <= N; a++) {
fact[a] = (ll) fact[a - 1] * a % P;
infact[a] = (ll) infact[a - 1] * qmi(a, P - 2, P) % P;
}

// O(1) 查询
int q;
cin >> q;
while (q--) {
int n, k;
cin >> n >> k;
cout << (ll) fact[n] * infact[k] % P * infact[n - k] % P << "\n";
}

return 0;
}

字符串

sstream

控制中间结果的运算精度

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <iostream>
#include <iomanip>
#include <sstream>

using ll = long long;
using namespace std;

// 控制中间结果的运算精度
void solve() {
double x = 1.2345678;
cout << x << "\n"; // 输出 1.23457

stringstream ss;
ss << fixed << setprecision(3) << x;
cout << ss.str() << "\n"; // 输出 1.235
}

计算几何

浮点数默认输出 6 位,范围内的数据正常打印,最后一位四舍五入,范围外的数据未知。


a_template
https://blog.dwj601.cn/Algorithm/a_template/
作者
Mr_Dwj
发布于
2024年3月21日
更新于
2024年11月9日
许可协议