快速幂

快速幂,二进制取幂(Binary Exponentiation,也称平方法),是一个在 Θ(logn)\Theta(\log n) 的时间内计算 ana^n 的小技巧,而暴力的计算需要 Θ(n)\Theta(n) 的时间。而这个技巧也常常用在非计算的场景,因为它可以应用在任何具有结合律的运算中。其中显然的是它可以应用于模意义下取幂、矩阵幂等运算,我们接下来会讨论。

算法描述

计算 aann 次方表示将n个n乘在一起: an=a×a×an 个 aa^{n} = \underbrace{a \times a \cdots \times a}_{n\text{ 个 a}} 。然而当a、n太大的时侯,这种方法就不太适用了。不过我们知道: ab+c=abac,a2b=abab=(ab)2a^{b+c} = a^b \cdot a^c,\,\,a^{2b} = a^b \cdot a^b = (a^b)^2 。二进制取幂的想法是,我们将取幂的任务按照指数的 二进制表示 来分割成更小的任务。

首先我们将n表示为 2 进制,举一个例子:

313=3(1101)2=3834313^{13} = 3^{(1101)_2} = 3^8 \cdot 3^4 \cdot 3^1

因为n有 log2n+1\lfloor \log_2 n \rfloor + 1 个二进制位,因此当我们知道了 a1,a2,a4,a8,,a2log2na^1, a^2, a^4, a^8, \dots, a^{2^{\lfloor \log_2 n \rfloor}} 后,我们只用计算 Θ(logn)\Theta(\log n) 次乘法就可以计算出 ana^n

于是我们只需要知道一个快速的方法来计算上述 3 的 2k2^k 次幂的序列。这个问题很简单,因为序列中(除第一个)任意一个元素就是其前一个元素的平方。举一个例子:

31=332=(31)2=32=934=(32)2=92=8138=(34)2=812=6561\begin{align} 3^1 &= 3 \\ 3^2 &= \left(3^1\right)^2 = 3^2 = 9 \\ 3^4 &= \left(3^2\right)^2 = 9^2 = 81 \\ 3^8 &= \left(3^4\right)^2 = 81^2 = 6561 \end{align}

因此为了计算 3133^{13} ,我们只需要将对应二进制位为 1 的整系数幂乘起来就行了:

313=6561813=15943233^{13} = 6561 \cdot 81 \cdot 3 = 1594323

将上述过程说得形式化一些,如果把n写作二进制为 (ntnt1n1n0)2(n_tn_{t-1}\cdots n_1n_0)_2 ,那么有:

n=nt2t+nt12t1+nt22t2++n121+n020n = n_t2^t + n_{t-1}2^{t-1} + n_{t-2}2^{t-2} + \cdots + n_12^1 + n_02^0

其中 ni0,1n_i\in{0,1} 。那么就有

an=(ant2t++n020)=an020×an121××ant2t\begin{aligned} a^n & = (a^{n_t 2^t + \cdots + n_0 2^0})\\\\ & = a^{n_0 2^0} \times a^{n_1 2^1}\times \cdots \times a^{n_t2^t} \end{aligned}

根据上式我们发现,原问题被我们转化成了形式相同的子问题的乘积,并且我们可以在常数时间内从 2i2^i 项推出 2i+12^{i+1} 项。

这个算法的复杂度是 Θ(logn)\Theta(\log n) 的,我们计算了 Θ(logn)\Theta(\log n)2k2^k 次幂的数,然后花费 Θ(logn)\Theta(\log n) 的时间选择二进制为 1 对应的幂来相乘。

代码实现

首先我们可以直接按照上述递归方法实现:

long long binpow(long long a, long long b) {
  if (b == 0) return 1;
  long long res = binpow(a, b / 2);
  if (b % 2)
    return res * res * a;
  else
    return res * res;
}

第二种实现方法是非递归式的。它在循环的过程中将二进制位为 1 时对应的幂累乘到答案中。尽管两者的理论复杂度是相同的,但第二种在实践过程中的速度是比第一种更快的,因为递归会花费一定的开销。

long long binpow(long long a, long long b) {
  long long res = 1;
  while (b > 0) {
    if (b & 1) res = res * a;
    a = a * a;
    b >>= 1;
  }
  return res;
}

例题

POJ 1001

题目

Description

Problems involving the computation of exact values of very large magnitude and precision are common. For example, the computation of the national debt is a taxing experience for many computer systems. This problem requires that you write a program to compute the exact value of Rn where R is a real number ( 0.0 < R < 99.999 ) and n is an integer such that 0 < n <= 25.

