算法竞赛进阶指南0x0101

hicancan的算法笔记0x0101

洛谷P1226【模板】快速幂

题目描述

给你三个整数 a,b,pa,b,p,求 abmodpa^b \bmod p

输入格式

输入只有一行三个整数,分别代表 a,b,pa,b,p

输出格式

输出一行一个字符串 a^b mod p=s,其中 a,b,pa,b,p 分别为题目给定的值, ss 为运算结果。

样例

样例输入

2 10 9

样例输出

2^10 mod 9=7

提示

样例解释

210=10242^{10} = 10241024mod9=71024 \bmod 9 = 7

数据规模与约定

对于 100%100\% 的数据,保证 0a,b<2310\le a,b < 2^{31}a+b>0a+b>02p<2312 \leq p \lt 2^{31}

题目分析

思路一:直接计算

直接计算 abmodpa^b \bmod p,然后输出结果。
代码如下:

#include <iostream>
using namespace std;
int main()
{
    long long a, b, p;
    cin >> a >> b >> p;
    long long result = 1;
    for (int i = 0; i < b; i++)//一个一个乘
        result = (result * a) % p;
    cout << a << "^" << b << " mod " << p << "=" << result << endl;
    return 0;
}

测试点结果:
202411032316711.png
最后三个测试点TLE了

分析原因:

一个一个乘以a,时间复杂度为:O(b)O(b),当b很大的时候,时间复杂度会很大,因此我们需要优化。

没有利用到中间结果,比如如果算到了a100a^{100},完全可以自乘平方,直接得到a200a^{200}

从下至上的想的话,从头开始就自乘:也就是知道了a,就能知道a2a^{2}a4a^{4}a8a^{8}a16a^{16},最终知道:a2na^{2^n}

但是:很遗憾2n2^{n}不一定等于bb,所以我们需要找到最接近bb2n2^{n},也就是2[log2(b)]2^{\left[log2(b)\right]},然后剩下的指数(b2[log2(b)])(b-2^{\left[log2(b)\right]})一个一个乘。
由此产生了思路二:

思路二:自乘平方递归

自乘平方并用2[log2(b)]2^{\left[log2(b)\right]}找到不超过b的最大2的幂次的指数

代码实现:

#include <iostream>
#include <cmath>
using namespace std;
int main()
{
    long long a, b, p;
    cin >> a >> b >> p;
    long long result = a;
    int mindex = floor(log2(b)); // 求出不大于b的最大2次幂:mindex
    for (int index = 1; index <= mindex; index++) // 自乘平方直至指数到2的mindex次方
        result = (result * result) % p;
    for (int i = 0; i < b - (1 << mindex); i++) // 剩下的指数一个一个乘
        result = (result * a) % p;
    cout << a << "^" << b << " mod " << p << "=" << result << endl;
    return 0;
}

测试点结果:

202411032316085.png

仍然有俩个测试点TLE了

分析原因:
剩下的指数还是一个一个乘,而剩下的指数个数为:$$2^{\left[log2(b)\right]}$$
如果b比较大,那么比b小的2的最大幂次与b差距的可能的差距的最大值就会越来越大,因此我们利用递归的思想,将剩下的指数作为新的b再进行自乘平方,递归求解。

优化:递归剩下的指数

#include <iostream>
#include <cmath>
using namespace std;
// 自乘递归求幂
long long quickpow(long long a, long long b, long long p)
{
    if (b == 0)
        return 1;//递归出口:指数为0,返回1
    long long result = a;
    int mindex = floor(log2(b));
    for (int index = 1; index <= mindex; index++)
        result = (result * result) % p;
    b = b - (1 << mindex);//更新b,准备递归
    return (result * quickpow(a, b, p)) % p;
}
int main()
{
    long long a, b, p;
    cin >> a >> b >> p;
    cout << a << "^" << b << " mod " << p << "=" << quickpow(a, b, p) << endl;
    return 0;
}

测试点结果:
202411032317333.png
成功AC了!
不过用了log2()log_{2}()函数,可能还能对log2()log_{2}()函数进行改变,比如使用位运算去计算不超过b的最大2的幂次,也就是找到b的二进制最高位。

尝试:位运算找b二进制最高位,替代cmath库中log2函数

#include <iostream>
using namespace std;
// 自乘递归求幂
long long quickpow(long long a, long long b, long long p)
{
    if (b == 0)
        return 1; // 递归出口:指数为0,返回1
    long long result = a;
    int mindex = 0;
    while ((b >> mindex) != 1)
        mindex++; // 逐步右移直到等于1
    for (int index = 1; index <= mindex; index++)
        result = (result * result) % p;
    b = b - (1 << mindex); // 更新b,准备递归
    return (result * quickpow(a, b, p)) % p;
}
int main()
{
    long long a, b, p;
    cin >> a >> b >> p;
    cout << a << "^" << b << " mod " << p << "=" << quickpow(a, b, p) << endl;
    return 0;
}

