甲斐性のない男が機械学習とか最適化とかを綴るブログ

うどんくらいしか食べる気がしない

競プロとかに使うアルゴリズム実装メモ(二分探索、2次元累積和、しゃくとり法)

はじめに

アルゴリズムメモ第3段です。今回は二分探索法、2次元累積和、しゃくとり法と様々な問題に使える汎用的なアルゴリズムを書いていきます。
今回は勉強のため、アルゴリズムの本質的な部分を記述した抽象クラスと実際の問題を解く具象クラス(関数オブジェクト)を分け、テンプレートメソッドパターンもしくはストラテジパターンを用いてコーディングしました。

二分探索

用途

  • 広義単調増加関数または広義単調減少関数f(x)が条件を満たし、かつ、条件成立と不成立の境界となる点xを探索(条件成立と不成立の境界が1つのみ)
  • ソート済み配列から条件が成立する値が存在するかの判定

アルゴリズムのポイント

  1. 変数lを探索範囲の最小値、変数rを最大値で初期化
  2. mr,lの中点(例えば、m = (l + r)/2で求める)として、f(m)が条件を満たすか否かを判定
    • このとき、見つけたい点(条件成立と不成立の境界)がmより小さいか、大きいかはf(x)が広義単調増加(減少)関数なのでわかる
  3. 見つけたい点がm以上ならl = m、そうでなければr = mとする
  4. この2~3の操作をr - l \leq  \epsilon  を満たすまで繰り返す(\epsilonは収束判定基準の定数で、問題に応じて適切な値を定める)
  5. 収束したら、その時点のl,rから最終的な出力値を決める

抽象クラスの実装

上記アルゴリズムを実現する抽象クラスを作成し、見つけたい点が現時点以上かの判定、条件判定、中点計算、戻り値計算を純粋仮想関数とし、上記アルゴリズムをテンプレートメソッドパターンで実装。

#include <vector>

template <typename T>
class abstract_binary_search
{
private:
    virtual bool is_in_right(T x) = 0; //求めているのがxより大きいか
    virtual bool is_converge(T L, T R) = 0; //収束判定
    virtual T ret_val(T L, T R) = 0; //見つからなかった場合に返す値
    virtual T get_middle(T L, T R) = 0; //中央値を返す

protected:
    T search_rec(T L, T R);

public:
    virtual T search() = 0;
    virtual bool is_match(T x) = 0;  //それが求めている値か

    virtual ~abstract_binary_search();
};

template <typename T>
T abstract_binary_search<T>::search_rec(T L, T R) {
    if (is_converge(L, R) == true) {
        return ret_val(L, R);
    }

    T mid = get_middle(L, R);

    if (is_in_right(mid) == true) {
        return search_rec(mid, R);
    }

    return search_rec(L, mid);

}


template <typename T>
abstract_binary_search<T>::~abstract_binary_search()
{
}

具体的な問題を解く実装

Lower bownd

素数nの数列が与えられて、指定された値q以上の値が入っているindexを求める問題。もし数列内にq以上の値がなければnを返す。

#include <iostream>

class lower_bound :
    public abstract_binary_search<int>
{
private:
    std::vector<int> data_list;//検索する数列
    int N; //数列のサイズ
    int search_data;//検索対象データ
    virtual bool is_in_right(int x); //求めているのがxより大きいか
    virtual int ret_val(int L, int R); //見つからなかった場合に返す値
    virtual bool is_converge(int L, int R); //収束判定
    virtual int get_middle(int L, int R);
public:
    lower_bound(const std::vector<int>& data_list);
    void set_search_data(int a_search_data);
    virtual int search();
    virtual bool is_match(int x);  //それが求めている値か
    ~lower_bound();
};

bool lower_bound::is_in_right(int x) {
    return data_list[x] <= search_data;
}


int lower_bound::ret_val(int L, int R) {
    if (search_data <= data_list[L])
        return L;
    else
        return L + 1;
}

int lower_bound::get_middle(int L, int R) {
    return (L + R) / 2;
}

lower_bound::lower_bound(const std::vector<int>& data_list):
    data_list(data_list)
{
    N = data_list.size();
}

int lower_bound::search() {
    return search_rec(0, N);
}

bool lower_bound::is_match(int x) {
    if (x >= N) return false;
    return data_list[x] == search_data;
}

bool lower_bound::is_converge(int L, int R) {
    return R - L <= 1;
}

void lower_bound::set_search_data(int a_search_data) {
    search_data = a_search_data;
}

lower_bound::~lower_bound()
{
}

