のんびり読書日記

日々の記録をつらつらと

COP-KMEANS(Constrained K-means)を試してみた

制約付きクラスタリング・半教師ありクラスタリングは、クラスタリングをする際に制約を与えることで精度を向上させる手法です。制約は2つのデータ間の関係を定義した、以下がよく使われるようです。

  • must-link
  • cannot-link

今回はとりあえず制約付きクラスタリングの論文で多く引用されていて、以下の論文を参考に実装してみました。手法がK-meansを少し改良しただけで一番簡単そうだったのと、最新の動向まで調べきれなかったので、まずはとっかかりとして。

  • "Constrained K-means Clustering with Background Knowledge", by Kiri Wagstaff, Claire Cardie, Seth Rogers, and Stefan Schroedl. ICML 2001
  • Java Appletのデモ

上記の論文ではCOP-KMEANS(Costrained K-means)という手法を提案しています。手法はK-meansを改良したもので、各データとクラスタ中心との距離を測定して一番中心が近いクラスタに割り当てるときにmust-link, cannot-link制約をチェックして、制約を満たすものの内で一番近いクラスタに割り当てるようにしています。制約を満たすクラスタが存在しなかった場合は、クラスタリングを中止します。

作成したソースコードgithubに置いてあります。いい加減ブログにべた貼りは止めていこうかなと^^;

実際に使用するときは以下のようにしてください。

% cat /path/data.tsv  # 入力データ
1       a       2       b       2       c       2       d       -1      e       -1      f       -1
2       a       2       b       -1      c       2       d       -1      e       -1      f       -1
3       a       2       b       2       c       -1      d       -1      e       -1      f       -1
4       a       -1      b       -1      c       -1      d       2       e       2       f       2
5       a       -1      b       -1      c       -1      d       2       e       -1      f       2
6       a       -1      b       -1      c       -1      d       2       e       2       f       -1

% cat /path/constraint.tsv  # 制約データ
1       4       m           # ID1  ID2  m(ust),c(annot)
2       3       c

% g++ cop_kmeans.cc -O3 -o cop_kmeans
% cop_kmeans 2 /path/data.tsv
kmeans loop No.0 ...
kmeans loop No.1 ...
1       0
2       0
3       0
4       1
5       1
6       1
% cop_kmeans 2 /path/data.tsv /path/constraint.tsv
kmeans loop No.0 ...
kmeans loop No.1 ...
1       1
2       1
3       0
4       1
5       0
6       1

(見やすいように出力結果を一部並べ替えてます)

制約として、{1, 4}は同じクラスタ(must-link)、{2, 3}は違うクラスタ(cannot-link)を指定しました。その結果、 制約なしでは{1, 2, 3}, {4, 5, 6}とクラスタリングされていたのが、{1, 2, 4, 6}, {3, 5}とクラスタリング結果が変化しました。一応制約は効いてるみたいですね。

ただ論文通りの手法だと、制約が満たされなかった時点ですぐにクラスタリングを中止してしまうので、データを見る順番によってクラスタリングがすぐに止まってしまう可能性が高そうです。データの順番をランダムに見るようにして、制約を満たさなかった場合は違う順番で何回か試行するようにすればいいのかな?またcannot-linkを重視した方が精度がよくなる、と他の論文にありましたので、その点を考慮してアルゴリズムを改良した方がよさそう。って、そんなことはとっくにやられてそうですけど。

また前回の記事ではOpenMPを使いましたが、今回はOpenMPは使用していません。制約を満たすように一つ一つデータを見ていくインクリメンタルなアルゴリズムなので、このままだと並列化が難しそうです。制約を毎回チェックするのではなく、すべてクラスタに割り当てた後で制約を満たすように修正するように変更すれば、クラスタに割り当てる部分は並列化できるかな。

でもまずはもっと最近の論文も読むのが先ですかね^^; これが定番、みたいな論文ないのかなぁ。

K-meansをOpenMPで並列化

