甲斐性のない男が機械学習とか最適化とかを綴るブログ

うどんくらいしか食べる気がしない

競プロとかに使うアルゴリズム実装メモ(幅優先・深さ優先探索、union-find、最小全域木)

はじめに

以前の記事にて最短経路問題を解くアルゴリズムの実装を書きましたが、今回はその続きとしてグラフアルゴリズムの中でも幅優先探索深さ優先探索、union-find、最小全域木問題を解くアルゴリズム2種について実装を書いていきます。
例によって、自分の脳内整理メモの側面が強い上、検索すればそれぞれのアルゴリズムについて素晴らしい解説記事が出てくるので、まじめな解説を見たい場合はそちらを当たった方が良いと思います。

幅優先探索

用途

  • コストが全て同じグラフ(重みなしグラフ)の最短経路探索。
  • 連結成分の列挙。
  • 重みなしグラフの最小全域木

アルゴリズムのポイント

  • 隣接する頂点から順々に探索していく。
  • 距離を求める場合は、遷移元の距離に+1をする。

実装のポイント

  • キューを使う。
  • まずスタートの頂点をキューに入れ、キューの先頭から頂点を取り出し、それに隣接する頂点を探索。
  • 未探索の全て頂点はキューに追加して、またキューの先頭から頂点を取り出して・・・ということを目的の頂点が見つかるか、キューが空になるまで繰り返していく。
  • 既に探索済みの頂点かというフラグを持たせておき、既に見た頂点だった場合はキューに追加しない(下記実装では距離が初期値-1でなければ探索済みと判定)。

実装

以前と同様、メイン関数は省略。

#include <unordered_map>
#include <vector>
#include <queue>

class breadth_first_search
{
private:
    int node_num;
    int abs_time;
    std::unordered_map<int, std::vector<int>> adj_list;
    std::vector<int> dists;//始点からの距離

public:
    breadth_first_search(const int node_num,
        const std::unordered_map
        <int, std::vector<int>>& adj_list);
    void exec_search(int start);
    int get_dist(int v);
};

breadth_first_search::breadth_first_search(const int node_num,
    const std::unordered_map
    <int, std::vector<int>>& adj_list) :
    node_num(node_num),
    adj_list(adj_list),
    dists(node_num + 1, -1)
{
}

void breadth_first_search::exec_search(int start) {
    std::queue<int> v_queue;
    v_queue.push(start);
    dists[start] = 0;
    int d = 0;
    while (1) {
        if (v_queue.empty() == true)
            break;
        int v = v_queue.front();
        v_queue.pop();
       
        for (int &adj : adj_list[v]) {
            if (dists[adj] == -1) {
                dists[adj] = dists[v] + 1;
                v_queue.push(adj);
            }
        }
    }
}

int breadth_first_search::get_dist(int v) {
    return dists[v];
}

AOJのALDS1_11_Cが通ることを確認。

深さ優先探索再帰呼び出し版)

用途

  • コストが全て同じグラフ(重みなしグラフ)において、与えられたコスト以内でたどり着ける頂点の列挙。
  • 現実における迷路探索(現実だと幅優先探索は分身かワープでもできないと無理)。
  • 深さを時系列的な情報と関連づけて探索するケース。

(基本、幅優先探索と同じことができるが、それぞれ解きたい問題に応じて使い分ける)

アルゴリズムのポイント

  • 行けるところまでどんどん深くまで探索していく。
  • これ以上探索できないというところまで探索したら根まで戻り、更に別なルートの探索を進める。

実装のポイント

  • 幅優先探索と同じノリでスタックを使うパターンと再帰呼び出しを使うパターンがある(今回の実装は再帰版)。
  • 再帰呼び出しで行う場合、分岐して最深部まで探索した後、分岐点まで戻って来たときに既に探索済みか否かという情報をクリアしたいケースの実装が楽になる(上記の深さと時系列情報を関連付ける場合、分岐点に戻る=時系列的にも戻る=探索済み情報をクリアする、とするべき問題もある。これとか)。
  • 今回の実装では、AOJの問題に対応するべく、探索で移動したらカウントアップ(後戻り時もカウントアップ)するabs_time変数を用意し、exec_search_sub関数に入った時と出る時にカウントアップさせている。また、find_abs_timesには各頂点を見つけたときのabs_timeを保持しておく。
  • find_abs_timeが初期値(-1)でなければ、確定頂点として探索しない。
  • このタイムスタンプのつけ方は、問題に応じて適宜カウントアップするタイミング、回数、場合によっては初期値にクリアする等の変更が必要。

実装

#include <unordered_map>
#include <vector>
class depth_first_search
{
private:
    int node_num;
    int abs_time;
    std::unordered_map<int, std::vector<int>> adj_list;
    std::vector<int> depths;//ルートからの深さ(距離ではない)
    std::vector<int> find_abs_times;//頂点に到達した時刻
    std::vector<int> finish_abs_times;//サーチが完了した時刻
    
    void exec_search_sub(int start, int depth);

public:
    depth_first_search(const int node_num, 
                       const std::unordered_map
                        <int, std::vector<int>>& adj_list);
    
    void exec_search(int start);
    int get_depth(int v);
    int get_find_abs_time(int v);
    int get_finish_abs_time(int v);
};

depth_first_search::depth_first_search(const int node_num, 
                                       const std::unordered_map
                                        <int, std::vector<int>>& adj_list):
    node_num(node_num),
    abs_time(0),
    adj_list(adj_list),
    depths(node_num + 1, -1),
    find_abs_times(node_num + 1, -1),
    finish_abs_times(node_num + 1, -1)
{
}

void depth_first_search::exec_search_sub(int start, int depth) {
    
    if (find_abs_times[start] != -1) return;
    depths[start] = depth;
    abs_time++;//探索進みのカウントアップ
    find_abs_times[start] = abs_time;
    for (auto &v : adj_list[start]) {
        exec_search_sub(v, depth + 1);         
    }
    abs_time++;//探索後戻りのカウントアップ
    finish_abs_times[start] = abs_time;
}

void depth_first_search::exec_search(int start) {
    for (int i = start; i < node_num + 1; i++) {
        exec_search_sub(i, 0);
    }
}

int depth_first_search::get_depth(int v) {
    return depths[v];
}
int depth_first_search::get_find_abs_time(int v) {
    return find_abs_times[v];
}

int depth_first_search::get_finish_abs_time(int v) {
    return finish_abs_times[v];
}

AOJのALDS1_11_Bが通ることを確認。

union-find

用途

  • 集合を扱うデータ構造。
  • 値xとyが同じ集合に属するか否かの判定や2つの集合の結合が高速で実行可能。
  • 後に説明する最小全域木問題を解くアルゴリズムであるクラスカル法でも使われる。

アルゴリズムのポイント

  • 各集合、木構造でデータを保持する。
  • 値xとyが同じ集合か否かを判断するには(以下、find操作)、木の根を見て同じだったら同じ集合、異なれば別の集合に属すると判断すればよい。
  • そのため、なるべく木の高さが低くなるようにデータを保持した方が、高速に判断できるようになる。
  • このことから、2つの集合を結合する操作(以下、unite操作)では木の高さが低い方の根の親を高い方の根になるように連結させる。こうすれば、連結後の木の高さは元の集合のうち高い方の高さと同じになるので、高さは増えない(両方の木の高さが同じ場合は、木の高さが1だけ増える)。

実装のポイント

  • 木は親ノードの値を保持する1次元配列と所属する木の高さを保持する1次元配列で表現する。
  • こうすれば、unite操作時、低い方の木の根の親を書きかえるように親を保持する配列を操作だけで良くなる。
  • find操作時に全てのノードを直接親ノードとつなぐことで木の縮約を行う。
  • 木の縮約時に木の高さの更新を行わないので実際の木の高さと配列に保持している高さが合わなくなるが、あまり気にしなくてよい(重要なのはノードの親子関係であり、それさえ合っていればunion-findは問題なく動作する。木の高さの配列に入っている値がおかしいとunite操作時に、接続する関係が逆になり木が高くなる恐れがあるが、それよりもfind操作で縮約する方がメリットが大きい)。

実装

#include <vector>

class union_find
{
private:
    std::vector<int> parent;
    std::vector<int> rank;
    int node_num;
public:
    union_find(int node_num);
    void unite(int x, int y);
    int find(int x);
};

union_find::union_find(int node_num):
    parent(node_num + 1, 0),
    rank(node_num + 1, 0)
{
    for (int i = 0; i <= node_num; i++) {
        parent[i] = i;
    }
}

void union_find::unite(int x, int y) {
    int x_root = find(x);
    int y_root = find(y);
    //既に同じ根なら、同じ所属なので何もしない
    if (x_root == y_root) return;
    //ランクが大きい方に小さい方をくっつける
    //結合後のランクは大きい方になる。
    if (rank[x_root] > rank[y_root]) {
        parent[y_root] = x_root;
    }
    else {
        parent[x_root] = y_root;
        if(rank[x_root] == rank[y_root]){
            //ランクが等しい場合は、必ずランクが1上がる。
            rank[y_root]++;
        }
    }
}

int union_find::find(int x) {
    if (x != parent[x]) {
        // findの結果を各parentに入れることで経路を圧縮。
        //ただし、変更コストが大きいのでランクは変えてない。
        parent[x] = find(parent[x]);
    }
    return parent[x];
}

後に説明するクラスカル法の中で一緒に確認。

プリム法

用途

  • 最小全域木問題を解く。
  • 最小全域木問題自体は与えられた無向グラフから、ループがなくなるように枝を取り除き、辺のコストの合計が最小になる全域木を求める問題。
  • 例えば、全ての街間で行き来できるように最小コストで道を作りたい的な問題に使える(完全グラフを構成して最小全域木を求める等*1)。

アルゴリズムのポイント

  • ダイクストラ法と似ており、まず探索の初期位置となる頂点を決め、そこから出ている辺を探索中リストに追加する。
  • 探索中リストの中で最小のコストの辺を取り出す。この辺は最小全域木の構成要素として確定する。
  • 上記で取り出した辺につながっている頂点を次の探索対象の頂点とし、その頂点を確定頂点とする。また同様にそこから出ている辺を探索中リストに追加する(ただし、確定頂点へつながる辺は登録しない)。
  • これらの操作を全ての頂点が確定するまで繰り返す。

実装のポイント

  • ダイクストラ法と同様に優先度付きキューを使う。
  • 各頂点に対して現在どのコストで優先度付きキューに登録しているかを保持する外部変数keyを用意する。
  • 優先度付きキューから最小コストの辺(と接続先の頂点)を取り出し、この辺を最小全域木の構成要素として確定させる。
  • 同時に接続先の頂点も確定頂点として記録し、次はその頂点につながっている辺を見ていく。
  • 辺のコストを見て、接続頂点のkeyの値より小さいコストの辺だったらkeyの値を更新し、優先度付きキューに追加する。もし、大きいコストの辺だったら、その辺は最小全域木の構成要素にはなり得ないので、優先度付きキューには追加しない。これ結構大事。これをやらないと、どんどん優先度つきキューに登録されていくので、サイズがばかでかくなり、確定ノードかの判定回数が大幅に増えるので遅くなる可能性がある。

実装

#include <vector>
#include <unordered_map>
#include <algorithm>
#include <vector>
#include <unordered_map>
#include <queue>
#include <functional>
#include <climits>


class edge {
public:
    edge(int to, long long cost);
    int get_to() const;
    long long get_cost() const;
private:
    int to;
    long long cost;
};

edge::edge(int to, long long cost) : to(to), cost(cost) {}

int edge::get_to() const { return to; }

long long edge::get_cost() const { return cost; }

class prim
{
private:
    //unordered_mapを使って隣接リストを構築
    //first:接続元、second:辺のコストと接続先
    std::unordered_map<int, std::vector<edge>> adj_list;
    std::vector<int> pred;
    std::vector<long long> key;
    int node_num;
    std::vector<bool> is_done;//確定ノードか否かを保持
    long long MST_cost;
public:
    prim(const int node_num, const std::unordered_map<int, std::vector<edge>>& adj_list);
    long long get_MST_cost() const;
    //std::unordered_map<int, std::vector<edge>> calc_MST();
    std::unordered_map<int, std::vector<edge>> get_MST_list();
    void calc_MST();
};

prim::prim(const int node_num, const std::unordered_map<int, std::vector<edge>>& adj_list):
    adj_list(adj_list),
    node_num(node_num),
    is_done(node_num + 1, false),
    pred(node_num + 1, INT_MAX),
    key(node_num + 1, LLONG_MAX),
    MST_cost(0)
{

}