int main()
{

    int n, q;
    std::cin >> n;
    std::vector<int> S(n);
    
    for (int i = 0; i < n; i++) {
        int  s;
        std::cin >> s;
        S[i] = s;
    }
    std::cin >> q;

    int count = 0;
    
    lower_bound lbs(S);
    for (int i = 0; i < q; i++) {
        int  t;
        std::cin >> t;
        lbs.set_search_data(t);
        
        int res = lbs.search();
        std::cout << res << std::endl;
        if (lbs.is_match(res)) 
            count++;
    }
    std::cout << count << std :: endl;
    return 0;
}
POJ 1064:Cable master

蟻本に二分探索の例として載っている問題。問題はリンクを参照。xが切り取る紐の長さ、f(x)が作れる同じ長さ(長さ=x)の数だとすると、f(x)は広義単調減少関数なので二分探索法が使える。
つまり、f(x)が条件を満たしつつ、xが最大となる点を二分探索する(条件を満たさなければxを小さくし、満たすならまだ長くできると考えxを大きくする)。

#include <vector>
#include <algorithm>
#include <functional>
#include <iostream>
#include <iomanip>
#include <math.h>

class cable_master :
    public abstract_binary_search<double>
{
private:
    double tol;
    std::vector<double> len_list;
    int K;
    int N;
    bool is_condition(double x);
    virtual bool is_in_right(double x); //求めているのがxより大きいか
    //virtual bool is_in_left(double x); //求めているのがxより小さいか
    virtual double ret_val(double L, double R); //見つからなかった場合に返す値
    virtual double get_middle(double L, double R);
    virtual bool is_converge(double L, double R); //収束判定
public:
    cable_master(const std::vector<double>& len_list, int K, double tol);
    virtual double search();
    virtual bool is_match(double x);
    ~cable_master();
};

bool cable_master::is_condition(double x) {
    int count = 0;
    for (int i = 0; i < N; i++) 
        count += int(len_list[i] / x);

    if (count >= K)
        return true;

    return false;
}

bool cable_master::is_in_right(double x) {
    return is_condition(x);
}

double cable_master::ret_val(double L, double R) {
    return R;
}

double cable_master::get_middle(double L, double R) {
    return (L + R) / 2.0; 
}

bool cable_master::is_converge(double L, double R) {
    return R - L < tol;
}

cable_master::cable_master(const std::vector<double>& len_list, int K, double tol) :
    len_list(len_list),
    K(K),
    tol(tol)
{
    N = len_list.size();
}

double cable_master::search(){
    return search_rec(0.0, *std::max_element(len_list.begin(), len_list.end()));

}

bool cable_master::is_match(double x) {
    return false;
}

cable_master::~cable_master()
{
}



int main()
{

    std::vector<double> cable_list;
    int N, K;
    std::cin >> N >> K;

    for (int i = 0; i < N; i++) {
        double len;
        std::cin >> len;
        cable_list.push_back(len);
    }
    cable_master cm(cable_list, K, 0.0001);
    std::cout << std::fixed << std::setprecision(2) << floor(cm.search() * 100) / 100.0 << std::endl;
}

2次元累積和

用途

  • \displaystyle \sum_{i = p_x}^{q_x} \sum_{j = p_y}^{q_y}f(x_i, y_j)の計算

アルゴリズムのポイント

  1. あらかじめ、全てのn,mについて\displaystyle  S_{nm} = \sum_{i = 1}^{n} \sum_{j = 1}^{m} f(x_i, y_j)  を求めておく
    • これは、\displaystyle S_{nm} =f(x_n, y_m) + S_{n-1, m} + S_{n, m-1} - S_{n-1, m - 1}で効率的に計算できる
  2. 目的の値は\displaystyle  \sum_{i = p_x}^{q_x} \sum_{j = p_y}^{q_y}f(x_i, y_j) =  S_{q_x, q_y} - S_{q_x-1, q_y} - S_{q_x, q_y-1} + S_{q_x - 1, q_y-1}  で求まる

抽象クラスの実装

このアルゴリズムはストラテジーパターンで実装する。具体的には、f(x_i, y_j)を計算する関数オブジェクト受け取り、それに応じた部分和を求める。

#include <vector>
#include <functional>

template <typename T>
class func_sum
{
private:
    std::vector<std::vector<T>> sum_table;
protected:
    virtual T acc_sum(int i, int j);
public:
    func_sum(std::function<T(int, int)> func, int nx, int ny) ;
    ~func_sum();
    //目的の区間挿話を求める関数
    T calc_section_sum(int px, int qx, int py, int qy);
};