测试点结果:
202411032317766.png

突发奇想:从找2[log2(b)]2^{\left[log2(b)\right]},转换成不大于b的最大2的幂次呢?

#include <iostream>
using namespace std;
// 求不大于b的最大2的幂次的快速算法(按位或运算快速1的扩散)
int maxpowerof2(long long x)
{
                         // 0010 1100 0000 0000 0000 0000 0000 0000 0000 0001
    x = x | (x >> 1);    // 0011 1110 0000 0000 0000 0000 0000 0000 0000 0000
    x = x | (x >> 2);    // 0011 1111 1000 0000 0000 0000 0000 0000 0000 0000
    x = x | (x >> 4);    // 0011 1111 1111 1000 0000 0000 0000 0000 0000 0000
    x = x | (x >> 8);    // 0011 1111 1111 1111 1111 1000 0000 0000 0000 0000
    x = x | (x >> 16);   // 0011 1111 1111 1111 1111 1111 1111 1111 1111 1111
    x = x | (x >> 32);   // 0011 1111 1111 1111 1111 1111 1111 1111 1111 1111
    return (x + 1) >> 1; // 0100 0000 0000 0000 0000 0000 0000 0000 0000 0000
                         // 0010 0000 0000 0000 0000 0000 0000 0000 0000 0000
}
// 自乘递归求幂
long long quickpow(long long a, long long b, long long p)
{
    if (b == 0)
        return 1; // 递归出口:指数为0,返回1
    long long result = a, btemp = b;
    int mindex = maxpowerof2(btemp);
    for (int index = 2; index <= mindex; index <<= 1)
        result = (result * result) % p;
    b = b - mindex; // 更新b,准备递归
    return (result * quickpow(a, b, p)) % p;
}
int main()
{
    long long a, b, p;
    cin >> a >> b >> p;
    cout << a << "^" << b << " mod " << p << "=" << quickpow(a, b, p) << endl;
    return 0;
}

测试点结果:
202411032318330.png
结果居然有一道题TLE了,说明存在不够快的问题,不过这个算法的思想还是值得学习的,比如利用按位或运算进行1的扩散快速求不大于b的最大2的幂次。
不够快的原因是什么呢?试验:将quickpow部分改成如下代码测试:

long long quickpow(long long a, long long b, long long p)
{
    if (b == 0)
        return 1; // 递归出口:指数为0,返回1
    long long result = a, btemp = b;
    int mindex = log2(maxpowerof2(btemp));
    for (int index = 1; index <= mindex; index++)
        result = (result * result) % p;
    b = b - (1 << mindex); // 更新b,准备递归
    return (result * quickpow(a, b, p)) % p;
}

测试点结果:
202411032318720.png
居然都AC了,说明不是maxpowerof2的算法不够快,而是每次index <<= 1相比于index++慢,多次循环后导致超时。

思路二小结

重新观察思路二发现主要的算法优化集中于剩下的b2[log2(b)]b-2^{\left[log2(b)\right]}个的指数上,然后还发现算法的优化与二进制以及位运算息息相关,于是思考能否不让指数剩下,也就是把指数b用aaa2a^{2}a4a^{4}a8a^{8}a16a^{16},……,a2na^{2^n} 充分的表示出来,而每一个都可以选或者不选对应着 a1×2na^{1\times 2^n} 或者 a0×2n=1a^{0\times 2^n}=1(1在乘积中能维持式子值不变对应着不选该项),那么如何把b用这2种项表示出来呢?思路三呼之欲出。

思路三:二进制表示指数b,以拆分aba^{b}a(01)×2na^{(0或1)\times{2^n}}的乘积

#include <iostream>
using namespace std;
long long quickpow(long long a, long long b, long long p)
{
    long long result = 1;
    while (b > 0)
    {
        if (b & 1)
            result = (result * a) % p; // 如果b的当前位为1,则乘上a
        a = (a * a) % p; // a自乘,更换乘积因子
        b >>= 1; // b右移一位,准备下一次判断 b & 1
    }
    return result;
}
int main()
{
    long long a, b, p;
    cin >> a >> b >> p;
    cout << a << "^" << b << " mod " << p << "=" << quickpow(a, b, p) << endl;
    return 0;
}