昨年末に「平行コンピューティング技法」を読んで勉強していたのですが、せっかくなのでK-meansOpenMPを使って高速化してみようと思います。OpenMPは簡単な構文を挿入することで、自動的にループの繰り返しを分割し、複数のスレッドにタスクを割り当ててくれます。

並行コンピューティング技法 ―実践マルチコア/マルチスレッドプログラミング

並行コンピューティング技法 ―実践マルチコア/マルチスレッドプログラミング

作成したコードは以下の通り。コンパイル・実行にはgoogle_sparsehashとOpenMPをあらかじめインストールしておく必要があります。

// K-means++ + OpenMP

#include <cassert>
#include <cstdio>
#include <ctime>
#include <fstream>
#include <vector>
#include <google/dense_hash_map>
#include <omp.h>

typedef uint64_t VecKey;
typedef size_t VecId;
typedef google::dense_hash_map<VecKey, double> Vector;
typedef google::dense_hash_map<std::string, VecKey> KeyMap;

class KMeans;

/* function prototypes */
int main(int argc, char **argv);
void usage(const char *progname);
void read_vectors(const char *filename, KMeans &kmeans);
size_t splitstring(std::string s, const std::string &delimiter,
                   std::vector<std::string> &splited);

/* constants */
const size_t MAX_ITER  = 10;
const VecKey EMPTY_KEY = 0;
const double LONG_DIST = 1000000000000000;
const std::string DELIMITER("\t");

class KMeans {
 public:
  typedef google::dense_hash_map<std::string, size_t> LabelMap;

 private:
  std::vector<Vector *> vectors_;
  std::vector<Vector *> centers_;
  LabelMap labels_;

  double euclid_distance_squared(const Vector &vec1, const Vector &vec2) const {
    google::dense_hash_map<VecKey, bool> check;
    check.set_empty_key(EMPTY_KEY);
    double dist = 0.0;
    Vector::const_iterator it1, it2;
    for (it1 = vec1.begin(); it1 != vec1.end(); ++it1) {
      double val1 = it1->second;
      double val2 = 0.0;
      it2 = vec2.find(it1->first);
      if (it2 != vec2.end()) val2 = it2->second;
      dist += (val1 - val2) * (val1 - val2);
      check[it1->first] = true;
    }
    for (it2 = vec2.begin(); it2 != vec2.end(); ++it2) {
      if (check.find(it2->first) != check.end()) continue;
      double val2 = it2->second;
      double val1 = 0.0;
      it1 = vec1.find(it2->first);
      if (it1 != vec1.end()) val1 = it1->second;
      dist += (val1 - val2) * (val1 - val2);
    }
    return dist;
  }

  void choose_random_centers(size_t ncenters) {
    centers_.clear();
    google::dense_hash_map<size_t, bool> check;
    check.set_empty_key(vectors_.size());
    size_t cnt = 0;
    while (cnt < ncenters) {
      size_t idx = rand() % vectors_.size();
      if (check.find(idx) == check.end()) {
        Vector *center = new Vector(*vectors_[idx]);
        centers_.push_back(center);
        cnt++;
        check[idx] = true;
      }
    }
  }

  void choose_smart_centers(size_t ncenters) {
    centers_.clear();
    double closest_dist[vectors_.size()];
    double potential = 0.0;
    size_t cnt = 0;

    // choose one random center
    size_t idx = rand() % vectors_.size();
    Vector *center = new Vector(*vectors_[idx]);
    centers_.push_back(center);
    cnt++;
    // update closest distance
    for (size_t i = 0; i < vectors_.size(); i++) {
      double dist = euclid_distance_squared(*vectors_[i], *centers_[0]);
      closest_dist[i] = dist;
      potential += dist;
    }
    // choose each centers
    while (cnt < ncenters) {
      double randval = static_cast<double>(rand()) / RAND_MAX * potential;
      size_t idx = 0;
      for (size_t i = 0; i < vectors_.size(); i++) {
        if (randval <= closest_dist[i]) {
          idx = i;
          break;
        } else {
          randval -= closest_dist[i];
        }
      }
      Vector *center = new Vector(*vectors_[idx]);
      double potential_new = 0.0;
      for (size_t i = 0; i < vectors_.size(); i++) {
        double dist = euclid_distance_squared(*vectors_[i], *center);
        if (closest_dist[i] > dist) closest_dist[i] = dist;
        potential_new += closest_dist[i];
      }
      centers_.push_back(center);
      cnt++;
      potential = potential_new;
    }
  }