template <typename T>
T func_sum<T>::acc_sum(int i, int j) {
    return sum_table[i][j];
};

template <typename T>
func_sum<T>::~func_sum() {}

template <typename T>
func_sum<T>::func_sum(std::function<T(int, int)> func, int nx, int ny) {
    sum_table = std::vector<std::vector<T>>(nx + 1, std::vector<T>(ny + 1, T()));
    //j=0の場合の総和の計算
    for (int i = 1; i <= nx; i++) {
        sum_table[i][0] += func(i - 1, 0);
    }
    //i=0の場合の総和の計算
    for (int j = 1; j <= ny; j++) {
        sum_table[0][j] += func(0, j - 1);
    }
    //総和の計算
    for (int i = 1; i <= nx; i++) {
        for (int j = 1; j <= ny; j++) {
            sum_table[i][j] += func(i - 1, j - 1) +
                sum_table[i - 1][j] +
                sum_table[i][j - 1] -
                sum_table[i - 1][j - 1];

        }
    }
}

template <typename T>
T func_sum<T>::calc_section_sum(int px, int qx, int py, int qy) {
    return acc_sum(qx, qy) - 
           acc_sum(qx, py - 1) - 
           acc_sum(px - 1, qy) + 
           acc_sum(px - 1, py - 1);
}

具体的な問題を解く実装

\sum_{i = p_x}^{q_x} \sum_{j = p_y}^{q_y} \frac{x_i}{y_i}の計算

f(x_i, y_j)= \frac{x_i}{y_i}  として、上記アルゴリズムに関数オブジェクトを渡す。

int main()
{
    int N, M;
    int px, py, qx, qy;
    std::cin >> N >> M ;
    std::cin >> px >> qx >> py >> qy;
    std::vector<double> x_list(N, 0);
    std::vector<double> y_list(M, 0);
    for (int i = 0; i < N; i++) {
        double x;
        std::cin >> x;
        x_list[i] = x;
    }
    for (int i = 0; i < M; i++) {
        double y;
        std::cin >> y;
        y_list[i] = y;
    }
    auto func = [](int i, int j, 
                   const std::vector<double>& x_list,
                   const std::vector<double>& y_list) { return x_list[i] / y_list[j]; };

    std::function<double(int, int)>  func_curried = std::bind(func, 
                                                              std::placeholders::_1, 
                                                              std::placeholders::_2, 
                                                              x_list, 
                                                              y_list);

    func_sum<double> ds(func_curried, x_list.size(), y_list.size() );
    std::cout<<ds.calc_section_sum(px, qx, py, qy)<<std::endl;
    return 0;
}
ABC106 D:AtCoder Express 2

問題はリンクを参照。この問題はf(x_i, y_j)  x_i = L_i, y_i = R_iのときそのペア(L_i, R_i)の存在する個数を返す関数とすれば2次元累積和で解ける。

#include <iostream>
int main()
{
    int N, M, Q;
    std::cin >> N >> M >> Q;
    std::vector<std::vector<int>> table(N, std::vector<int>(N, 0));
    for (int i = 0; i < M; i++) {
        int L, R;
        std::cin >> L >> R;
        table[L - 1][R - 1]++;
    }

    auto func = [](int i, int j,
        const std::vector<std::vector<int>>& table) { return table[i][j]; };

    std::function<int(int, int)>  func_curried = std::bind(func,
        std::placeholders::_1,
        std::placeholders::_2,
        table);

    func_sum<int> ds(func_curried, N, N);

    for (int i = 0; i < Q; i++) {
        int p, q;
        std::cin >> p >> q;
        std::cout << ds.calc_section_sum(p, q, p, q) << std::endl;
    }
    return 0;
}

しゃくとり法

用途

  • 解きたい問題が以下のいずれかの性質を満たす時、条件C(l, r)を満たすl,rの数え上げ、もしくは評価関数f(l,r)の最大化・最小化
    • l_p \leq l\leq r \leq r_qとして、C(l, r)が条件を満たすならばC(l_p, r_q)も条件を満たす(以降、これを上位区間条件と呼ぶ)
      • 例えば、整数の数列x_1, x_2, \cdots, x_nの中からその和がS以上となる連続する部分数列を数え上げる問題
    • l \leq l_p \leq r_q \leq rとして、C(l, r)が条件を満たすならばC(l_p, r_q)も条件を満たす(以降、これを部分区間条件と呼ぶ)
      • 例えば、整数の数列x_1, x_2, \cdots, x_nの中からその和がS以下となる連続する部分数列を数え上げる問題

アルゴリズム