测试点结果:
202411032318949.png
成功AC!不禁感叹,没有递归处理剩余的b2[log2(b)]b-2^{\left[log2(b) \right]}个指数的思路是多么的巧妙!究其原因是进制的唯一性,能够恰巧的将b精准的表示成a1×2na^{1\times 2^n} 或者 a0×2n=1a^{0\times 2^n}=1的乘积,从而避免了递归处理剩余的b2[log2(b)]b-2^{\left[log2(b) \right]}个指数。

从大思路来讲,我们从自下而上的乘的观念转变成了将b如何表示成这些乘的自上而下的视角,那么如果我们一开始没有一眼看到这种二进制的表示方法,那么我们将如何把b拆分呢?正过来是自乘平方,倒过来就是开方,换句话说是指数除以2,于是思路四呼之欲出。

思路四:指数除以2,2分递归

如果b为偶数,ab=ab/2×ab/2a^{b}=a^{b/2}\times a^{b/2},如果b为奇数,则再乘一个a,递归处理ab/2a^{b/2},直到b=0b=0,返回1,结束递归。当然也可以写成奇偶统一的形式:ab=ab/2×ab/2×ab%2a^{b}=a^{b/2}\times a^{b/2}\times a^{b\%2}。代码实现如下:

#include <iostream>
using namespace std;
long long quickpow(long long a, long long b, long long p)
{
    if (b == 0)
        return 1; // 递归出口:指数为0,返回1
    long long result = quickpow(a, b / 2, p); // 递归处理指数b/2
    result = (result * result) % p; // 指数为偶数
    if (b & 1) //位运算判断b是否为奇数
        result = (result * a) % p; // 指数为奇数,再乘一个a
    return result;
}
int main()
{
    long long a, b, p;
    cin >> a >> b >> p;
    cout << a << "^" << b << " mod " << p << "=" << quickpow(a, b, p) << endl;
    return 0;
}

测试点结果:
202411032318894.png
成功AC!看来自上而下的想法,即指数除以2,也是可行的,而且代码实现起来也很简单,不过这种思路的缺点是递归的深度太深:本质原因是未能一眼看出b如何充分乘积表示,而是用递归只看了一步二分表示,从而导致递归深度太深。那就多分一点,减少递归深度,于是思路五呼之欲出。

思路五:2分可以递归,3分4分呢?

既然2分可以递归,那么3分也可以递归,甚至n分也可以递归,然后递归处理ab/na^{b/n},直到b=0b=0,返回1,结束递归。代码实现如下:

#include <iostream>
using namespace std;
long long quickpow(long long a, long long b, long long p)
{
    if (b == 0)
        return 1;// 递归出口:指数为0,返回1
    long long result = quickpow(a, b / 3, p); // 递归处理指数b/3
    result = ((result * result) % p * result) % p;
    if (b % 3 == 1)
        result = (result * a) % p;
    if (b % 3 == 2)
        result = ((result * a) % p * a) % p;//小心每乘一次都要取模不然会wa
    return result;
}
int main()
{
    long long a, b, p;
    cin >> a >> b >> p;
    cout << a << "^" << b << " mod " << p << "=" << quickpow(a, b, p) << endl;
    return 0;
}

测试点结果
202411032318305.png
发现,三分比二分还快!原因就在于,三分的递归深度比二分浅,但是每次递归的计算量比二分多,在这两个方面的博弈下,三分比二分快一点点。当然,如果分更多,比如四分,那么递归深度会更浅,但是每次递归的计算量也会更多,所以分多少,我们还需要分析。
类似的四分的结果:
202411032318643.png
都AC了,而且比三分还快!说明在n分的n比较小的情况下,n越大,递归深度变浅的影响大于每次递归计算量变多的影响。

思路六:n分呢?对n进行试验

写成n分的代码如下:

#include <iostream>
using namespace std;
int n = 5;
long long quickpow(long long a, long long b, long long p)
{
    if (b == 0)
        return 1;// 递归出口:指数为0,返回1
    long long temp = quickpow(a, b / n, p); // 递归处理指数b/n
    long long result = 1;
    for (int i = 0; i < n; i++)
        result = (result * temp) % p;
    for (int i = 0; i < b % n; i++)
        result = (result * a) % p; // 处理指数b%n
    return result;
}
int main()
{
    long long a, b, p;
    cin >> a >> b >> p;
    cout << a << "^" << b << " mod " << p << "=" << quickpow(a, b, p) << endl;
    return 0;
}

测试点结果:
n=5
202411032319982.png
n=10
202411032319810.png
n=50
202411032319978.png
n=100
202411032320704.png
n=500
202411032321786.png
我们可以看到100分的表现还不错,但到500分的时候表现就不是很好了,说明分太多,递归深度变浅,但是每次递归计算量变多的影响也变大,导致整体效率下降。所以,分多少的最佳值n能不能提前预判呢?答案是肯定的,我们可以通过计算递归深度和每次递归计算量来预判n的最佳值。