  void assign_clusters(size_t *assign) const {
    int vsiz = static_cast<int>(vectors_.size());
    #pragma omp parallel for
    for (int i = 0; i < vsiz; i++) {
      size_t min_idx = 0;
      double min_dist = LONG_DIST;
      for (size_t j = 0; j < centers_.size(); j++) {
        double dist = euclid_distance_squared(*vectors_[i], *centers_[j]);
        if (dist < min_dist) {
          min_idx = j;
          min_dist = dist;
        }
      }
      assign[i] = min_idx;
    }
  }

  void move_centers(const size_t *assign) {
    for (size_t i = 0; i < centers_.size(); i++) {
      centers_[i]->clear();
    }
    std::vector<size_t> count(centers_.size());
    Vector::iterator cit;
    for (size_t i = 0; i < vectors_.size(); i++) {
      for (Vector::iterator it = vectors_[i]->begin();
           it != vectors_[i]->end(); ++it) {
        cit = centers_[assign[i]]->find(it->first);
        if (cit != centers_[assign[i]]->end()) {
          cit->second += it->second;
        } else {
          centers_[assign[i]]->insert(
            std::pair<VecKey, double>(it->first, it->second));
        }
      }
      count[assign[i]]++;
    }
    for (size_t i = 0; i < count.size(); i++) {
      if (count[i] == 0) continue;
      for (Vector::iterator it = centers_[i]->begin();
           it != centers_[i]->end(); ++it) {
        it->second /= count[i];
      }
    }
  }

  bool is_same_array(size_t *array1, size_t *array2, size_t size) {
    for (size_t i = 0; i < size; i++) {
      if (array1[i] != array2[i]) return false;
    }
    return true;
  }

 public:
  KMeans() { labels_.set_empty_key(""); }

  ~KMeans() {
    for (size_t i = 0; i < vectors_.size(); i++) {
      if (vectors_[i]) delete vectors_[i];
    }
    for (size_t i = 0; i < centers_.size(); i++) {
      if (centers_[i]) delete centers_[i];
    }
  }

  void add_vector(const std::string &label, Vector *vec) {
    assert(!label.empty() && !vec->empty());
    labels_[label] = vectors_.size();
    vectors_.push_back(vec);
  }

  void execute(size_t nclusters) {
    assert(nclusters <= vectors_.size());
    choose_random_centers(nclusters);
//    choose_smart_centers(nclusters);
    size_t assign[vectors_.size()];
    size_t prev_assign[vectors_.size()];
    memset(assign, nclusters, sizeof(nclusters) * vectors_.size());
    memset(prev_assign, nclusters, sizeof(nclusters) * vectors_.size());
    for (size_t i = 0; i < MAX_ITER; i++) {
      fprintf(stderr, "kmeans loop No.%d ...\n", i);
      assign_clusters(assign);
      move_centers(assign);
      if (is_same_array(assign, prev_assign, vectors_.size())) {
        break;
      } else {
        std::copy(assign, assign + vectors_.size(), prev_assign);
      }
    }
    // show clustering result
    for (LabelMap::iterator it = labels_.begin(); it != labels_.end(); ++it) {
      printf("%s\t%d\n", it->first.c_str(), assign[it->second]);
    }
  }

  void show_vectors() const {
    for (LabelMap::const_iterator lit = labels_.begin();
         lit != labels_.end(); ++lit) {
      printf("%s", lit->first.c_str());
      for (Vector::const_iterator vit = vectors_[lit->second]->begin();
           vit != vectors_[lit->second]->end(); ++vit) {
        printf("\t%d\t%.3f", vit->first, vit->second);
      }
      printf("\n");
    }
  }
};