上位区間条件を満たす問題と部分区間条件を満たす問題とで若干異なる(うまいことやれば、共通にできるかもしれない)

  1. 条件を満たさない間rをインクリメント(右端まで来てインクリメントできない場合はそのまま)
  2. 条件を満たしたところで、インクリメントを止め、その時点のl, rを記録(この時点でlp = l, r \leq r_pを満たす(l_p, r_p)は全て条件を満たす)
  3. 条件が満たさなくなるまでlをインクリメント
  4. この1~3の操作をrがインクリメントできなくなり、かつ、条件C(l, r)を満たさなくなるまで繰り返す
  1. 条件を満たす間rをインクリメント(右端まで来てインクリメントできない場合はそのまま)
  2. 条件を満たさなくなったところで、インクリメントを止め、その時点のl, rを記録(この時点でlp = l, r_p \leq rを満たす(l_p, r_p)は全て条件を満たす)
  3. 条件を満たすまでlをインクリメント
  4. この1~3の操作をrがインクリメントできなくなり、かつ、lrと同じになるまで繰り返す

抽象クラスの実装

こちらもテンプレートメソッドパターンで実装する。上述の通り上位区間条件と部分区間条件を分けて考えるので、それぞれ抽象クラスを作る。

#include <functional>
 
//ある区間が条件を満たすならば、
//それを包含する上位区間も条件を満たす問題を解くクラス
template <typename T>
class abstract_inchworm_superset
{
private:
    virtual void move_left_pointer() = 0;//lを右に動かす処理
    virtual void move_right_pointer() = 0;//rを右に動かす処理
    virtual bool is_cond() = 0; //条件成立か否かを返す
    virtual bool is_right_end() = 0;//rが右端に到達したかを返す
    virtual std::pair<T, T> no_next() = 0;//次の要素がない場合に終了値を返す
    virtual std::pair<T, T> get_LR() = 0;//lとrを返す
public:
    virtual std::pair<T, T> next(); //条件を満たすデータを返し、ポインタを次に進める
};
 
template <typename T>
std::pair<T, T> abstract_inchworm_superset<T>::next(){
    while (is_cond() == false && is_right_end() == false) {
            move_right_pointer();
    }
    if (is_cond() == false)
        return no_next();
    std::pair<T, T> ret = get_LR();
    move_left_pointer();
    return ret;
}
 
//ある区間が条件を満たすならば、
//それが包含する部分区間も条件を満たす問題を解くクラス
template <typename T>
class abstract_inchworm_subset
{
private:
    virtual void move_left_pointer() = 0;//lを右にインクリメントする処理
    virtual void move_right_pointer() = 0;//rを右にインクリメントする処理
    virtual bool is_cond() = 0; //条件成立か否かを返す
    virtual bool is_right_end() = 0;//rが右端に到達したかを返す
    virtual std::pair<T, T> no_next() = 0;//次の要素がない場合に終了値を返す
    virtual std::pair<T, T> get_LR() = 0;//lとrを返す
    virtual bool is_left_eq_right() = 0;//lとrが一致しているかを返す
public:
    virtual std::pair<T, T> next(); //条件を満たすデータを返し、ポインタを次に進める
};
 
template <typename T>
std::pair<T, T> abstract_inchworm_subset<T>::next() {
    std::pair<T, T> ret;
    while (is_right_end() == false && is_cond() == true) {
        move_right_pointer();
    }
    ret = get_LR();
    if (is_left_eq_right() == true) {
        return no_next();
    }
    
    move_left_pointer();
    return ret;
}

具体的な問題を解く実装

POJ 3061:Subsequence

蟻本に例として載っている和がS以上になる部分数列を数え上げる問題。上位区間条件型のしゃくとり法で解ける。

#include <iostream>
#include <algorithm>
#include <vector>

class subsequence :
    public abstract_inchworm_superset<int>
{
private: 
    int Left, Right;
    long long S;
    int N;
    long long LRsum;
    std::vector<long long> sequence;
    virtual void move_left_pointer();
    virtual void move_right_pointer();
    virtual bool is_cond();
    virtual bool is_right_end();
    virtual std::pair<int, int> no_next();
    virtual std::pair<int, int> get_LR();
public:

    void set_S(long long S);
    subsequence(std::vector<long long>& sequence, int N, int start);
};

subsequence::subsequence(std::vector<long long>& sequence, int N, int start):
    sequence(sequence),
    N(N),
    Left(start),
    Right(start),
    S(0),
    LRsum(0)
{
}

void subsequence::set_S(long long S) {
    this->S = S;
}

void subsequence::move_left_pointer() {
    LRsum -= sequence[Left];
    Left++;
}