n分递归的理论分析:最佳值n的提前预判计算

递归深度:floor(logn(b))floor(log_n(b))

每次递归计算量:nn

总的时间复杂度:O(nlogn(b))O(n\cdot log_n(b))
bnln(n)b\cdot\frac {n}{ln(n)}

因此希望时间复杂度最小,也就是求nln(n)\frac {n}{ln(n)}的最小值

因此求导得ln(n)1ln2(n)\frac {ln(n)-1}{ln^2(n)},令其为0,解得n=e,即n=2.718281828459045,取整为3,所以n的最佳理论值为3。从上面的测试结果来看,n=3时,效率确实相对较高。因此打破对2分的固有认知,事实上3分、4分效率比二分更高。

而因为联想到2分对应着二进制,3分对应着三进制,4分对应着四进制,所以n分对应着n进制,因此,我们可以将快速幂的指数进制表示算法扩展到任意进制的快速幂算法,即任意进制快速幂算法。

而且我们大胆推测,3进制的快速幂算法,应该比2进制更快,不妨试试。

思路七:3进制快速幂,看看是否比2进制快

任意进制快速幂算法的思路是,将指数b转换为n进制表示,然后从低位到高位,依次处理指数的每一位,计算abini1a^{b_i\cdot n^{i-1}},然后将这些结果相乘,得到aba^b。如果n=3,代码实现如下:

#include <iostream>
using namespace std;

long long quickpow(long long a, long long b, long long p)
{
    long long result = 1;
    while (b > 0)
    {
        if (b % 3 == 1) // 检查b的当前位(三进制下的个位)
            result = (result * a) % p;
        if (b % 3 == 2)
            result = ((result * a) % p * a % p) % p;
        a = ((a * a) % p * a) % p; // a=a*a*a,更换乘积因子
        b /= 3;                    // b除以3,准备下一次判断 b % 3
    }
    return result;
}
int main()
{
    long long a, b, p;
    cin >> a >> b >> p;
    cout << a << "^" << b << " mod " << p << "=" << quickpow(a, b, p) << endl;
    return 0;
}

测试点结果
202411032321718.png

成功的AC,而且和三分的效率相当,并且都比2进制或者二分快。这也正好符合我们刚刚的n分递归最佳n的理论分析。

思路八:n进制快速幂呢?不妨对n进行实验!

n进制快速幂代码如下:

#include <iostream>
using namespace std;
int n = 4; 
long long quickpow(long long a, long long b, long long p)
{
    long long result = 1;
    while (b > 0)
    {
        for (int i = 0; i < b % n; i++)//提取b的当前位(n进制下的个位)
            result = (result * a) % p;
        long long temp = a;
        for (int i = 0; i < n-1; i++)
            a = (a * temp) % p;//自乘n-1次,更换乘积因子
        b /= n;
    }
    return result;
}
int main()
{
    long long a, b, p;
    cin >> a >> b >> p;
    cout << a << "^" << b << " mod " << p << "=" << quickpow(a, b, p) << endl;
    return 0;
}

测试点结果:
n=2
202411032322885.png
n=3
202411032322150.png
n=4
202411032322046.png
n=5
202411032322635.png
不难发现,实验结果依然是n=4最快,与递归中4分递归也是最快的是一致的。
因此又一次论证了2不是最快的,理论和实践证明3比2快,实践证明4甚至更快!

总结

本质上,一共只有2个大思路:

  1. 第一个是自下而上的想法,先自乘平方,再对剩下的递归,由于没有把b精准的分解,而是有剩下的指数,因此代码的实现度较低。
  2. 第二个是自上而下的想法,而这个大思路中又分为两个小思路:
    1. n分递归法,运用递归我们不需要一步看清b的最终分解方式,只需要对指数n分递归即可,实现度较高。
    2. n进制快速幂法,运用进制转换,我们一步看清了b的最终分解方式,因此代码更简洁效率也高,思维度也较高。
      最终我们通过试验结合对复杂度O(nlogn(b))O(n\cdot log_n(b))的最值分析,证明了3、4分递归和或者3、4进制快速幂具有更快的速度,但是由于计算机中的数字是以2进制存储的,因此2进制快速幂的代码实现度最高更有利于用位运算操作,代码更简洁,因此常规的模版便是2进制快速幂!

算法竞赛进阶指南0x0101
https://hicancan.cn/2024/11/07/0x0101/
作者
hicancan
发布于
2024年11月7日
更新于
2024年11月16日
许可协议