int main(int argc, char **argv) {
  if (argc < 3) {
    usage(argv[0]);
  }
  //srand((unsigned int) time(NULL));
  KMeans kmeans;
  read_vectors(argv[2], kmeans);
//  kmeans.show_vectors();
  kmeans.execute(atoi(argv[1]));
  return 0;
}

void usage(const char *progname) {
  fprintf(stderr, "%s: ncluster data\n", progname);
  exit(1);
}

void read_vectors(const char *filename, KMeans &kmeans) {
  std::ifstream ifs(filename);
  if (!ifs) {
    fprintf(stderr, "cannot open %s\n", filename);
    exit(1);
  }
  KeyMap keymap;
  keymap.set_empty_key("");
  VecKey curkey = EMPTY_KEY + 1;
  std::string line;
  std::vector<std::string> splited;
  while (getline(ifs, line)) {
    splitstring(line, DELIMITER, splited);
    if (splited.size() % 2 != 1) {
      fprintf(stderr, "format error: %s\n", line.c_str());
      continue;
    }
    Vector *vec = new Vector;
    vec->set_empty_key(EMPTY_KEY);
    for (size_t i = 1; i < splited.size(); i += 2) {
      KeyMap::iterator kit = keymap.find(splited[i]);
      VecKey key;
      if (kit != keymap.end()) {
        key = kit->second;
      } else {
        key = curkey;
        keymap[splited[i]] = curkey++;
      }
      double point = 0.0;
      point = atof(splited[i+1].c_str());
      if (point != 0) {
        vec->insert(std::pair<VecKey, double>(key, point));
      }
    }
    if (!splited[0].empty() && !vec->empty()) {
      kmeans.add_vector(splited[0], vec);
    }
    splited.clear();
  }
}

size_t splitstring(std::string s, const std::string &delimiter,
                   std::vector<std::string> &splited) {
  size_t cnt = 0;
  for (size_t p = 0; (p = s.find(delimiter)) != s.npos; ) {
    splited.push_back(s.substr(0, p));
    ++cnt;
    s = s.substr(p + delimiter.size());
  }
  splited.push_back(s);
  ++cnt;
  return cnt;
}

今回OpenMPで並列化したのは、「各データと各クラスタの中心との距離を求めて、最も中心が近いクラスタにデータを割り当てる」部分(assign_clustersメソッドの中)です。OpenMPに関係あるところは、

  • OpenMPのヘッダの読み込み
#include <omp.h>
  • for文の前に以下の構文を挿入
    #pragma omp parallel for

この2つだけです。

実際に動かすときは以下のようにします。OpenMPを使用するときは"-fopenmp"オプションを加える必要があります。このオプションを加えなかった場合は、OpenMPの構文は無視されて、並列化されないまま実行されます。OpenMPを使いたくない場合でも、ソースを修正せずにそのままコンパイル・実行できるのは結構うれしいかも。入力データのフォーマットは1行1ドキュメントのタブ区切りテキストです。

% cat /path/data.tsv  # 入力データ
1       a       2       b       2       c       2       d       -1      e       -1      f       -1
2       a       2       b       -1      c       2       d       -1      e       -1      f       -1
3       a       2       b       2       c       -1      d       -1      e       -1      f       -1
4       a       -1      b       -1      c       -1      d       2       e       2       f       2
5       a       -1      b       -1      c       -1      d       2       e       -1      f       2
6       a       -1      b       -1      c       -1      d       2       e       2       f       -1

% g++ kmeanspp_mp.cc -O3 -o kmeanspp_mp -fopenmp
% ./kmeans_pp 100 /path/data.tsv > cluster.tsv
% sort -g cluster.tsv | lv