void subsequence::move_right_pointer() {
    if (Right < N) {
        LRsum += sequence[Right];
        Right++;
    }
}

bool subsequence::is_cond() {
    if (LRsum >= S)
        return true;
    return false;
}

bool subsequence::is_right_end() {
    if (Right >= N)
        return true;
    return false;
}

std::pair<int, int> subsequence::no_next() {
    return std::make_pair(-1, -1);
}

std::pair<int, int> subsequence::get_LR() {
    return std::make_pair(Left, Right);
}

int main()
{

    int N;
    std::cin >> N;
    for (int k = 0; k < N; k++) {
        std::vector<long long> sequence;
        int n;
        long long S;
        std::cin >> n >> S;
        for (int i = 0; i < n; i++) {
            long long a;
            std::cin >> a;
            sequence.push_back(a);
        }

        subsequence seq(sequence, sequence.size(), 0);
        seq.set_S(S);
        int min_len = INT_MAX;
        while (1) {
            std::pair<int, int> LR = seq.next();
            if (LR.first == -1) break;
            min_len = std::min(min_len, LR.second - LR.first);
        }
        if (min_len == INT_MAX)
            std::cout << 0 << std::endl;
        else
            std::cout << min_len << std::endl;
    }
    return 0;
}
ARC098 D:Xor Sum 2

部分区間条件型のしゃくとり法で解けると気付くのが難しい問題。気づいてしまえば、普通にしゃくとり法を適用するだけ(部分区間を数え上げる問題なので、何も考えずしゃくとり法を適用してACしちゃったって人も結構いるかも)。なぜ、しゃくとり法で解けるかは解説参照。
ここで注意するのは、関数is_cond()のtrueを返す条件が現在指しているrの更に一つ先のrにおいて、条件が成立している場合としている。つまり、right pointer(r)を動かしても条件成立するなら、実際にright pointerを動かすという処理になっている。なお、この実装では区間[l, r)という半区間で扱っているため、一つ先のrrが指示しているインデックスの値を足すこと(XORを取ること)に当たる。

#include <vector>
#include <stdio.h>
class xor_sum :
    public abstract_inchworm_subset<int>
{
private:
    int Left, Right;
    
    int N;
    long long LR_XOR;
    long long LRsum;
    std::vector<long long> sequence;
 
    virtual void move_left_pointer();
    virtual void move_right_pointer();
    virtual bool is_cond();
    virtual bool is_right_end();
    virtual std::pair<int, int> no_next();
    virtual std::pair<int, int> get_LR();
    virtual bool is_left_eq_right();
public:
 
    xor_sum(std::vector<long long>& sequence, int N, int start);
};
 
xor_sum::xor_sum(std::vector<long long>& sequence, int N, int start) :
    sequence(sequence),
    N(N),
    Left(start),
    Right(start),
    LR_XOR(0),
    LRsum(0)
{
}
 
 
void xor_sum::move_left_pointer() {
    LRsum -= sequence[Left];
    LR_XOR ^= sequence[Left];
    Left++;
 
}
 
void xor_sum::move_right_pointer() {
    if (Right < N) {
        LRsum += sequence[Right];
        LR_XOR ^= sequence[Right];
        Right++;
    }
}
 
bool xor_sum::is_cond() {
    if (Left >= N) 
        return false;
    //Rightが1つ先に行っても条件を満たすか
    if ((LRsum + sequence[Right]) == (LR_XOR ^ sequence[Right]))
        return true;
    return false;
}
 
bool xor_sum::is_right_end() {
    if (Right >= N)
        return true;
    return false;
}
 
std::pair<int, int> xor_sum::no_next() {
    return std::make_pair(-1, -1);
}
 
std::pair<int, int> xor_sum::get_LR() {
    return std::make_pair(Left, Right);
}
 
bool xor_sum::is_left_eq_right() {
    return (Left == Right);
}
 
 
int main()
{
    int N;
    std::vector<long long> A;
    scanf("%d", &N);
 
    for (int i = 0; i < N; i++) {
        long long a;
        scanf("%lld", &a);
        A.push_back(a);
    }
    xor_sum xs(A, N, 0);
    long long count = 0;
    while (1) {
        std::pair<int, int> LR = xs.next();
        if (LR.first == -1) break;
        count += LR.second - LR.first ;
    }
    printf("%lld\n", count);
    return 0;
}

おわりに

今回は勉強のため、本質的なアルゴリズムを記述した抽象クラスと実際の問題を解く具象クラスを分けましたが、実際のコンテスト時はそんなことせずに、普通に書いた方が速いと思います。