void prim::calc_MST() {
    MST_cost = 0;
    key = std::vector<long long>(node_num + 1, LLONG_MAX);
    pred = std::vector<int>(node_num + 1, INT_MAX);
    //優先度付きキューpairで頂点番号と最短距離を保持
    //firstの要素で比較されるので、firstが距離、secondを遷移先の頂点とする
    std::priority_queue<std::pair<long long, int>,
        std::vector<std::pair<long long, int>>,
        std::greater<std::pair<long long, int>>> p_queue;
    p_queue.push(std::make_pair(0, adj_list.begin()->first));
    key[adj_list.begin()->first] = 0;
    pred[adj_list.begin()->first] = -1;

    while (1) {
        if (p_queue.empty() == true) break; //キューが空なら終了
        std::pair<long long, int> cost_node = p_queue.top();
        p_queue.pop();
        if (is_done[cost_node.second] == false) {
            MST_cost += cost_node.first;
        }
        
        int from_node = cost_node.second;
        is_done[from_node] = true; //キューから取り出した頂点を確定頂点とする
        for (auto &it : adj_list[from_node]) {
            int adj_node = it.get_to();
            if (is_done[adj_node] == false) {
                long long adj_cost = it.get_cost();
                if (adj_cost < key[adj_node]) {
                    p_queue.push(std::make_pair(adj_cost,adj_node));
                    key[adj_node] = adj_cost;
                    pred[adj_node] = from_node;
                }
            }
        }
    }
}


long long prim::get_MST_cost() const {
    return  MST_cost;
}

std::unordered_map<int, std::vector<edge>> prim::get_MST_list() {
    std::unordered_map<int, std::vector<edge>> MST_list;
    for (int i = 0; i < node_num + 1; i++) {
        if (pred[i] >= 0 && pred[i] < node_num + 1) {
            MST_list[i].push_back(edge(pred[i], key[i]));
            MST_list[pred[i]].push_back(edge(i, key[i]));
        }
    }
    return MST_list;
}

AOJのGRL_2_Aが通ることを確認。

クラスカル

用途

アルゴリズムのポイント

  • 最初は全ての頂点間が未連結の状態(辺がない状態)からスタート。
  • 辺をコストでソート後、コストが小さい辺から順に全域木の構成要素に追加していく。ただし、ループができる場合はその辺は追加しない。
  • 辺を追加してループができるか否かは、その辺で結んでいる両頂点が既に連結関係にあるか否かで判断可能(既に連結関係にあれば、その辺を追加すると必ずループができる)。

実装のポイント

  • 連結されている頂点同士が同じ集合になるようunion-findで管理し、頂点間の連結関係を効率よく判定および操作できるようにする。
  • 実装では隣接リストをコンストラクタで渡しているが、辺を管理する変数を用意する方が楽なので、隣接リストから辺のリストに変換させている(最初から辺のリストを構築してもよい)。
  • 辺はstd::pair<int, std::pair<int, int>>型で(コスト, (頂点, 頂点))という変数にし、リストにしたとき第1要素のコストでソートされるようにする。

実装

上述のunion-findを使用する。

#include <algorithm>
#include <vector>
#include <unordered_map>
#include <queue>
#include <functional>
#include <climits>

class edge {
public:
    edge(int to, long long cost);
    int get_to() const;
    long long get_cost() const;
private:
    int to;
    long long cost;
};

edge::edge(int to, long long cost) : to(to), cost(cost) {}

int edge::get_to() const { return to; }

long long edge::get_cost() const { return cost; }

class kruskal
{
private:
    //unordered_mapを使って隣接リストを構築
    //first:接続元、second:辺のコストと接続先
    std::unordered_map<int, std::vector<edge>> adj_list;
    //最小全域木の隣接リスト
    std::unordered_map<int, std::vector<edge>> MST_list;
    int node_num;
    long long MST_cost;
public:
    kruskal(const int node_num, const std::unordered_map<int, std::vector<edge>>& adj_list);
    long long get_MST_cost() const;
    std::unordered_map<int, std::vector<edge>> get_MST_list();
    void calc_MST();
};


kruskal::kruskal(const int node_num, const std::unordered_map<int, std::vector<edge>>& adj_list):
    adj_list(adj_list),
    node_num(node_num),
    MST_cost(0)
{
}

void kruskal::calc_MST() {
    //辺のリスト。firstがコスト、secondがつながっている両頂点
    std::vector<std::pair<long long, std::pair<int, int>>> edge_list;

    //隣接リストから辺のリストへ変換
    for (auto &v : adj_list) {
        for (auto &e : v.second) {
            edge_list.push_back(std::make_pair(e.get_cost(), 
                        std::make_pair(v.first, e.get_to())));
        }
    }

    std::sort(edge_list.begin(), edge_list.end());

    union_find uf(node_num + 1);
    MST_cost = 0;
    MST_list.clear();
    for (auto &ed : edge_list) {
        int v1 = ed.second.first;
        int v2 = ed.second.second;
        long long cost = ed.first;
        //同じ根を持つなら、既に両頂点は同じ木の中
        //=辺を追加するとループができるのでなにもしない。
        if (uf.find(v1) != uf.find(v2)) {
            //根が異なる場合、両頂点間に辺を追加
            uf.unite(v1, v2);
            MST_cost += cost;
            MST_list[v1].push_back(edge(v2, cost));
            MST_list[v2].push_back(edge(v1, cost));
        }
    }
}

long long kruskal::get_MST_cost() const {
    return MST_cost;
}

std::unordered_map<int, std::vector<edge>> kruskal::get_MST_list() {
    return MST_list;
}

プリム法同様、AOJのGRL_2_Aで確認。

おわりに

前回、今回とグラフアルゴリズムについて見てきましたが、これら以外にもネットワークフロー問題やマッチング問題など様々な問題とそれらを解くアルゴリズムがあります。また、前回紹介した最短経路問題や今回扱った最小全域木問題、他にもマッチング問題など多くの問題は離散凸解析の枠組みで抽象化され、その枠組みの中で最適化アルゴリズムが提案されています。この離散凸解析についてはまだ全然理解していませんが、いつか勉強して記事を書きたいところです。
あと、このシリーズはもう少し続きます。次回は尺取り法や累積和法などを書いていこうと思います。

*1:競プロだと、完全グラフを考えると大抵の場合TLEになるので問題の性質から元のグラフをどのように構築するかというところが肝になる。

競プロ関係の雑メモ 2018/9/9

はじめに

以前と同様、atcoderの400~500点問題穴埋め時のメモに加え、ABC109参加のメモです。

ARC088-D Wide Flip

区間[l - 1, r]を反転させた後、区間[l, r]を反転させると、l - 1番目の1文字のみ反転させることができる。同様に区間[l, r + 1]を反転させた後、区間[l, r]を反転させると、r + 1番目の1文字のみ反転させることができる。ということは、選べる区間の最小の長さKが与えられたとして、左端からN-K個、右端からもN-K個、合計2(N-K)個の要素は自由に0,1を選べる、ということになる。
じゃあ、左端もしくは右端から何番目の要素まで自由に変えられるようにする必要があるかと考えると、真ん中の要素がいくつ連続で同じ文字かを見ればよいという結論に至る。つまり、中央N/2番目の要素から左右にC個が全て同じ要素だとすると、そのC個は変化させる必要がなく、左端から(N-C)/2番目までの要素もしくは右端から(N-C)/2番目までの要素のどれかが自由に0, 1を選べるようKを定めればよい。よって、N - K = (N - C)/2という方程式が出てくるので、これをKについて解いてK = N - (N - C)/2となる。
Cについては、普通に真ん中から見て同じ要素をカウントすればいいのだけど、文字列の長さが偶数か奇数かで話が変わってくるので、少し注意が必要(実際、1度for文回す回数間違えてWAになった)。

ARC064-D An Ordinary Game

まず個数に着目。1個ずつ文字を削除していくゲームなので、負けるときの文字数の法則性を見つけ出せれば、後は逆算して元の文字数から勝者がわかると考えた。じゃあ、負けるときの文字列とは?ってところについては、2パターン考えられる。ひとつはabのような文字数が偶数での負け、もうひとつはabaのような文字数が奇数での負け(この辺り、大体の予想は立ったが、じゃあ本当?っていうところが解説見るまで分からなかった)。では、どういうときに偶数負けパターンになるか?奇数負けパターンになるか?
ポイントはゲームの性質上、先頭の文字と後尾の文字は必ず最後まで残るということ。このことから、先頭の文字と後尾の文字が同じであれば奇数負けパターン、異なるのであれば偶数負けパターンになる。なぜなら、例えばa**b***a*b(*はaとb以外の文字)のように先頭と後尾が異なるなら、必ず*は取り除くことができ、最終的な負けパターンとしてはabかababになる(そうなるように、両プレイヤーがコントロール可能)。どちらになるかは、一意に定まらないが偶数の文字数になることはかわらない。これと同じような考えで、a*b**a*b**a等のように、先頭と後尾が同じ場合は、abaかa*aかababaかのいずれかが最終形になるので必ず奇数の文字数になる。ここまで来れば、上述の通り逆算して元の文字列数からどちらが勝つかはわかる。
どうもゲーム系は苦手意識があったが(未だに「互いが最適に行動を取る」ってことのイメージがつかめてない)、最終的に詰み(負け)となる形を考えて、そこから逆算していくっていう方法が有効そう。
この問題では一発でACを取れたが、根拠の薄い予想がたまたま当たったっていうパターンなので、あまりよろしくない*1

ARC067-D Walk and Teleport

まず、いくつか移動パターンを考える。

  1. 西から東へ順々に移動するパターン(テレポートしても町は飛び越えない)。
  2. テレポートを使って、まだ移動していない町を飛び越え、その飛び越えた町へ行くため東から西へ逆走をするパターン。
  3. 全てテレポートで移動するパターン。

この中で2の移動パターンは使う必要がない。なぜなら、最終的には全ての町を訪れる必要があるため、逆走しても順方向に移動しても距離は変わらず、町を飛び越すメリットがないから。また、1のパターンはB<A(X[i + 1] - X[i])であればテレポートを使うし、そうでなければ徒歩で移動する。3のパターンは1の条件で、全ての町間の距離がテレポート使用条件に合うなら自動的に適用される。ということで、min(A(X[i + 1] - X[i]), B)を足しこんでいけばよい。
500点問題の中では明らかに簡単なような・・・

ABC 109参加メモ

2018/9/8開催のABC109に参加。4完達成もD問題の実装に時間がかかりすぎたのがよろしくない。

A問題

A*Bが奇数か偶数か。

B問題

setを使って既出かの判定と、前回の文字列の最後の文字と今見てる文字列の頭文字が同じかという判定を行う。

C問題

Xも数列に組み込んでソートしてx[i + 1] - x[i]の最大公約数を求めたが、解説読むとソートなんかしなくても|X - x[i]|の最大公約数でよかった。

D問題

着目しているマスのコイン数が奇数なら隣のマスにコインを1枚ずらし、着目するマスもそのずらした先に移動。逆に、着目しているマスのコイン数が偶数なら何もせず、着目するマスのみ隣にずらす。この操作を左端からジグザグでもグルグルでも何でもいいから順々に全てのマスに対して行えば、目的が達成される。
その理由としては、まず奇数のマスから奇数のマスへコインを一つずらすと両方偶数になる。その上で、ずらし元が奇数、ずらし先が偶数の場合、操作後はずらし先が奇数になるが、直後に奇数になったずらし先に対して操作を行えば、そのマスも偶数になる。この操作を続けていくと、どこかかでずらし元、ずらし先ともに奇数という個所が出てきて、両方偶数になる。もしそのような個所が出てこない場合、最後のマスまでコインが伝搬され最後のマスのみが奇数になる。従って、最適な最終形としては、コインの総枚数が偶数の場合は全てのマスが偶数になり、コインの総枚数が奇数の場合、1つのマスのみが奇数でそれ以外全手のマスが偶数になる。
ただ、いざ実装ってなると、配列のジグザグ操作またはグルグル操作の実装が面倒だなーと億劫になり、結局、既に操作済みのマスか?という情報を記録する配列を別途用意。ずらせるマスを探し、ずらした先の座標を次のずらし元とする作戦を取る。ACだったが実装に時間がかかりすぎた。

おわりに

最近、競プロ関連チラ裏ブログと化しているので、次回はもう少しまともな内容を書きます。

*1:もう少し正確に言うと、予想した負けパターンは、abのような2文字か、abaのような3文字、もしくは最初からabababのようなFirstプレイヤーが取れず負けるというパターンの3種類を考え、最後のパターンだけ(無駄な)例外処理をする実装でACになった。

競プロ関係の雑メモ 2018/9/1

はじめに

競プロの問題解いてて、解けなかった問題は何故解けなかったのか、何故その発想が出なかったのかを記していく自分用メモです。atcoderの400、500点くらいの問題なら、解けないなりにもボチボチ近いところまでは行けてるっていうパターンが多いので(それ以上の問題は今のところ大体歯が立たない)、あと一歩何が足りなかったのかを探って書いていきます。

ARC074-D 3N Numbers

とりあえず、境界を決めて左側の数列から大きい順にN個、右側の小さい順にN個取ればよく、後はスコアが最大になる境界位置を全探索すればよい。ただし、これだとO(N^2logN)だからTLEになる。じゃあ、どうする?優先度付きキューか?っていうところまでは発想が至った。
ということで、左側数列用優先度付きキューと右側数列用優先度付きキューを用意して、境界位置を左から順にずらし、両優先度付きキューを更新していくってことを行おうとしたところ、右側数列用優先度付きキューから指定要素を取り除く操作が出てきて、その操作には1回O(N)かかるぞ、どうすればいいんだーとなる。
境界位置探索について1回で済まそうと、左側と右側の更新を同時(同じループ内)で行おうと考えたのが間違い。
まず境界位置を左からずらして、左側数列優先度付きキューを更新していく。それが終わった後、今度は境界位置を右からずらして、右側数列の優先度付きキューを更新していく。こうすれば優先度付きキューの操作は、先頭要素取り出しと要素追加だけになるので1回の操作はO(logN)で可能。最後にどの境界位置が最適かを見ればよい。