OpenMPを使った場合と使わなかった場合で、どれくらい実行時間が違うかを簡単に比較してみました。入力ドキュメント数は10000、クラスタ数は100、K-meansのループ回数は最大10回、実行環境のCPUコア数は4で実験しました。また乱数のseedは固定にして実行してあります。

== OpenMPあり ==
% g++ kmeans_mp.cc -O3 -o kmeans_mp -fopenmp
% time ./kmeans_mp 100 /path/data.tsv > /dev/null
315.147u 0.060s 1:25.48 368.7%  0+0k 0+0io 0pf+0w

== OpenMPなし ==
% g++ kmeans_mp.cc -O3 -o kmeans_mp
% time ./kmeans_mp 100 /path/data.tsv > /dev/null
491.246u 0.132s 8:11.82 99.9%   0+0k 0+0io 0pf+0w

OpenMPを使用した時の実行時間は1分25秒、OpenMPを使用しなかった場合の実行時間は8分11秒になりました。たしかに高速化されてますね。でもコア数倍以上に早くなってるのは何でだろう?なにかミスがあるのかな。。。

ループ内が完全に独立ではなく複雑な構成をしている場合等では、OpenMPは使用するのが難しいこともあるようですが、これだけ簡単に使えて高速化できるならかなりうれしいですね。さて、次はpthread勉強しようかな。

追記1:OpenMPで並列化している箇所で念のためスレッド間の競合を避けるために、各スレッドが共有しているstd::vectorを配列に変更しました。各スレッドが別インデックスでvectorを読み書きするときに、vectorのサイズをあらかじめ確保している場合でも競合って発生しちゃうんですかね?vectorが中で勝手にresizeしたりするのかな…。

追記2:pragma節内の変数j, min_idx, min_dist, dist をprivate指定しないとまずいのでは?と指摘があったのですが、下記の資料を参照したところ、

並列実行領域のローカルデータはプライベートとなります。

と書いてありましたので、おそらく問題はないかな…?

Variable Byte codeを試してみた

最近転置インデックスをゴニョゴニョしているのですが、インデックスの圧縮をするためにVariable Byte codeでの数字列の圧縮部分を作ってみました。アルゴリズムはIntroduction to Information Retrievalの5章Index compressionを参考にしています。

作成したコードは以下の通りです。数字列はソートして差分をとってから圧縮するようにしています。また符号化した後のchar *のサイズを別で持っておくのは面倒なので、数字列の先頭に数字列の個数を入れてから符号化しています。復号するときはまず符号化されたchar *から数字列の個数部分だけ読んでおき、後はその個数に到達するまで復号化します。

//
// Variable Byte code
// http://nlp.stanford.edu/IR-book/html/htmledition/variable-byte-codes-1.html
//

#include <algorithm>
#include <iostream>
#include <map>
#include <vector>
#include <ctime>

void variable_byte_encode_number(uint64_t num, std::vector<uint64_t> &encoded) {
  for (;;) {
    encoded.push_back(num % 128);
    if (num < 128) break;
    num /= 128;
  }
  encoded[0] += 128;
}

char *variable_byte_encode(const std::vector<uint64_t> &numbers) {
  if (numbers.size() == 0) return NULL;
  // size of numbers
  std::vector<uint64_t> bytes, numenc;
  variable_byte_encode_number(numbers.size(), numenc);
  copy(numenc.rbegin(), numenc.rend(), std::back_inserter(bytes));
  numenc.clear();

  for (size_t i = 0; i < numbers.size(); i++) {
    variable_byte_encode_number(numbers[i], numenc);
    copy(numenc.rbegin(), numenc.rend(), std::back_inserter(bytes));
    numenc.clear();
  }
  char *buf = new char[bytes.size()];
  copy(bytes.begin(), bytes.end(), buf);
  std::cout << "Size(encoded): " << sizeof(char) * bytes.size() << std::endl;
  return buf;
}