Input

The input will consist of a set of pairs of values for R and n. The R value will occupy columns 1 through 6, and the n value will be in columns 8 and 9.

Output

The output will consist of one line for each line of input giving the exact value of R^n. Leading zeros should be suppressed in the output. Insignificant trailing zeros must not be printed. Don't print the decimal point if the result is an integer.

Sample Input

95.123 12
0.4321 20
5.1234 15
6.7592  9
98.999 10
1.0100 12

Sample Output

548815620517731830194541.899025343415715973535967221869852721
.00000005148554641076956121994511276767154838481760200726351203835429763013462401
43992025569.928573701266488041146654993318703707511666295476720493953024
29448126.764121021618164430206909037173276672
90429072743629540498.107596019456651774561044010001
1.126825030131969720661201

解答

#include <iostream>
#include <string>
#include <algorithm>

//#define DEBUG 

using namespace std;


void plusString(string& str1, string str2) {
    
    reverse(str1.begin(), str1.end());
    reverse(str2.begin(), str2.end());

    str1.size() < str2.size() ? str1.resize(str2.size(), '0') : str2.resize(str1.size(), '0') ;
    
    int add = 0;
    for(int i = 0; i < str1.size(); ++i) {
        int out = str1[i] - '0' + str2[i] - '0' + add;
        str1[i] = out % 10 + '0';
        add = out / 10;
    }

    if(add != 0) str1.push_back(add + '0');

    reverse(str1.begin(), str1.end());
    return;
}


string format(string raw, int bits, bool is_integer) {
    if(is_integer) {
        for(;;) {
            if(raw[0] == '0') raw.erase(raw.begin());
            else break;
        }
    } else {
        raw.insert(raw.begin() + (raw.size() - bits), '.');
        for(;;) {
            if(raw[0] == '0') raw.erase(raw.begin());
            else break;
        }
        for(;;) {
            if(raw[raw.size() - 1] == '0') raw.erase(raw.end() - 1);
            else break;
        }
        if(raw[raw.size() - 1] == '.') raw.erase(raw.end() - 1);
    }
    if(raw == "") return "0";
    else return raw;
}


struct BigNum {
    string num_str;

    BigNum() : num_str("") {}

    BigNum(string num_str) : num_str(num_str) {}

    int size() const {
        return this->num_str.size();
    }

    char operator[] (int i) const {
        return this->num_str[i];
    }
};


BigNum operator* (const BigNum& num1, const BigNum& num2) {
    string ans("0");

    for(int i = num1.size() - 1; i >= 0; --i) {
        string temp;
        int add = 0;

        temp.resize(num1.size() - 1 - i, '0');

        for(int j = num2.size() - 1; j >= 0; --j) {
            int out = (num1[i] - '0') * (num2[j] - '0') + add;
            temp.push_back(out % 10 + '0');
            add = out / 10;
        }

        if(add != 0) temp.push_back(add + '0');

        reverse(temp.begin(), temp.end());
#ifdef DEBUG
        cout << "ans " << ans << " temp " << temp << endl;
#endif
        plusString(ans, temp);
    }

    return BigNum(ans);
}


int main() {
#ifdef DEBUG
    string a, b;
    cin >> a >> b;
    cout << (BigNum(a) * BigNum(b)).num_str << endl << endl;
#endif
    string R; int n;

    while(cin >> R >> n) {
        BigNum ori_sum("1");
        BigNum base; int mag;

#ifdef DEBUG
        cout << endl << "Handle " << R << " " << n << endl;
#endif
        bool is_integer;
        if(R.find('.') != R.npos) {
            is_integer = false;
            mag = R.size() - 1 - R.find('.');
            R.erase(R.begin() + R.find('.'));
            base.num_str = R;
        } else {
            is_integer = true;
            mag = 0;
            base.num_str = R;
        }

#ifdef DEBUG
        cout << "Filter " << base.num_str << " " << mag << endl;
#endif

        int n_backup = n;
        while(n > 0) {
#ifdef DEBUG
            cout << "ori_sum " << ori_sum.num_str << " base " << base.num_str << endl;
#endif
            if(n & 1) ori_sum = ori_sum * base;
            base = base * base;
            n >>= 1;
        }

        cout << format(ori_sum.num_str, mag*n_backup, is_integer) << endl;
    }

    return 0;
}

最后更新于