ARC083-D Restoring Road Network

与えられている最短距離表を各頂点間を結ぶ隣接行列だとみなしてワーシャル・フロイドで解いた後、矛盾がないか探る(矛盾があれば、ワーシャルフロイドで求めた最短距離と最短距離表の値が異なる部分が出てくる)。矛盾がなければ、頂点i,j間を直接結ぶ辺が必要かを考えればよい。つまり、i,j間を直接結ばずともどこかを経由して最短距離となる経路があるかを探す。
っていうところまでは出てきたけど、ここで躓く。i, j間の最短経路で直接結ぶ以外のパスがあるかを求める方法がわからなかった(i, jを直接結ぶパスのみを削除したグラフG'を用意して、これに対してダイクストラ法でi, j間の最短距離を求め、元のグラフで求めた最短距離と変わらないかってことをやろうとしたけどTLE)。
これを求めるには、頂点kを経由して、最短距離が等しくなる経路が存在するかを調べればよい。つまり、i-k間の最短距離+k-j間の最短距離が最短距離表にあるAijと等しくなるkが存在するかみればよい。

ARC076-D Built?

x座標、y座標それぞれでソートして、隣り合う街同士にのみ道を作ることを考えればOK。x座標、y座標ごった煮にして、コストの小さい道から順に追加していけばいいってところまでは考えた。ところが、やってみると各頂点をが全て連結されているのか効率よく判断する方法がわからず、言葉だけは知っていたUnion-Findを使うのかな?って思ったところで、なんとなく敷居が高い気がして断念。
というか、ここまで来れば、まんまクラスカル法。もし、クラスカル法知らなくても、ソートして隣り合う街同士のみ考えればいいってことを鑑みると、考慮する辺数が2Vだけなので、O(VlogV)で解ける最小全域木問題だって事気付けばプリム法で解いても良かった(これ解いた時、プリム法は知ってた)。いずれにしても、問題そのものが最小全域木の典型問題にもかかわらず、ソートとか隣り合う街同士のみ考慮とか色々考えていくうちに、最小全域木問題を解くアルゴリズム使うという発想がすっぽり頭から抜けていたのが反省点。

おわりに

もう少しレベルを上げたいので、時々こういう反省ポエムも書いていきます。

競プロとかに使うアルゴリズム実装メモ(最短経路探索系)

はじめに

ここ1年くらい、ちまちまとatcoder中心に競技プログラミングに参加してたりします。ABCのD問題を解くのがやっとなのにAGCに突撃して爆死するってことを繰り返し続け、未だに緑コーダーです(パフォーマンスも1200前後がやっと)。こりゃまずいということで、もう少しまじめにアルゴリズムの勉強をしようと思い立ったので、今回から数回にわたって基本的なアルゴリズムのメモと実装を書いて、自分の理解を深めていこう思います。
今回はその中でも最短経路探索系(ダイクストラ法、ベルマン・フォード法、ワーシャル・フロイド法)のアルゴリズムについて書いていきます。これらのアルゴリズムについては、ネットや探せばいくらでも解説記事が出てくるので、ここでは主に自分が実装していく中で理解にしにくかった、ハマった点をポイントとして記していきます。そのため、ほぼ自分用のメモ&脳内整理用の記事なので、まじめな解説は他を当たった方が良いかもしれません。
ちなみに、今まで競プロも機械学習pythonで実装してきましたが、今後本業でC++を使う機会が増えそうなので、今回はリハビリの意味も込めてC++での実装を行っていきます。

ダイクストラ

用途

  • 負のコストがない場合の単一始点最短経路探索。

アルゴリズムのポイント

  • 始点に隣接している頂点からの順次最短経路を確定させていくアルゴリズム
  • 最短経路が確定した頂点に隣接する頂点と始点間のコストを計算・更新する処理を全部の頂点の最短経路が確定するまで繰り返す。
  • 負のコストがない前提なので、現時点の始点からコスト計算されている頂点の中で最もコストが小さい頂点について、そのコストを最短経路として確定させてよい。
  • 具体的には、始点の頂点のコストを0、それ以外の頂点までのコストを∞として初期化後、下記の処理を終了条件を満たすまで繰り返す。
    1. 最短経路が確定確定していない頂点の中からコストが最も小さい頂点を選択。
    2. その選択した頂点を最短経路確定頂点とする。
    3. その選択した頂点に隣接する頂点の現時点のコストが、選択した頂点へのコスト+辺のコストより大きければ、選択した頂点へのコスト+辺のコストに置き換える。
  • 最短経路の復元のため、どの頂点から来たかも記憶しておき、最後にバックトラックで辿れるようにする。

実装のポイント

  • 優先度付きキューを使って、コスト最小の頂点から取り出して最短経路確定頂点としていく。
  • 優先度付きキューの要素はSTLのpairを使う。pairの比較演算子は第1要素で比較するので、pairの第1要素にコスト、第2要素に頂点番号を入れる。
  • 同じ頂点でコストが異なる要素が優先度付きキューに残るけど気にしない(プライオリティキューから取り出した時に確定頂点か否かを確認するから、下手に操作せずに残しておく)。
  • 隣接頂点をすぐに取り出せるようにするため、グラフの表現にはunordered_mapを使った隣接リストを採用。

実装

メイン関数(呼び出し元)は省略。

#include <algorithm>
#include <vector>
#include <unordered_map>
#include <queue>
#include <functional>
#include <climits>

class edge {
public:
    edge(int to, long long cost);
    int get_to() const;
    long long get_cost() const;
private:
    int to;
    long long cost;
};

edge::edge(int to, long long cost) : to(to), cost(cost) {}

int edge::get_to() const { return to; }

long long edge::get_cost() const { return cost; }

class dykstra {
private:
    //unordered_mapを使って隣接を構築
    //first:接続元、second:辺のコストと接続先
    std::unordered_map<int, std::vector<edge>> adj_list;
    int node_num;
    int start_point;
    std::vector<int> pred;//経路(遷移元)を保持
    std::vector<long long> costs;//距離を保持
    std::vector<bool> is_done;//確定ノードか否かを保持

public:
    dykstra(const int node_num, const std::unordered_map<int, std::vector<edge>>& adj_list);
    void calc_min_cost(int start);
    long long get_cost(int end) const;
    std::vector<int> get_min_path(int end) const;
};

dykstra::dykstra(const int node_num, const std::unordered_map<int, std::vector<edge>>& adj_list) :
    adj_list(adj_list),
    start_point(0),
    node_num(node_num),
    pred(node_num + 1, INT_MAX),
    costs(node_num + 1, LLONG_MAX),
    is_done(node_num + 1, false)
{}

void dykstra::calc_min_cost(int start) {
    this->start_point = start;
    this->pred = std::vector<int>(node_num + 1, INT_MAX);
    this->costs = std::vector<long long>(node_num + 1, LLONG_MAX);
    this->is_done = std::vector<bool>(node_num + 1, false);

    //優先度付きキューpairで頂点番号と最短距離を保持
    //firstの要素で比較されるので、firstが距離、secondを頂点番号とする
    std::priority_queue<std::pair<long long, int>,
        std::vector<std::pair<long long, int>>, std::greater<std::pair<long long, int>>> p_queue;
    p_queue.push(std::make_pair(0, start));
    pred[start] = -1;
    costs[start] = 0;

    while (1) {
        if (p_queue.empty() == true) break; //キューが空なら終了
        std::pair<long long, int> cost_node = p_queue.top();
        p_queue.pop();
        is_done[cost_node.second] = true; //キューから取り出した頂点を確定頂点とする
        for (auto &it : adj_list[cost_node.second]) {
            int adj_node = it.get_to();
            if (is_done[adj_node] == false) {
                long long adj_cost = it.get_cost() + cost_node.first;
                if (adj_cost < costs[adj_node]) {//計算された隣接頂点のコストが現在のコストより小さいなら
                    costs[adj_node] = adj_cost; //隣接頂点のコストを更新
                    pred[adj_node] = cost_node.second;//隣接頂点の遷移元を更新
                    p_queue.push(std::make_pair(adj_cost, adj_node));//キューに隣接頂点の情報を突っ込む
                }
            }
        }
    }
}

long long dykstra::get_cost(int end) const {
    return costs[end];
}

std::vector<int> dykstra::get_min_path(int end) const {
    int node = end;
    std::vector<int> vec;
    vec.push_back(node);
    //終点から始点までの経路をたどる
    while (1) {
        //始点から辿れない頂点or
        //始点であれば終了
        if (pred[node] >= INT_MAX || pred[node] == -1) break;
        node = pred[node];
        vec.push_back(node);
    }
    std::reverse(vec.begin(), vec.end());
    return vec;
}

AOJのALDS1_12_Cが通ることを確認。

ベルマン・フォード法

用途

  • 負のコストがある場合の単一始点最短経路探索および負のサイクル検出。

アルゴリズムのポイント

  • 全ての辺のコストを見て、各頂点までの最短経路(コスト)を更新するという処理を所定の回数繰り返すアルゴリズム
  • 具体的には、始点の頂点のコストを0、それ以外の頂点までのコストを∞として初期化後、下記の処理を所定の回数繰り返す(回数については後述)。
    1. 各辺に対して、接続元の頂点のコスト+辺のコストが接続先の頂点のコストより小さければ、接続先頂点のコストを接続元の頂点のコスト+辺のコストに置き換える。
    2. この処理を全ての辺に対して行う。
  • 負のサイクルがあれば、始点からの最短経路は全ての頂点に対して求まらない(いくらでもコストを小さくできるので)。
  • 頂点数をVとして、負のサイクルがグラフ上にない場合、始点から各頂点までの最短経路の経由頂点数は高々V-1個(始点は除く)。
  • これらのことから負のサイクルがなければ、V-1回の繰り返しで全頂点までの最短経路が求まる。逆にV回目に更新が発生するのならばグラフに負のサイクルがあると判定可能。
  • 更に、V~V*2回目の繰り返しにおいて、最短経路が更新される頂点は負のサイクルに含まれると判断される。
  • 加えて、負のサイクルに含まれていると確定している頂点から辿れる頂点についても、負のサイクルに含まれる頂点。
  • 上記のことより、特定終点までの最短経路と負のサイクルの有無を求める場合は、全頂点の最短経路更新および負のサイクル情報の伝搬処理をV*2回繰り返す(グラフ全体に負のサイクルがあるか否かを求めるだけの場合はV回の繰り返しでよい)。

実装のポイント

  • for文使って、V*2回のループと辺数E回のループで素直にアルゴリズムを実装すればよい。
  • ダイクストラ法みたいに優先度付きキューは不要。
  • 今回は上記ダイクストラ法に合わせて、グラフの表現に隣接リストを使ったが、辺(fromとtoとcostを持つクラス)のリストの方が楽かも。
  • 負のサイクルが見つかった後もコスト配列の更新を怠ってはいけない(自戒)。

実装

#include <algorithm>
#include <vector>
#include <unordered_map>
#include <functional>
#include <climits>

class edge {
public:
    edge(int to, long long cost);
    int get_to() const;
    long long get_cost() const;
private:
    int to;
    long long cost;
};

edge::edge(int to, long long cost) : to(to), cost(cost) {}

int edge::get_to() const { return to; }

long long edge::get_cost() const { return cost; }


class bellman_ford {
private:
    //unordered_mapを使って隣接グラフを構築
    //first:接続元、second:辺のコストと接続先
    std::unordered_map<int, std::vector<edge>> adj_list;
    int node_num;
    int start_point;
    std::vector<int> pred;//経路(遷移元)を保持
    std::vector<long long> costs;//距離を保持
    bool is_negative_graph;//グラフ内に負のサイクルをもつか否か
    std::vector<bool> is_negative_pass;//経路上に負のサイクルを持つか否か

public:
    bellman_ford(const int node_num, const std::unordered_map<int, std::vector<edge>>& adj_list);
    void calc_min_cost(int start);
    long long get_cost(int end) const;
    bool get_is_negative_graph() const;
    std::vector<int> get_min_path(int end) const;
};

bellman_ford::bellman_ford(const int node_num,
    const std::unordered_map<int, std::vector<edge>>& adj_list) :
    adj_list(adj_list),
    start_point(0),
    node_num(node_num),
    pred(node_num + 1, INT_MAX),
    costs(node_num + 1, LLONG_MAX),
    is_negative_graph(false),
    is_negative_pass(node_num + 1, false)
{}

void bellman_ford::calc_min_cost(int start) {
    this->start_point = start;
    this->pred = std::vector<int>(node_num + 1, INT_MAX);
    this->costs = std::vector<long long>(node_num + 1, LLONG_MAX);
    this->is_negative_pass = std::vector<bool>(node_num + 1, false);

    pred[start] = -1;
    costs[start] = 0;
    for (int i = 0; i < 2 * node_num; i++) {
        for (auto &node : adj_list) {
            if (costs[node.first] == LLONG_MAX) continue;
            for (auto &adj : adj_list[node.first]) {
                int adj_node = adj.get_to();
                long long adj_cost = adj.get_cost() + costs[node.first];
                if (adj_cost < costs[adj_node]) {//計算された隣接頂点のコストが現在のコストより小さいなら
                    costs[adj_node] = adj_cost;//隣接頂点のコストを更新
                    pred[adj_node] = node.first;//隣接頂点の遷移元を更新
                    if (i >= node_num - 1) {
                        //頂点数回以上繰り返しても、
                        //まだ更新が発生するなら、その頂点への経路には負のサイクルあり
                        is_negative_pass[adj_node] = true;
                        is_negative_graph = true;
                    }
                }
                if (is_negative_pass[node.first] == true) {//負のサイクル情報伝搬
                                                           //経路に負のサイクルを含む頂点に連結されているなら、
                                                           //その頂点への経路も負のサイクルを含む
                    is_negative_pass[adj_node] = true;
                }
            }
        }
    }
}

bool bellman_ford::get_is_negative_graph() const {
    return is_negative_graph;
}

long long bellman_ford::get_cost(int end) const {
    if (is_negative_pass[end] == false)
        return costs[end];
    else
        return LLONG_MIN;
}

std::vector<int> bellman_ford::get_min_path(int end) const {
    int node = end;
    std::vector<int> vec;
    vec.push_back(node);
    //終点から始点までの経路をたどる
    while (1) {
        //負のサイクルを含む頂点or
        //始点から辿れない頂点or
        //始点であれば終了
        if (is_negative_pass[node] == true ||
            pred[node] >= INT_MAX || pred[node] == -1) break;
        node = pred[node];
        vec.push_back(node);
    }
    std::reverse(vec.begin(), vec.end());
    return vec;
}

AOJのGRL_1_Bが通ることを確認。加えて、グラフそのものに負のサイクルを含むか否かではなく、目的の経路に負サイクルがあるかという判定が正しく動作するか検証するために、atcoderのABC061Dでも確認。

ワーシャル・フロイド法

用途

  • 全頂点ペアの最短経路探索。
  • 負のコストがあっても良く、負のサイクル検出も可能。

アルゴリズムのポイント

  • 頂点iから頂点jへ頂点0~k-1のみを経由する最短経路が既に得られているとして、それに頂点kを追加して最短経路を更新していくという動的計画法の一種。
  • もし頂点kを追加してi→j間の最短経路更新が発生するということは、更新後のi→jの最短経路はi→kの最短経路+k→jの最短経路。
  • 従って、i→kの最短経路のコスト+k→jの最短経路のコストが、既に得られているi→jへの最短経路のコストよりも小さければ更新。
  • 具体的には頂点iから頂点jへの最短経路のコストをcost[i][j]とし、隣接行列で初期化した後、以下の処理を行う。
    1. k =0~Vに対し、以下の処理を繰り返す。
    2. 全頂点のペアi, jについて、既に得られている0~k-1のみを経由する最短経路cost[i][j]よりもcost[i][k] + cost[k][j]が小さければ、頂点kを追加することで経路が更新されるということなので、cost[i][j]をcost[i][k] + cost[k][j]に置き換える。
  • 上記処理完了後、costの対角成分(自分から自分への距離)が負ならば、そのグラフには負のサイクルがあり、その頂点は負のサイクルに含まれる。
  • 対角成分が負でなくても、負のサイクルに含まれる頂点から辿れる頂点も負のサイクルに含まれる。
  • 上記の方法で、各頂点が負のサイクルに含まれるか判断可能なので、頂点iから頂点jへの経路に負のサイクルを含むか否かは、終点jが負のサイクルに含まれるか否かで判断可能(この辺り、あまり文献が見つからず自分で考えたので少し怪しい)。やっぱり間違えてる(というより処理が足りない)。以下のようなグラフで、4→5は負のサイクルを含まないのに、負のサイクルありと判定されてしまう。ちゃんと伝搬元の負のサイクルに含まれる頂点が始点iから辿れるか?も判断しないといけない(時間がないのでコードの修正は別途)。


f:id:YamagenSakam:20180822120427p:plain

  • ダイクストラ法などと同様、経路復元のためどの頂点から来たかについても、cost[i][j]が更新されたときに合わせて記憶する(どう記憶するかは、少しややこしいのでソースコード内のコメント参照)。

実装のポイント

  • k=0~V、i=0~V、j=0~Vの3重ループを回す。
  • アルゴリズムの性質上、グラフの表現には隣接行列を採用。
  • 隣接行列の未接続の表現にLLONG_MAXを使っているので、呼び出し元は注意が必要。

実装

#include <algorithm>
#include <vector>
#include <functional>
#include <climits>
 
class warshall_floyd
{
private:
    int node_num;
    bool is_negative_graph;
    std::vector<std::vector<long long>> cost_matrix;//隣接行列を構築
    std::vector<std::vector<int>> pred;//経路(遷移元)を保持
   
public:
    warshall_floyd(const int node_num, const std::vector<std::vector<long long>> &adj_matrix);
    bool get_is_negative_graph() const;
    bool get_is_negative_pass(int start, int end) const;
    long long get_cost(int start, int end) const;
    std::vector<int> get_min_path(int start, int end) const;
};
 
warshall_floyd::warshall_floyd(const int node_num, 
        const std::vector<std::vector<long long>> &adj_matrix):
    cost_matrix(adj_matrix),
    node_num(node_num),
    is_negative_graph(false),
    pred(node_num + 1, std::vector<int>(node_num + 1, INT_MAX))
{
    //predの初期化
    for (int i = 0; i <= node_num; i++) {
        for (int j = 0; j <= node_num; j++) {
            if (i == j) pred[i][j] = -1;
            else if (adj_matrix[i][j] < LLONG_MAX) {
                pred[i][j] = i;
            }
        }
    }
    //ワーシャルフロイドの更新式
    for (int k = 0; k <= node_num; k++) {
        for (int i = 0; i <= node_num; i++) {
            for (int j = 0; j <= node_num; j++) {
                if (cost_matrix[i][k] < LLONG_MAX && cost_matrix[k][j] < LLONG_MAX) {
                    if (cost_matrix[i][k] + cost_matrix[k][j] < cost_matrix[i][j]) {
                        cost_matrix[i][j] = cost_matrix[i][k] + cost_matrix[k][j];
                        //経路復元用
                        //pred[i][j]にはi→jの最短経路におけるjの1つ前の頂点が入る。
                        //pred[k][j]にはk→jの最短経路におけるjの1つ前の頂点が入っている。
                        //この処理を通るということは最短経路がkを使った経路に変わる
                        //ということでありi→jへの経路の中にk→jの最短経路を含むことになる。
                        //従って新しいi→jにおけるjの1つ前の頂点は
                        //k→jにおけるjの1つ前の頂点、すなわちpred[k][j]となる。
                        pred[i][j] = pred[k][j];
                    }
                }
            }
        }
    }
 
    //負のサイクル検出
    //[注意!]間違っているので修正必要!(アルゴリズムのポイントを参照)
    for (int i = 0; i <= node_num; i++) {
        //自分への距離が負になるならその頂点は負のサイクルに含まれる
        if (cost_matrix[i][i] < 0) {
            for (int j = 0; j <= node_num; j++) {
                //負のサイクルに含まれる頂点から辿れる頂点も負のサイクルになる
                if (i == j) continue;
                if (cost_matrix[i][j] < LLONG_MAX && cost_matrix[j][j] >= 0)
                    cost_matrix[j][j] = -1;
            }
            is_negative_graph = true;
            //break;
        }
    }
}
 
bool warshall_floyd::get_is_negative_graph() const {
    return is_negative_graph;
}
 
bool warshall_floyd::get_is_negative_pass(int start, int end) const {
    //startからendにたどり着けない場合はfalseを返す
    if (cost_matrix[start][end] < LLONG_MAX) {
        //startからend辿りつける場合、
        //endが負のサイクルに含まれるならtrueを返す
        if (cost_matrix[end][end] < 0) return true;
    }
    return false;
}
 
 
long long warshall_floyd::get_cost(int start, int end) const {
    return cost_matrix[start][end];
}
 
std::vector<int> warshall_floyd::get_min_path(int start, int end) const {
    int node = end;
    std::vector<int> vec;
    vec.push_back(node);
    //終点から始点までの経路をたどる
    while (1) {
        //負のサイクルを含む頂点or
        //始点から辿れない頂点or
        //始点であれば終了
        if (cost_matrix[node][node] < 0 || 
            pred[start][node] >= INT_MAX || 
            pred[start][node] == -1) break;
        node = pred[start][node];
        vec.push_back(node);
    }
    std::reverse(vec.begin(), vec.end());
    return vec;
}

AOJのGRL_1_Cが通ることを確認。後、負サイクル検出ロジックの検証のため、上記同様atcoderのABC061Dも確認(ワーシャル・フロイドで解くには少しきつい問題設定だけどC++なら通る)。

おわりに

今回は、競プロレベルアップに向け最短経路探索系アルゴリズムに関するメモと実装を書きました。
ただ、ここに書いたコードがそのまま競プロに使えるケースはあまりないと思います。例えば、ダイクストラ法を使うにしても、辺のコストがあらかじめ与えられておらず(もしくは与えようとすると膨大なメモリ使用量になる等)キューから取り出して、コスト計算時に初めて辺のコストがわかるなんて問題もあります。そんな問題だとここに書いた実装はそのまま使えないので、カスタマイズする必要があります。
あとそのまま使えるにしても、問題文からはそのことがパッとわからない、または、アルゴリズムを使うのはわかるけどどうやって適用するかを考える必要がある(問題で与えられたコストを負値にする、スタートとゴール両方から計算する、等々)なんてこともあります。その辺りの考察力が勝負になってくる世界なので、強化するべきはそこなんだと思いますが、まずは基本を理解しないと考察力も鍛えられないかなと
ということで、次回以降はその他のグラフアルゴリズム(幅優先、深さ優先、最小全域木とか)、探索系(二分探索、尺取り法とか)、数学系(素数、modとか)、動的計画法辺りを書いていこうと思います。

交互方向乗数法による最適化と画像ノイズ除去への応用

はじめに

これまでの記事で近接勾配法と、それによるスパース解や低ランク解に導く正則化項を付随した最適化問題の解法、そしてその応用を見てきました。正則化項に変数間の絡みがなく各変数が独立に扱える場合は、正則化項のproximal operatorが解析的に求まるため近接勾配法は有効な手段です。ところが、各変数間に絡みがあり分離できない場合、一般的にproximal operatorの計算は容易ではありません(解析的に求まらない)。そこで今回は、変数間に絡みがある正則化項が付いていても最適解を導出することができる交互方向乗数法(Alternating Direction Method of Multipliers : ADMM)のアルゴリズムを見ていきます。また、それを用いた画像のノイズ除去をPythonで実装したので紹介します。

交互方向乗数法(ADMM)とは

いきなりですがADMMのアルゴリズムを記載します。まずADMMの対象となる最適化問題は下記のようなものになります。

\displaystyle \min_{{\bf x},{\bf y}} f({\bf x}) + g({\bf y}) \ \   {\rm s.t.} \ \ {\bf A}{\bf x} + {\bf B}{\bf y} = {\bf 0}  \tag{1}

この問題を解くために、下記に示す「拡張ラグランジュ関数」というものを定義します。

\displaystyle L_{\rho}({\bf x},{\bf y}, {\bf \lambda})= f({\bf x}) + g({\bf y})  + {\bf \lambda}^T ({\bf A}{\bf x} + {\bf B}{\bf y}) + \frac{\rho}{2} \| ({\bf A}{\bf x} + {\bf B}{\bf y})  \|^2_2 \tag{2}

この拡張ラグランジュ関数に関する詳細は補足の項や参考文献をご参照ください。とりあえず、ここでは通常のラグランジュ関数に等式制約が満たされないことに対する罰則項(第3項)がついたものと考えて問題ありません。この拡張ラグランジュ関数を用いると式(1)を解くADMMアルゴリズムは下記のようになります。

  1. {\bf x}{\bf y} {\bf \lambda}を適当な値で初期化する
  2. 以下の操作を収束するまで繰り返す

\displaystyle {\bf x}_k \leftarrow {\rm arg}\min_{\bf x} L_{\rho} ({\bf x}, {\bf y}_{k-1}, {\bf \lambda}_{k-1} ) \tag{3}
\displaystyle {\bf y}_k \leftarrow {\rm arg}\min_{\bf y} L_{\rho} ({\bf x}_k, {\bf y}, {\bf \lambda}_{k-1} ) \tag{4}
\displaystyle {\bf \lambda}_k \leftarrow \lambda_{k-1} + \rho({\bf A}{\bf x}_k + {\bf B}{\bf y}_k) \tag{5}

これだけです。アルゴリズムとしては非常に簡単です。特に式(3)と式(4)では{\bf x}{\bf y}それぞれに関する最適化問題を独立で解いているので、元々の最適化問題が関数同士の和が含まれており解くのが難しい場合でも、式(1)の形式に変換できればADMMにより簡単に解くことができます。このことを次の章で見ていきます。が、その前に補足として、上記アルゴリズムがどのように出てきたかを理解するために、双対上昇法、拡張ラグランジュ関数について簡単に説明します(細かいことはかなり省略するので、詳しい説明は参考文献をご参照ください)。

【補足】双対上昇法、拡張ラグランジュ関数について

ADMMは最適化問題の強双対性という性質を利用しています。ここでは、

\displaystyle \min_{{\bf x}} f({\bf x}) \ \   {\rm s.t.} \ \ {\bf A}{\bf x}  = {\bf b}  \tag{6}

という最適化問題を考えます。この問題のラグランジュ関数は、

\displaystyle L({\bf x}, {\bf \lambda})= f({\bf x}) + {\bf \lambda}^T( {\bf A}{\bf x}  - {\bf b} )\tag{7}

と定義されます。f({\bf x})が真凸関数である場合は、\nabla L({\bf x}^*, {\bf \lambda}^*) = 0となる{\bf x}^*{\bf \lambda}^*が存在すれば、{\bf x}^*が式(6)の最適解です。
更にf({\bf x})が真凸関数であるとき式(7)について下記も成り立ちます。

\displaystyle \min_{{\bf x}} \max_{{\bf \lambda}} L({\bf x}, {\bf \lambda}) = \max_{{\bf \lambda}} \min_{{\bf x}} L({\bf x}, {\bf \lambda})\tag{8}

これは強双対性と呼ばれる性質で、左辺は式(6)と同値の問題です。この強双対性を利用して最適化問題を解くのが双対上昇法や拡張ラグランジュ関数法、そして今回の主題であるADMMです。

双対上昇法

双対上昇法は式(6)の問題を強双対性から双対問題という問題に変換し、そちらを解こうという方針の手法です。まず、式(8)の右辺に問題を当てはめ変形すると、

\displaystyle 
\begin{eqnarray}
\max_{{\bf \lambda}} \min_{{\bf x}} L({\bf x}, {\bf \lambda}) &=& \max_{{\bf \lambda}} -f^*(-{\bf A}^T {\bf \lambda}) - {\bf \lambda}^T {\bf b} \\
& =&  \min_{{\bf \lambda}} f^*(-{\bf A}^T {\bf \lambda}) + {\bf \lambda}^T {\bf b}\\
& =&   \min_{{\bf \lambda}} \phi({\bf \lambda})
\end{eqnarray}
\tag{9}

という問題になります。ここで、f^*は以前説明した共役関数でf^*({\bf z})= \sup_{{\bf x}} {\bf z}^T {\bf x} - f({\bf x})です。この式(9)の問題は双対問題と呼ばれ、これを解く方法を双対上昇法と言います。更にこれを勾配法に基づいて解く場合は宇沢の方法といい、

\displaystyle {\bf \lambda}_k = {\bf \lambda}_{k - 1} - \alpha \nabla \phi({\bf \lambda}_k) \tag{10}

という更新式になります。で、この勾配を求めるためには、

\displaystyle \nabla f^*(-{\bf A}^T {\bf \lambda})= {\rm arg} \min_{\bf x}f(\bf x) + {\bf \lambda}_k^T {\bf A} {\bf x}\tag{11}

を計算する必要があります。結局のところこれは、式(7)にあるラグランジュ関数L({\bf x}, {\bf \lambda}_k){\bf x}に関する最小化です。つまり、主問題である{\bf x}に関する最適化と双対問題である{\bf \lambda}に関する最適化を交互に行うアルゴリズムとなります。
この双対上昇法では共役関数f^*の勾配を求めるため、共役関数が微分可能でなければ使えません。しかし、一般的に共役関数は微分可能ではなく、しかも関数fが強凸という性質*1を持っていなければアルゴリズムの収束も保証されません*2。そこで、次に説明する拡張ラグランジュの登場です。

拡張ラグランジュ関数法

まず式(6)の問題について、

\displaystyle \min_{{\bf x}} f({\bf x}) + \frac{\rho}{2} \| {\bf A}{\bf x}  - {\bf b} \|_2^2\ \   {\rm s.t.} \ \ {\bf A}{\bf x}  = {\bf b}  \tag{12}

という式に書きなおします。目的関数の第2項として\frac{\rho}{2} \| {\bf A}{\bf x}  - {\bf b} \|_2^2を付け加えましたが制約条件より0になるため、式(12)は式(6)と等価です。次に式(12)のラグランジュ関数は

\displaystyle L_{\rho}({\bf x}, {\bf \lambda})= f({\bf x}) + {\bf \lambda}^T ( {\bf A}{\bf x}  - {\bf b}) + \frac{\rho}{2} \| {\bf A}{\bf x}  - {\bf b} \|_2^2 \tag{13}

となります。これは拡張ラグランジュ関数と呼ばれる関数で、ラグランジュ関数に対し制約を満たさないことに対する罰則項\frac{\rho}{2} \| {\bf A}{\bf x}  - {\bf b} \|_2^2を加えたものとなっています。実はこの罰則項があることにより主問題は強凸となり、たとえfの共役関数f^*微分不可能だったとしても双対問題が滑らかになるため双対上昇法が適用できます(その理由については、参考文献をご参照ください)。このように、ラグランジュ関数に制約条件を満たさないことに対する罰則項\frac{\rho}{2} h({\bf x} )を加えることで問題を解きやすくした上で、 {\bf x} {\bf \lambda}の更新を交互していくのが拡張ラグランジュ関数法です。

ADMM再考

さて、式(1)は式(6)において変数が{\bf x}{\bf y}に分解可能という特別なケースです。式(1)を拡張ラグランジュ関数法で解こうとした時に{\bf x}{\bf y}に関する同時最適化が困難なことが多いです。そこで1回の更新で{\bf x}{\bf y}の同時最適化はあきらめ、個別に最適化していこうというのがADMMです。

変数間に絡みのある正則化項付き最適化問題のADMMによる解法

L1ノルム正則化は変数間が独立しており、proximal operatorが解析的に求まります。L2ノルム正則化(2乗しない)はルートを取る段階で変数間の絡みが出てきますがMoreau decompositionをうまく使ってproximal operator計算することができます(近接勾配法の記事参照)。p=2の重複なしのグループ正則化については、重複がないため各グループ独立したL2ノルムとしてproximal operatorとして計算できます。
しかし、上述のように変数間に絡みがあるとproximal operatorは容易には計算できません。例えば、重複ありのグループ正則化(p=2)の場合、同じ変数が複数のグループにまたがって現れるため、各グループ独立してproximal operatorを計算できなくなってしまいます。そこで、そのような正則化がつく最適化問題をうまく式(1)の形式かつg({\bf y})がproximal operatorが計算しやすい形に変形し、ADMMで解くことを考えます。重複ありグループ正則化を例にとって考えてみると、

\displaystyle \min_{{\bf x}} f({\bf x}) +C\sum_{g \in G}  \| {\bf x}_g \|_2 \tag{14}\ \ \ \ \ \ \ \ \ G=\left\{g_1, g_2, \cdots, g_k\right\}

という問題を解くことになるのですが、これを下記のような問題に変形します。

\displaystyle\min_{{\bf x}} f({\bf x}) +C \sum_{g' \in G'}  \| {\bf y}_{g'} \|_2   \ \   {\rm s.t.} \ \ {\bf y} = {\bf B}{\bf x}\tag{15}

ここで{\bf B}をうまく定義しG'がグループ間の重複がないようにします。具体的には{\bf B}_{g_1},{\bf B}_{g_2}, \cdots, {\bf B}_{g_k}


\displaystyle {\bf B}_{g_1} = \left[
    \begin{array}{ccccccc}
      1 & 0 & 0 & 0 & 0 & \ldots & 0 \\
      0 & 1 & 0 & 0 & 0 & \ldots &  0 \\
      \vdots & \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\
      0 & 0 & 0 & 0 & 0 & \ldots & 1 \\
   \end{array} 
    \right]
,\cdots,
\displaystyle {\bf B}_{g_k} =   \left[
   \begin{array}{ccccccc}
      1 & 0 & 0 & 0 & 0 & \ldots & 0 \\
      0 & 0 & 0 & 1 & 0 & \ldots &  0 \\
      \vdots & \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\
      0 & 0 & 0 & 0 & 0 & \ldots & 1 \\
    \end{array}
  \right]

などのように、それぞれ{\bf x}_g={\bf B}_g {\bf x}と左からかけると、グループgに所属する要素のみを抽出する行列とします。そして、{\bf B}=\left[{\bf B}_{g_1}^T, {\bf B}_{g_2}^T, \cdots, {\bf B}_{g_k}^T  \right]^Tとすると、式(15)より {\bf y}={\bf B}{\bf x}=\left[{\bf x}_{g_1}^T, {\bf x}_{g_2}^T, \cdots, {\bf x}_{g_k}^T\right]^T{\bf y}は各グループに対応する部分ベクトルを並べたベクトルとなります。こうすることで、正則化C \sum_{g' \in G'}  \| {\bf y}_{g'} \|_2は重複なしのグループ正則化になります。
ということで、この式(15)に対してADMMを適用することを考えます。この中で{\bf y}に関する更新式は、

\displaystyle \begin{eqnarray}
{\bf y}_k &\leftarrow& {\rm arg}\min_{\bf y} L_{\rho} ({\bf x}_k, {\bf y}, {\bf \lambda}_{k-1} ) \\
& = &   {\rm arg}\min_{\bf y}  C \sum_{g' \in G'}  \| {\bf y}_{g'} \|_2 + {\bf \lambda}_{k-1}^T ( {\bf y} - {\bf B}{\bf x}_k) + \frac{\rho}{2}\| {\bf y} - {\bf B}{\bf x}_k \|_2^2 \\
&=&  {\rm arg}\min_{\bf y}  \frac{1}{\rho} C \sum_{g' \in G'}  \| {\bf y}_{g'} \|_2 + \frac{1}{2} \|{\bf y} -({\bf B} {\bf x}_k + {\bf \lambda}_{k-1}/\rho )  \|_2^2 \\
&=& {\rm prox}_{ \psi/\rho}({\bf B} {\bf x}_k + {\bf \lambda}_{k-1}/\rho)
\end{eqnarray}
 \tag{16}

となります。ここで、 {\rm prox}_{ \psi}(\cdot)は、グループ正則化C \sum_{g' \in G'}  \| {\bf y}_{g'} \|_2のproximal operatorです(こちらの式(21)参照)。このように、元々の問題が正則化項のproximal operaorを求めるのが困難であり近接勾配法を利用できない場合でも、うまく式変形を行いADMMを適用すれば解けるようになることもあるため、ADMMはかなり強力なアルゴリズムだと言えます。

TV正則化によるノイズ除去への応用

冒頭でも述べたように、今回はこのADMMを応用して画像のノイズ除去を行ってみます。今、入力画像を{\bf v}として、この {\bf v}からのノイズ除去は、全変動正則化(TV正則化)と呼ばれる正則化\|\cdot\|_{TV}を用いて以下のような定式化が提案されています。

\displaystyle{\rm arg} \min_{\bf x} \|{\bf v} - {\bf x} \|_2^2 + C \| {\bf x}\|_{TV} \tag{17}

式(17)の第1項は出力画像{\bf x}がなるべく入力画像{\bf v}に近する働きがあります。第2項がTV正則化項で下記のような式となります。

\displaystyle \| {\bf x}\|_{TV} = \sum_{(i,j)}\sqrt{({\bf x}_{(i,j)} - {\bf x}_{(i + 1, j)})^2 +  ({\bf x}_{(i,j)} - {\bf x}_{(i, j + 1)})^2} \tag{18}

ここで(i,j)ピクセル位置を表します*3。基本的に画像は滑らかな性質を有しているはずなので、隣り合うピクセルの差は小さいはずです。そのため、このTV正則化は隣り合うピクセルの差が大きい場合の罰則項となり、画像を滑らかにしノイズ成分を除去する働きがあります。
さて、この式(17)をADMMで解くことを考えます。TV正則化項も重複ありグループ正則化と同じく変数間に絡みがあり、直接proximal operatorを計算するのは困難です。よって、重複ありグループ正則化を考えたときと同じく、{\bf y} = {\bf B}{\bf x}として、重複なしグループ正則化項に変換することを考えます。これは以下のように、


\displaystyle  {\bf B}_{g_1} = \left[
    \begin{array}{ccccccc}
      1 & -1 & \cdots & 0 & 0 &  \cdots & 0 \\
      1 & 0 & \cdots &  -1 & 0 &  \cdots &  0 
   \end{array}
    \right]
,
\displaystyle {\bf B}_{g_2} = \left[
    \begin{array}{ccccccc}
      0 & 1 & -1 & \cdots & 0 &  \cdots & 0 \\
      0 & 1 & 0&  \cdots & -1 &  \cdots &  0 
   \end{array}
    \right]
, \cdots

と、{\bf B}_g {\bf x} = \left[{\bf x}_{(i,j)} - {\bf x}_{(i + 1, j)},  {\bf x}_{(i,j)} - {\bf x}_{(i, j+1)}\right]となるように定義すれば、式(18)は{\bf B}_g {\bf x}のL2ノルムの和なので、 {\bf B}=\left[{\bf B}_{g_1}^T, {\bf B}_{g_2}^T, \cdots, {\bf B}_{g_k}^T  \right]^TとすればTV正則化を重複なしグループ正則化に変換することができます。よって、この{\bf B}を用いれば式(16)の更新式で{\bf y}を更新することができます。一方、{\bf x}についての更新式は、式(3)について{\bf x}微分して{\bf 0}となる点を求めればよいので解析的に求まり、以下のようになります。

\displaystyle \begin{eqnarray}
{\bf x}_k &\leftarrow& {\rm arg}\min_{\bf x} L_{\rho} ({\bf x}, {\bf y}_{k-1}, {\bf \lambda}_{k-1} ) \\
& = &  (2 {\bf I} + \rho {\bf B}^T {\bf B})^{-1} (2 {\bf v} - {\bf B}^T ({\bf  \lambda}_{k-1} - \rho  {\bf y}_{k-1})) 
\end{eqnarray}
\tag{19}

実装と実験

ノイズ除去アルゴリズムの実装

ということで、TV正則化によるノイズ除去を実装し、実験してみました。ここで、全てのピクセルを1つのベクトルとして扱ってしまうと画像サイズによってはメモリが足りなくなってしまうので、超解像を行ったときのように画像パッチを切り出し、それぞれのパッチに対してノイズ除去を行った後、そのパッチをつなぎ合わせて画像を再構築するという方法を取りました。なお、この画像切り出しやパッチのつなぎ合わせなどは、超解像の記事で記した「おれおれ画像・ベクトル・行列変換関数群」も利用してるので、そちらもご参照ください。

import os
import sys
import numpy as np
import scipy as sci
from scipy import sparse
from numpy.random import *
from PIL import Image

#ブロックソフト閾値計算
def block_soft_thresh(b, lam):
    return max(0, 1 - lam/np.linalg.norm(b, 2)) * b

#グループ正則化のproximal operator
def prox_group_norm(v, groups, gamma, lam):
    u = np.zeros(v.shape[0])
    for i in range(1, np.max(groups) + 1):
        u[groups == i] = block_soft_thresh(v[groups == i] , gamma * lam)
    return u

#グループ正則化計算
def gloup_l1_norm(x, groups, p):
    s = 0
    for i in range(1, np.max(groups) + 1):
        s += np.linalg.norm(x[groups == i], p)
    return s

#関数をメモ化する関数
def memoize(f):
    table = {}
    def func(*args):
        if not args in table:
            table[args] = f(*args)
        return table[args]
    return func

#ADMMを行う関数
#argmin_x, argmin_y, update_lam, objectiveはそれぞれ変数名に準ずる計算を行う関数
def ADMM(argmin_x, argmin_y, update_lam, p, objective, x_init, y_init, lam_init, tol= 1e-9):
    x = x_init
    y = y_init
    lam = lam_init
    result = objective(x)
    while 1:
        x_new = argmin_x(y, lam, p)
        y_new = argmin_y(x_new, lam, p)
        lam_new = update_lam(x_new, y_new, lam, p)
        result_new = objective(x_new)
        if result_new < tol or (np.abs(result - result_new)/np.abs(result) < tol) == True :
            break
        x = x_new
        y = y_new
        lam = lam_new
        result = result_new
    return x_new, result_new


#TV正則化付きのスムージング
#N,M:画像のピクセル行数、列数
#v:入力画像のベクトル
#B:正則化の変換行列(スパースなので、scipyのsparse行列を渡す)
#groups:変換後ベクトルの各要素の所属グループ
#C:正則化の係数
#p:拡張ラグランジュの係数
def TV_reg_smoothing(N, M, v, B, groups, C, p) :
    #inv(2I + pB^T B)を計算する関数。アルゴリズムによってはpが可変なので、pを引数として受け取る。
    #関数をメモ化して同一のpが入力された場合、再計算不要としている。
    inv_H = memoize(lambda p:np.array(np.linalg.inv( 2.0 * np.eye(B.shape[1], B.shape[1]) + p * (B.T * B))))

    argmin_x = lambda y, lam, p:np.dot(inv_H(p), 2.0 * v - np.array(B.T * (-p * y + lam)))
    argmin_y = lambda x, lam, p:prox_group_norm((B * x) + lam/p, groups, 1.0/p, C)
    update_lam = lambda x, y, lam, p:lam + p*((B * x) - y)
    objective = lambda x:np.linalg.norm(v - x, 2)**2 + C * gloup_l1_norm((B * x), groups, 2)

    x_init = np.random.randn(B.shape[1])
    y_init = (B * x_init)
    lam_init = np.zeros(B.shape[0])

    (x, result) = ADMM(argmin_x, argmin_y, update_lam, p, objective, x_init, y_init, lam_init, 1e-9)

    return x, result

#グループ変換行列を計算する関数
#N,M:画像のピクセル行数、列数
def calc_group_trans_mat(N, M):
    B = sci.sparse.lil_matrix((2 * (M - 1) * (N - 1) + (M - 1) + (N - 1), M * N))
    groups = np.zeros(B.shape[0], 'int')
    k = 0
    for i in range(N):
        for j in range(M):
            base = i * M + j
            if i < N -1 and j < M -1:
                B[k, base] = 1
                B[k, base + 1] = -1
                B[k + 1, base] = 1
                B[k + 1, base + M] = -1
                groups[k] = int(k/2)  + int(k % 2) + 1
                groups[k + 1] = int(k/2)  + int(k % 2) + 1
                k += 2
            #一番下の行のピクセルは右隣のピクセルとの差分のみ計算
            elif i >= N - 1 and j < M - 1:
                B[k, base] = 1
                B[k, base + 1] = -1
                groups[k] = int(k/2) + int(k % 2) + 1
                k += 1
            #一番右の劣のピクセルは下のピクセルとの差分のみ計算
            elif i < N - 1 and j >= M - 1:
                B[k, base] = 1
                B[k, base + M] = -1
                groups[k] = int(k/2) + int(k % 2) + 1
                k += 1
    return B, groups

#グレースケール画像に対するノイズ除去
#img:画像データ(PILのimageクラス)
#patch_hight:切り出す画像パッチの高さ
#patch_width:切り出す画像パッチの幅
#shift_num:画像を切り出す際のずらし量(重複して切り出してもOK)
#C:正則化の係数
#p:拡張ラグランジュの係数
def denoise_gray_img(img, patch_hight, patch_width, shift_num, C, p):
    #グループ変換行列の計算
    [B, groups] = calc_group_trans_mat(patch_hight, patch_width)

    #画像パッチの切り出し
    patchs = gray_im2patch_vec(img, patch_hight, patch_width, shift_num)

    new_patchs = np.zeros((patch_hight * patch_width, patchs.shape[1]))
    for i in range(patchs.shape[1]):
        #各パッチに対してノイズ除去を施す
        new_patchs[:, i] = TV_reg_smoothing(patch_hight, patch_width, patchs[:, i], B, groups, C, p)[0]

    #パッチをつなぎ合わせて画像の再構成
    [new_img, img_arr] = patch_vecs2gray_im(new_patchs, img.size[0], img.size[1], patch_hight, patch_width, shift_num)
    return new_img

細かい点はソースコードとそのコメントをご参照ください。とりあえず、denoise_gray_imgにPILのimageクラスのオブジェクトと切り出すパッチのサイズ、パッチ切り出し時のずらし量を渡せばノイズ除去されたimageクラスのオブジェクトを返してくれるように実装しています。超解像の時と同様、パッチを切り出す際に領域の重複を許し、再構成時に重ね合わせて平均を取るようにしています。また、画像の最後の行および列は右隣のピクセル(および下のピクセル)とのみ差分を取るようにBを工夫しています(calc_group_trans_mat関数内)。
注意する点としては、TV_reg_smoothingが受け取るBはスパースな行列なので、scipyのsparse行列のクラス渡します。また、同関数内にあるinv_Hはpを引数にとり、式(19)の逆行列部分を計算する関数ですが、同一のpに対して再度同じ計算を行うと時間がかかるので、memoizeによりメモ化し同じpが入力されたときに再計算しないようにしています。
なお、更新式はADMM関数の中に直接記述するのではなく、ADMM関数はそれぞれの更新式の計算を行う関数を引数として受け取るようにし、目的関数の出力が収束するまでそれらの更新式関数を呼び出すという設計にしているため、更新式の定義もTV_reg_smoothing関数の中で行っています。

実験と結果

いつものように256×256のレナさん画像で実験を行います。
f:id:YamagenSakam:20180630185527p:plain

これに対し、標準偏差30のガウスノイズを乗せた画像に対してノイズ除去を行ってみました。

def main():
    test_img = Image.open("ADMM_input.png")
    img_arr = np.array(test_img.convert('L'),'float')
    img_arr = (img_arr  + 30 * (np.random.randn(test_img.size[0], test_img.size[1]) ))
    img_arr[img_arr >255] = 255
    img_arr[img_arr <0] = 0
    test_img = Image.fromarray(np.uint8(img_arr))
    patch_hight = 20
    patch_width = 20
    shift_num = 10
    C = 50
    p = 20
    img2 = denoise_gray_img(test_img, patch_hight, patch_width, shift_num, C, p)
    img2.save("ADMM_result.png")

画像パッチのサイズを20×20ずらし量10として切り出し、C=50、p=20で実験を行った結果は以下のようになります。


f:id:YamagenSakam:20180701123722p:plain f:id:YamagenSakam:20180701123739p:plain 
左が入力画像(標準偏差30のガウスノイズを重畳)、右がノイズ除去後の画像

元画像レベルとまではいきませんが、それなりにノイズ除去ができているかと思います。ただ、パッチを切り貼りしている影響か、少しゆがんでしまっているようにも見えます。更に追実験として入力画像のノイズレベルを標準偏差50と上げ、その時のノイズ除去もやってみました(C=80、p=1で実験)。


f:id:YamagenSakam:20180630163116p:plain f:id:YamagenSakam:20180630163146p:plain
左が入力画像(標準偏差50のガウスノイズを重畳)、右がノイズ除去後の画像

ノイズレベルの割にはある程度ノイズ除去できているかと思いますが、TV正則化項の罰則に引っ張られ画像がノッペリとしてしまいました。Cをもっと調整すればもう少し改善できるかも知れません。

まとめ

今回は近接勾配法よりも幅広い問題を解くことができるADMMについて説明し、それを用いて画像のノイズ除去をやってみました。今回紹介した重複ありグループ正則化やTV正則化以外にも、特徴間のグラフ構造や階層構造なんかを表す正則化項がついた最適化問題もADMMを使えば解けるようになり、加えて別途制約条件がある場合でも式変形をうまくやればADMMで解けるようになるケースもあるなど応用範囲が広い手法です。更に収束も割りと速く、\rhoの値も適当に決めても0より大きければ大体収束するので、使い勝手もなかなか良いです。

参考文献

今回はこちらの書籍の主に12章と15章を参考にしました。

機械学習のための連続最適化 (機械学習プロフェッショナルシリーズ)

機械学習のための連続最適化 (機械学習プロフェッショナルシリーズ)

*1:関数f({\bf x})があって、f({\bf x}) - \frac{\mu}{2}\| {\bf x}\|が凸関数ならば、関数f\mu強凸であるといいます。fが凸関数でなくとも、強凸である可能性はあります。

*2:逆に強凸であれば共役関数は微分可能でありアルゴリズムの収束が保証されます。

*3:便宜上、(i,j)と2次元の添え字で表していて{\bf x}が行列データのように見えますが、実際は画像のピクセル値を直列に連結したベクトルとして扱っています。

近接勾配法応用編その3~トレースノルム正則化項付きロジスティック回帰による行列データの分類~

はじめに

前回の記事では、ベクトルデータの分類問題に対するスパースなロジスティック回帰を説明しましたが、今回はその拡張となる行列データに対するロジスティック回帰を見ていきます。この行列データ分類問題においても正則化項は重要な役割を持ちますが、中でもトレースノルムの正則化項を用いることで低ランクな解が得られ、多次元の時系列データの分類などに有効であることを見ていきます。

行列データの分類

まず、学習データ集合\left\{{\bf X}_i, y_i \right\}(i=1⋯N)があるとします。ここで、{\bf X}D \times Tの行列、 yy \in \left\{0,1 \right\}のラベルです。この学習データ集合から係数行列{\bf W}とバイアス項bを学習し、以下に示すような線形モデルで分類することが今回のテーマです。

\displaystyle a({\bf X}) = {\rm Tr}({\bf X}^T {\bf W}) + b \tag{1}

行列データ分類問題の特徴

では、行列データを直接扱うことにどのようなメリットがあるのでしょうか。確かに、行列データ{\bf X}をベクトルに変換して(列ベクトルを縦に連結するなどして)、学習・分類するという方法も考えられます。しかしベクトルにしてしまうと行列構造がなくなり、行と列の意味合いが異なる場合に重要な情報が失われる可能性があります。
例えば、複数のセンサーから取得される多次元の時系列データの場合、行がセンサー、列が時間に対応する行列データとしてとらえることができます。このようなデータを連結してベクトルとして扱ってしまうと、空間方向と時間方向の特徴を一緒くたにしてしまい、重要な情報が失われてしまう可能性があります。特にこのような複数センサーのデータの場合、実際に観測できるデータは{\bf X} = {\bf A}{\bf Z}のように、潜在変数(信号源){\bf Z}の空間的な線形結合で表されることがしばしばあり、クラス分類において有用な情報はこの潜在変数に隠れていることも多いです(下図参照)。


f:id:YamagenSakam:20180613220313p:plain

一方で、行列データとして直接扱うことの意味は、{\bf W} =\sum_{i = 1}^{r} \sigma_i {\bf u}_i {{\bf v}_i}^T 特異値分解して式(1)を以下のように変形すると見えてきます。

\displaystyle a({\bf X}) = \sum_{i=1}^{r}(\sigma_i {{\bf u}_i}^T {\bf X} {\bf v}_i) + b \tag{2}

このようにすると、{\bf u}は空間方向の特徴を抽出する空間フィルタ、{\bf v}は時間方向の特徴を抽出する時間フィルタとして考えることができ、時空間両方の特徴を捉えられることがわかります。上述のように観測されるデータが潜在変数の空間的な線形結合である場合にも、観測データに空間フィルタをかけることでクラス分類に有用な潜在的な情報を抽出した上で時間方向に係数をかけるので、時空間両方の特徴をとらえることができます。

行列データロジスティック回帰の目的関数

行列データの分類問題においてもロジスティック回帰は使えます。具体的には下記の誤差関数を最小化する{\bf W}bを学習データ集合から求めればよいです。
\displaystyle E({\bf W}, b) =-\sum_{i=1}^{N} \left\{y_i \ln \sigma({\rm Tr}({{\bf X}_i}^T {\bf W}) + b ) + (1 - y_i) \ln (1 - \sigma({\rm Tr}({{\bf X}_i}^T {\bf W}) + b )) \right\} \tag{3}

トレースノルム正則化による低ランク化

さて、式(2)にあるように係数行列{\bf W}のランクrが高いほど、かけ合わせる時空間フィルタの個数は増えます。もし、{\bf W}がフルランクならr = \min(D,T)個の時空間フィルタをかけて重み付き和を取ることになります。しかし、上述のような潜在変数の中でもクラス分類に有用な情報は、(空間的に)ごく一部の成分でありそれ以外はノイズというケースも多くあります。例えば、上の図の例だと、3つの信号源のうちクラス分類に有用な成分は1つです。この場合に{\bf W}がフルランクだと必要以上に複雑なモデルとなってしまい過学習となる恐れがあります。
ということで、{\bf W}を低ランクにしたいのですが、そのためにはこちらの記事で少し述べたトレースノルムの正則化 \|{\bf W} \|_{\rm trace}= \sqrt{{\rm Tr}({\bf W}^T {\bf W})}をつけると特異値がスパースになり、結果低ランクな解が得られます。具体的な式としては、
\displaystyle F({\bf W}, b) =E({\bf W}, b)  +  \|{\bf W} \|_{\rm trace} \tag{4}
であり、これを近接勾配法で解いていきます。なお、今回はバイアス項b正則化の中に含めていないことに注意してください*1b正則化項に含まれないので、{\bf W}は近接勾配法の更新式、bは通常の勾配法の更新式で更新していきます。それぞれの更新式で必要な式(3)の{\bf W}による微分および式(4)のbによる微分は、
\displaystyle \nabla E({\bf W}) = \sum_{i=1}^{N} (\sigma({\rm Tr}({{\bf X}_i}^T {\bf W})  + b) - y_i) {\bf X}_i \tag{5}
\displaystyle \frac{ dF}{db} = \sum_{i=1}^{N} (\sigma({\rm Tr}({{\bf X}_i}^T {\bf W})  + b) - y_i)  \tag{6}
となるので、後はこれを元に{\bf W}bの更新を交互に行えばよいです。

実装と実験

アルゴリズムの実装

今回の行列データロジスティック回帰は下記のように実装しました。基本的に今までと同じですが、proximal operatorはトレースノルム正則化のproximal operator計算を行うようにしており、近接勾配法の関数はバイアス項の更新(勾配法計算)も行うように変更しています。

import numpy as np
from numpy.random import *
import matplotlib.pyplot as plt
import sys

#ソフトしきい値作用素の計算
def soft_thresh(b, lam):
    x_hat = np.zeros(b.shape[0])
    x_hat[b >= lam] = b[b >= lam] - lam
    x_hat[b <= -lam] = b[b <= -lam] + lam
    return x_hat

#トレースノルム正則化のproximal operatorの計算
def prox_trace_norm(V, gamma, lam):
    [L, sig, R] = np.linalg.svd(V)
    sig_ = soft_thresh(sig, lam * gamma)
    if L.shape[1] > sig.shape[0]:
        L = L[:,:sig.shape[0]]
    if R.shape[1] > sig.shape[0]:
        R = R[:sig.shape[0], :]
    return np.dot(np.dot(L, np.diag(sig_)), R)

#近接勾配法とバイアス項の勾配法計算
def proximal_gradient_with_bias(grad_f, grad_b, prox, gamma, objective, init_x, init_b, tol = 1e-9):
    x = init_x
    b = init_b
    result = sys.maxint
    while 1:
        x_new = prox(x - gamma * grad_f(x, b), gamma)
        b_new = b - gamma * grad_b(x, b)
        result_new = objective(x_new, b_new)
        if result_new < tol or (np.abs(result - result_new)/np.abs(result) < tol) == True :
            break;
        x = x_new
        b = b_new
        result = result_new
    return x_new, b_new, result_new

#トレースノルム正則化項付きロジスティック回帰
#X:学習用データベクトルを列ベクトルとして並べた行列
#y:ラベルを並べたベクトル
def trace_norm_LogisticRegression(X, y, lam):
    sigma = lambda a : 1.0/(1.0+np.exp(-a))
    p=lambda Z, W, b:sigma(np.trace(np.dot(Z.T, W).T) + b) #メモ化したい・・・
    objective = lambda W, b: - np.sum( y * np.log(p(X, W, b)) +
                        (1 - y) *np.log(( 1 - p(X, W, b)))) + lam * np.linalg.norm(W, 'nuc')
    grad_E = lambda W, b: np.dot(X, p(X, W, b) - y) #勾配
    grad_b = lambda W, b: np.sum(p(X, W, b) - y)

    (u, l, v) = np.linalg.svd(X)
    Gamma =1.0/max(np.average(l.real*l.real,0)) 
    prox = lambda V, gamma:prox_trace_norm(V, gamma, lam) #proximal operator
    W_init = np.zeros((X.shape[0], X.shape[1]))
    b_init = 0
    (W, b, result) = proximal_gradient_with_bias(grad_E, grad_b, prox,
                                                    Gamma, objective, W_init, b_init, 1e-4)
    return W, b, result

なお、停止判定の方法については収束性を考慮して元の方法に戻しています。また、p(Z, W, b)の計算がかなり重いですが1回の更新で4回計算しているので、pをメモ化するなどの工夫を行えばもう少し速くできます(pの引数ZとWはnumpyのarray型なので簡単に辞書のキーにできず断念)。
あと、学習のステップ幅は今回かなりテキトーに決めた実装になっているので、その辺りはもう少し検討が必要です。

実験用疑似データ

実験用疑似のデータは下記のコードで生成しました。

    #学習データ
    #30×100の行列データ負例500個、正例500個
    #信号源のデータZ
    Z_train1=np.random.randn(30,100,500)
    Z_train2=np.random.randn(30,100,500)
    #30個の信号源のうち2個のみクラス分類に有用とする
    Z_train2[0, 20:50, :] += 0.2
    Z_train2[29, 80:90, :] -= 0.5
    Z_train = concatenate((Z_train1,Z_train2),2)

    #信号源データZを観測データに変換する行列A
    A = np.exp(-np.abs(np.tile(np.linspace(0, 1,30), (30,1)) -
            np.tile(np.linspace(0, 1,30), (30,1)).T)/0.1) + np.random.randn(30, 30)

    #実際に観測されるデータX
    X_train = np.dot(Z_train.T, A).T

    y_train1 = ones((500,1))
    y_train2 = zeros((500,1))
    y_train = concatenate((y_train1,y_train2),0).squeeze()

    #テストデータ
    #学習データと同じく30×100の行列データ負例500個、正例500個
    Z_test1=np.random.randn(30,100,500)
    Z_test2=np.random.randn(30,100,500)
    Z_test2[0, 20:50, :] += 0.2
    Z_test2[29, 80:90, :] -= 0.5
    Z_test = concatenate((Z_test1,Z_test2),2)

    X_test = np.dot(Z_test.T, A).T

    y_test1=ones((500,1))
    y_test2=zeros((500,1))
    y_test=concatenate((y_test1,y_test2),0).squeeze()

学習、テストに使用するデータはX_train、X_testですが、上述の信号源からの線形結合を模擬するため、まず潜在変数(信号源)Z_train、Z_testの生成を行いました。この信号源のデータは30成分ありますが、0番目の成分と29番目の成分のみがクラス分類に有用となるようなデータとしています。これを行列Aで線形結合して観測データXを生成しています。なお、行列Aは近くの信号程大きな重み、離れるほど小さな重みで足し合わせるような線形結合としており、更にこのAにもランダム成分を加えています。以下、正例・負例におけるXとZの平均波形を表示します。

  • y=1に対応する観測データ集合X_train1の平均波形

f:id:YamagenSakam:20180613230346p:plain

  • y=0に対応する観測データ集合X_train2の平均波形

f:id:YamagenSakam:20180613230616p:plain

  • y=1に対応する信号源データ集合Z_train1の平均波形

f:id:YamagenSakam:20180613230428p:plain

  • y=0に対応する信号源データ集合Z_train2の平均波形

f:id:YamagenSakam:20180613230713p:plain


やはり、観測データはy=0クラスとy=1クラスの平均データを見比べても、ノイズにまみれて差異があまり分からなくなっていますが、信号源データの場合は、2信号のみy=0クラスとy=1クラスで差異があることがはっきりわかります。

実験結果

このように生成した疑似データに対し、以下のコードのように\lambda=0, 50, 100 \cdots, 2000と変化させて、その時の推定精度とWのランクを求めてみました。

    num = 41
    scale = 50
    rank_of_W  = np.zeros(num)
    acc_arr = np.zeros(num)
    for i in range(0, num):
        (W, b, r) = trace_norm_LogisticRegression(X_train, y_train.squeeze(), i * scale)
        y_result = np.trace(np.dot(X_test.T, W).T) + b
        y_result[y_result>=0] = 1
        y_result[y_result<0] = 0
        rank_of_W[i] = np.linalg.matrix_rank(W)
        acc_arr[i] = np.sum(y_result == y_test)/float(y_test.shape[0])


その結果が以下の図です(赤が推定精度、青がランクです)。

f:id:YamagenSakam:20180613231839p:plain

L1ノルムロジスティック回帰と同様、\lambdaを大きくしすぎても、制約がきつくなりすぎるのか精度が徐々に下がっていきます。実際今回の条件で推定精度が最大になったのは\lambda = 1250のときで、精度が68.0%、{\bf W}のランクが4でした。以下には、この\lambda = 1250での{\bf W}特異値分解して得られる空間フィルタ{\bf U}を、y=0クラスのX_trainとy=1クラスのX_trainの平均行列に施した波形を示します。

  • y=0に対応する観測データ集合X_train1の平均に{\bf U}をかけて得られる波形

f:id:YamagenSakam:20180613232012p:plain

  • y=1に対応する観測データ集合X_train2の平均に{\bf U}をかけて得られる波形

f:id:YamagenSakam:20180613231951p:plain

この図の中で青の波形は最大特異値に対応する空間フィルタで抽出された波形、緑の波形は2番目に大きい特異値に対応する空間フィルタで抽出された波形です。この両者の注目すると、空間フィルタを施すことによって信号源に潜在しているクラス分類に有用な成分を抽出できていることがわかります。

まとめ

今回は行列データの分類アルゴリズムについて、ロジスティック回帰を例にとってトレースノルム正則化項によって低ランクな解が得られること、多次元の時系列データの分類問題などに有効であることを見てきました。トレースノルムの正則化は信号源の中でも一部の成分のみがクラス分類に有用であるとき特に威力を発揮します。
なお、時系列データの分類には今回紹介した方法以外にも隠れマルコフモデルを使った手法やリカレントニューラルネットワークの一種であるLong Short-Term Memory(LSTM)を使った手法もあります。特に最近は深層学習の流れもあってLSTMが主流になってきているとのことなので、そろそろニューラルネット辺りにも手を出していきたいところです。

参考文献

以前も紹介したこちらの本の第8章が非常に参考になります。

スパース性に基づく機械学習 (機械学習プロフェッショナルシリーズ)

スパース性に基づく機械学習 (機械学習プロフェッショナルシリーズ)

あと、同著者によるこちらのスライドもわかりやすいです。
行列およびテンソルデータに対する機械学習

*1:どうやら、前回のベクトルデータの分類においても、基本的にバイアス項は正則化には含めない方が良いようです

近接勾配法応用編その2 ~L1ノルム正則化項によるスパースなロジスティック回帰~

今回は前々回の記事で書いた近接勾配法の応用例第2段ということで、分類問題などで使われるロジスティック回帰の目的関数にL1ノルムの正則化項をつけて、スパースな解を導いてみます。スパースな解を導出するメリットとしては、クラス分類に寄与しない無駄な属性の係数が自動的に0となり、クラス間で有意差がある属性を見つけ出せ、かつ、分類モデルの複雑度も下げられるので過学習の防止も期待できるというメリットがあります。なお、ロジスティック回帰は多クラス分類にも利用できますが、今回は2クラス分類の例で説明します。

ロジスティック回帰

ロジスティック回帰について詳しいことはたくさん文献が出ているのでそちらを参照いただくとして、ここでは概略と最終的な目的関数を示します。
まず、学習データ集合 \left\{ {\bf x}_i, y_i \right\} (i = 1 \cdots N)があるとします。ここで、{\bf x}d次元のデータベクトル、yy \in \left\{0, 1\right\}のラベルです。
ここで、{\bf x}が与えられてy = 1となる事後確率は
\displaystyle p(y = 1| {\bf x}) = \frac{p(y=1, {\bf x})}{p(y=1, {\bf x}) + p(y = 0, {\bf x})} =  \frac{1}{1 + \exp(-a) }= \sigma(a) \tag{1}
と表せます。ここで、a
\displaystyle a= \ln{\frac{p(y=1, {\bf x})}{p(y=0, {\bf x})}} \tag{2}
という式であり、ロジット関数とも呼ばれます。{\bf x}がクラスy = 1に所属する確率が大きければこのロジット関数は大きな値を取るので事後確率\sigma(a)は1に近づきますし、逆にクラスy=0に所属する確率が大きければロジット関数は小さな値を取るので\sigma(a)は0(つまり、クラスy=0となる事後確率1-\sigma(a)は1)に近づきます。
そしてこのa{\bf x}の線形関数としてモデル化したのがロジスティック回帰です。具体的には、


 \displaystyle a = {\bf w}^T {\bf x} + w_0 = \tilde{{\bf w}}^T \tilde {{\bf x}} \tag{3}
ただし、 \tilde{{\bf w}}= \left[ w_0 , {\bf w} \right],  \tilde{{ \bf x}}= \left[1,  {\bf x}\right] *1
となります。
さて、これまでの内容を踏まえて、学習データ集合に対する尤度関数を考えます。上記の議論より、{\bf x}_i\tilde{{\bf w}}が与えられて、そのクラスがy_i = 1である尤度は\sigma(a) = \sigma(\tilde{{\bf w}}^T \tilde{{\bf x}}_i)です。一方で、y_i = 0である尤度は (1-\sigma(\tilde{{\bf w}}^T \tilde{{\bf x}}_i))です。このように考えると尤度関数は、
 \displaystyle p({\bf y}| \tilde{{\bf w}})= \prod_{i=1}^{N}  \sigma(\tilde{{\bf w}}^T \tilde{{\bf x}}_i)^{y_i}  (1 - \sigma(\tilde{{\bf w}}^T \tilde{{\bf x}}_i))^{(1- y_i)} \tag{4}
となります。後は式(4)を最大化する\tilde{{\bf w}}を求めればよいです。ただ、この式(4)のままだと計算しにくいので、以下のような式(4)の負の対数を取った誤差関数を最小化することを考えます。
\displaystyle E(\tilde{{\bf w}}) = - \ln { p({\bf y}| \tilde{{\bf w}}) } = -\sum_{i=1}^{N} \left\{  {y_i} \ln{\sigma(\tilde{{\bf w}}^T \tilde{{\bf x}}_i)} + (1- y_i)\ln{(1 - \sigma(\tilde{{\bf w}}^T \tilde{{\bf x}}_i))} \right\} \tag{5}
これがロジスティック回帰の目的関数です。

正則化項によるスパース化と近接勾配法の適用

ここまでで普通のロジスティック回帰の目的関数が得られました。このロジスティック回帰をスパースな解にするためには、L1ノルム正則化項をつけた式
 \displaystyle F( \tilde{{\bf w}})= -\sum_{i=1}^{N} \left\{  {y_i} \ln{\sigma(\tilde{{\bf w}}^T \tilde{{\bf x}}_i)} + (1- y_i)\ln{(1 - \sigma(\tilde{{\bf w}}^T \tilde{{\bf x}}_i))} \right\} + \| \tilde{{\bf w}} \|_1 \tag{6}
を近接勾配法により最小化すればよいです。近接勾配法を解くためには、元の目的関数E(\tilde{{\bf w}})の勾配と正則化\| \tilde{{\bf w}} \|_1のproximal operatorが必要でした。
正則化項のproximal operatorは以前の記事で導出したとおりです。また、E(\tilde{{\bf w}})の勾配は、
\displaystyle \nabla E(\tilde{{\bf w}}) = \sum_{i=1}^{N} (\sigma(\tilde{{\bf w}}^T \tilde{{\bf x}}_i) - y_i) \tilde{{\bf x}}_i \tag{7}
となります。後は、これらに基づき近接勾配法を適用するだけです。

実装と実験

これまでの内容pyhonコードで実装すると下記のようになります。なお、近接勾配法のアルゴリズム部分についても、前々回記事で記したコードから停止条件部分を修正したので、改めて近接勾配法の関数も書きます(以前の記事から変更していない関数については下記に記していません)。

import numpy as np
from numpy.random import *
import matplotlib.pyplot as plt

#近接勾配法
def proximal_gradient(grad_f, prox, gamma, objective, init_x, tol = 1e-9):
    x = init_x
    result = objective(x)
    while 1:
        x_new = prox(x - gamma * grad_f(x), gamma)
        result_new = objective(x_new)
        #停止条件を説明変数差分のノルムに変更
        #if result_new < tol or (np.abs(result - result_new)/np.abs(result) < tol) == True :
        if np.linalg.norm(x_new - x, 2) < tol:
            break;
        x = x_new
        result = result_new
    return x_new, result_new

#l1ノルム正則化ロジスティック回帰
#X:学習用データベクトルを列ベクトルとして並べた行列
#y:ラベルを並べたベクトル
def l1norm_LogisticRegression(X, y, lam):
    sigma = lambda a : 1.0/(1.0+np.exp(-a))
    p=lambda Z, w:sigma(np.dot(Z.T, w))  # ※1下記に補足あり
    X=concatenate((ones((1,size(X,1))), X), 0) #バイアス項の追加
    objective = lambda w: - np.sum( y * np.log(p(X, w)) +
                        (1 - y) *np.log(( 1 - p(X, w)))) + lam * np.sum(np.abs(w))
    grad_E = lambda w: np.dot(X, p(X, w) - y) #※2下記補足あり
    (u, l, v) = np.linalg.svd(X)
    gamma = 1.0/max(l.real*l.real)
    prox = lambda v, gamma:prox_norm1(v, gamma, lam) #proximal operator
    w_init = np.zeros(X.shape[0])
    (w_hat, result) = proximal_gradient(grad_E, prox, gamma, objective, w_init, 1e-12)
    return w_hat, result

繰り返し構文を避けるため、行列・ベクトル演算で勾配計算などを行っています。これを少し補足しておくと、まずソースコード※1に当たる部分は、p({\bf Z}, {\bf w}) = \left[\sigma({\bf w}^T {\bf z}_1) , \cdots, \sigma({\bf w}^T {\bf z}_N) \right]^Tというベクトルを返す関数として定義しています。そして、※2の部分では、
\displaystyle{\bf X}(p({\bf X}, {\bf w}) - {\bf y}) = \left[ {\bf x}_1, \cdots, {\bf x}_N \right] \left[\sigma({\bf w}^T {\bf x}_1)  - y_1, \cdots, \sigma({\bf w}^T {\bf x}_N)  - y_N  \right]^T= \sum_{i=1}^{N} {\bf x}_i (\sigma({\bf w}^T {\bf x}_i)  - y_i)
というように、行列・ベクトルの演算で式(7)の勾配計算を行う関数を定義しています。
このL1ノルム正則化つきロジスティック回帰を疑似データで実験してみました。疑似データは以下のようなコードで生成しました。

def main():
    #学習データ
    #最初の2次元だけ分類に寄与する属性
    X_train1=np.random.randn(2,200)
    X_train2=np.random.randn(2,200) + ones((2,200))
    X_train = concatenate((X_train1,X_train2),1)
    #残り123次元は分類に寄与しないランダムな属性
    X_train = concatenate((X_train, np.random.randn(123, 400)), 0)
    y_train1 = ones((200,1))
    y_train2 = zeros((200,1))
    y_train = concatenate((y_train1,y_train2),0).squeeze()

    #テストデータ
    X_test1=np.random.randn(2,200)
    X_test2=np.random.randn(2,200)+ones((2,200))
    X_test = concatenate((X_test1,X_test2),1)
    X_test = concatenate((X_test, np.random.randn(123, 400)), 0)
    y_test1=ones((200,1))
    y_test2=zeros((200,1))
    y_test=concatenate((y_test1,y_test2),0).squeeze()

疑似データの次元は125次元ですが、その内最初の2次元だけクラス分類において意味があり、それ以外の123次元は全くのランダムでノイズとなる属性です。また、y=0のクラスにのみオフセットをはかせている形なので、バイアス成分もクラス分類に寄与します。そのため、今回のアルゴリズムを適用すると、最初の2次元の係数およびバイアス項のみが非0となり、それ以外の係数は0になることが期待されます。
このように生成した疑似データに対し、以下のように\lambda=0, 1, \cdots, 39と変化させて、その時の推定精度とwの非0要素の数を求めてみました。

    non_zero_cnt_arr = np.zeros(40)
    acc_arr = np.zeros(40)
    for lam in range(0, 40):
        (w, r) = l1norm_LogisticRegression(X_train, y_train.squeeze(), lam)
        y_result = np.dot(w[1:126], X_test) + w[0]
        y_result[y_result>=0] = 1
        y_result[y_result<0] = 0
        non_zero_cnt_arr[lam] = np.sum(np.abs(w)>1e-12)
        acc_arr[lam] = np.sum(y_result == y_test)/float(y_test.shape[0])

    fig, ax1 = plt.subplots()
    ax1.plot(acc_arr, "r.-")
    ax1.set_ylabel("accuracy")
    ax2=plt.twinx()
    ax2.plot(non_zero_cnt_arr, "bx-")
    ax2.set_ylabel("non zero element count")
    plt.legend()
    plt.show()

その結果は下記のようになりました(赤線が推定精度、青線が係数の非0要素数です)。
f:id:YamagenSakam:20180603015719p:plain

\lambda = 0のときは、すべての要素が非0になりノイズである属性も判別に用いてしまっているので精度が60.5%とあまり良くありません。そこから\lambdaを大きくしていくと非0要素数は減少していくことがわかります。一方で推定精度については、適切な\lambdaであれば精度が上がっていますが、逆に\lambdaが大きすぎると制約がきつくなりすぎるためか精度は下がっていきます。面白いのは、非0要素数が3のとき(クラス分類に寄与する2次元 + バイアス項のみが非0のとき)、推定精度がピークとなると予想していましたが、予想に反して\lambda=15で非0要素数8のときに推定精度が最大(75.5%)になる結果が得られました*2。これは、適切にクラス分類に寄与する属性を見つけられるまで\lambdaを上げると、ロジスティック回帰の制約条件としては厳しくなりすぎてしまい結果、精度の低下を招くのだと考えられます。そのため、クラス分類問題とどの属性に有意差があるか見つける問題とは一緒くたにして解くのではなく、分けて解くべきなのだと思います。

まとめ

今回はロジスティック回帰にL1ノルム正則化項をつけることによりスパースな解が得られ、分類器としての精度も向上させられることを見てきました。また、分類に寄与する属性もこのL1ノルム正則化を使うことで見つけられますが、だからと言って制約条件が厳しすぎる\lambdaを用いると分類精度低下を招くので注意が必要です。一旦、変数選択の目的でL1ノルム正則化項付きのロジスティック回帰を行って、そこで係数が非0となった要素だけを使って今度は制約なしのロジスティック回帰を再学習させる方法がよいのかも知れません。なお、今回はL1ノルムの正則化を試しましたが、もちろんL1ノルム以外にもグループL1ノルムやL1/L2の混合ノルムも適用できます。
今後の予定としては、もうひとつ近接勾配法の応用として行列データの分類について紹介したり、Alternating Direction Method of Multipliers (ADMM)を調べて書いたり、ICML2018の論文(そろそろ公開?)を読んだりなんかをしたいと思います。

*1:バイアス項w_0{\bf w}に統合し、d + 1次元のベクトルとして扱います。

*2:ランダムな疑似データなので試行毎に結果は変わりますが、傾向としては毎回同じで非0要素数が3よりも多い時点で推定精度のピークを迎えます。