void variable_byte_decode(const char *ptr, std::vector<uint64_t> &numbers) {
  // size of numbers
  uint64_t size = 0;
  uint64_t c;
  do {
    c = *(unsigned char *)ptr++;
    size = (c < 128) ? 128 * size + c : 128 * size + (c - 128);
  } while (c < 128);

  uint64_t cnt = 0;
  while (cnt < size) {
    uint64_t n = 0;
    do {
      c = *(unsigned char *)ptr++;
      n = (c < 128) ? 128 * n + c : 128 * n + (c - 128);
    } while (c < 128);
    numbers.push_back(n);
    cnt++;
  }
}

char *compress_diff(const std::vector<uint64_t> &numbers) {
  std::vector<uint64_t> diff;
  uint64_t prev = 0;
  for (size_t i = 0; i < numbers.size(); i++) {
    diff.push_back(numbers[i] - prev);
    prev = numbers[i];
  }
  return variable_byte_encode(diff);
}

void decompress_diff(const char *ptr, std::vector<uint64_t> &numbers) {
  variable_byte_decode(ptr, numbers);
  for (size_t i = 1; i < numbers.size(); i++) numbers[i] += numbers[i-1];
}

void random_numbers(size_t size, uint64_t max, std::vector<uint64_t> &numbers) {
  std::map<uint64_t, bool> check;
  size_t cnt = 0;
  while (cnt < size) {
    uint64_t num = static_cast<uint64_t>(rand()) % max;
    if (check.find(num) == check.end()) {
      numbers.push_back(num);
      check[num] = true;
      cnt++;
    }
  }
  std::sort(numbers.begin(), numbers.end());
}

int main(int argc, char **argv) {
  srand(static_cast<unsigned int>(time(NULL)));
  std::vector<uint64_t> numbers;
  size_t size = 10;
  random_numbers(size, size * 100, numbers);

  std::cout << "Size(input):   " << sizeof(uint64_t) * size << std::endl;
  std::cout << "Input:   ";
  for (size_t i = 0; i < numbers.size(); i++) {
    if (i != 0) std::cout << " ";
    std::cout << numbers[i];
  }
  std::cout << std::endl;

  char *encoded = compress_diff(numbers);
  //char *encoded = variable_byte_encode(numbers);

  std::vector<uint64_t> decoded;
  decompress_diff(encoded, decoded);
  //variable_byte_decode(encoded, decoded);
  delete [] encoded;

  std::cout << "Decoded: ";
  for (size_t i = 0; i < decoded.size(); i++) {
    if (i != 0) std::cout << " ";
    std::cout << decoded[i];
  }
  std::cout << std::endl;
  return 0;
}

encode部分がすごく冗長な気がするのですが、まあとりあえずはいいかなと。ちゃんと動くか試してみます。

% g++ variable_byte.cc -o vb 
% ./vb
Size(input):   80
Input:   335 383 386 421 492 649 777 793 886 915
Size(encoded): 14
Decoded: 335 383 386 421 492 649 777 793 886 915

ちゃんと復元できてそうですね。

Variable Byte code以外にもRice符号やδ符号など圧縮率がより高いものや、復元スピードがより速い符号化があるので、今後はそれも試してみます。でも個人的にはVariable Byte codeで今のところは十分かなと思っていたり…^^;

Perlでconstantを使うときの注意

この前CPANにアップしたモジュールでCPAN Testersの結果を見てたら、Perlのversion5.6.2で毎回テストに失敗してて、何でだろう?と思っていたらbug reportがきていた。

The syntax you are using to declare constants was not always supported.
perl 5.6.2 came with version 1.02 of constant.pm and there this syntax
was not supported.
...

constants.pmの古いバージョン(1.0.2以下)だとconstantsを一度に複数指定する書式をサポートしていないのが問題らしい。 なので下のように書いているところでエラーが出てたみたい。

use constant {
    AAA => 1,
    BBB => 2,
};

下のように一つ一つ定義するようにすればオッケーみたい。

use constant AAA => 1;
use constant BBB => 2;

普段書くときは複数同時指定でもいいけど、CPANモジュールみたいに古い環境でも実行される可能性がある場合は、個別に定義した方がいいのかな。これも多分常識なんだろうけど、とりあえずメモメモ。

それにしてもバグレポートなんて初めてもらったけど、すごくうれしいなぁ。あとやっぱり英語重要。丁寧に返事をしたいけど、いまいち言い回しが分からない…。

Algorithm::FuzzyCmeans と Algorithm::Kmeanspp を作った

なんとなくCPANにもっと上げてみたくなったので、昔書いたネタをパッケージングして上げてみました。とりあえず今回はFuzzy c-means clusteringと、K-means++。

使い方はpodを見れば簡単に分かると思います。

ただ実際に大きなデータで使うというよりは、ちょっと試してみてから、中身を見てどんな仕組みなのかなーと調べる資料になればいいかな、という気持ちで作ってます。その割には中身汚いかもしれませんが…^^;

freshmeatに登録してみた

せっかくオープンソースのプロダクトを作ったので、freshmeatにbayonを登録してみました。説明文をもっと分かりやすく、キャッチーにしないとダメですねぇ。

大したプロダクトでもないのでちょっと恥ずかしいですが、これ見て海外の人も使ってくれたりしたら、めちゃくちゃうれしいんでしょうねー。bayonもアップデートしつつ、もっといいプロダクトを作って、いろんな人に使ってもらえるよう頑張りたいですね!

Algorithm::BayesianSetsモジュールをアップした

前回のエントリでBayesian Setsを試してみたのですが、その時に書いたコードをAlgorithm::BayesianSetsというモジュールにまとめて、CPANにアップしました。生まれて初めてのCPANアップです。

すごいちっちゃいモジュールですが、これで僕もCPAN Authorの仲間入りかと思うとうれしいですね^^ 実用で使うのはちょっと厳しいかもしれませんが、なんとなくBayesian Setsを試してみるにはいいかと思います。

前回と同じように、タブ区切りのフォーマットのデータを入力として与える場合は、以下のようなコードで動きます。

#!/usr/bin/perl

use strict;
use warnings;
use Algorithm::BayesianSets;

use constant {
    MAX_OUTPUT => 20,
};

my $path = shift @ARGV;
my @queries = @ARGV;
if (!$path || !@queries) {
    warn "Usage $0 file query1 query2 ..\n";
    exit 1;
}

my $bs = Algorithm::BayesianSets->new();

# read input documents
# format: document_id \t key1 \t val1 \t key2 \t val2 \t ...\n
open my $fh, $path or die "cannot open file: $path";
while (my $line = <$fh>) {
    chomp $line;
    my @arr = split /\t/, $line;
    my $doc_id = shift @arr;
    my %vector = @arr;
    $bs->add_document($doc_id, \%vector);
}

$bs->calc_parameters();
my $scores = $bs->calc_similarities(\@queries);
my $count = 0;
foreach my $doc_id (sort { $scores->{$b} <=> $scores->{$a} }
    keys %{ $scores }) {
    last if ++$count > MAX_OUTPUT;
    printf "%s\t%.3f\n", $doc_id, $scores->{$doc_id};
}

ちなみにSYNOPSYSの例では入力ベクトル集合の値がすべて1になっていますが、ここは別に正の値であればなんでも大丈夫です。コンストラクタで与えた閾値(デフォルトでは0)以上のキーのみが使用されます。よく考えたら、閾値は絶対値で比較するようにして、正負両方の値を受けられるようにした方がよかったかな…?まあ負の値を使うケースはそんなに多くはなさそうですし、その場合でもユーザ側で正に揃えておいてもらえば十分ですかね。

あとBayesian SetsのC++実装版をいま作り中です。ライブラリ部分と、実際のサービスで使用できるようにサーバ部分も作っているのですが、サーバ実装まわりの知識がまったく足りてないので勉強しながらごにょごにょと開発中です。結構先になっちゃいそうですが、ある程度動く状態になったらまたブログに書きたいと思います。