低スペックな頭の僕がJavaの機械学習ライブラリmahoutをH2つないでみる。

mahoutのDataModelはJDBCをつかってDBのデータから作ることが出来ます。0.4ではMySQLのみ、0.5-SNAPSHOTではMySQLPostgreSQLのみが標準でサポートされ他のDBを利用する時にはAbstractJDBCDataModelを継承して実装します。

とりあえずH2で

設定も特に必要ないからH2との接続をやってみます。プロジェクト全体はこんな感じ。

pom.xml

mahoutで使うものにH2のJDBCを追加しる

  <dependencies>
  	<dependency>
  		<groupId>org.apache.mahout</groupId>
  		<artifactId>mahout-core</artifactId>
  		<version>0.4</version>
  	</dependency>
  	<dependency>
  		<groupId>com.github.mcpat.slf4j</groupId>
  		<artifactId>slf4cldc-nop</artifactId>
  		<version>1.6.0</version>
  	</dependency>
  	<dependency>
  		<groupId>com.h2database</groupId>
  		<artifactId>h2</artifactId>
  		<version>1.3.153</version>
  	</dependency>
  </dependencies>

テーブル

columName type PK
USER_ID BIGINT PK
ITEM_ID BIGINT PK
PREFERENCE REAL
UPDATING_TIME TIMESTAMP

テーブルのカラム名はDataModelのコンストラクタで指定するので変更可能。でも、Javadoc読むとUPDATING_TIMEの部分は必要ないかもしれない。でもMySQLJDBCDataModelのソース見ると参照するSQLがあるので念のためつけておく。デフォルト値は好きに設定する。

ソース

まずはH2からDataModel作ってくれるH2JDBCDataModel.java

package org.kirino.data.model;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;

import javax.sql.DataSource;

import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.model.jdbc.AbstractJDBCDataModel;
import org.apache.mahout.common.IOUtils;

public class H2JDBCDataModel extends AbstractJDBCDataModel {

	private static final long serialVersionUID = 1L;
	private final String updatePreferenceSQL;

	public H2JDBCDataModel(DataSource dataSource, String preferenceTable,
			String userIDColumn, String itemIDColumn, String preferenceColumn,
			String timestampColumn) {
		super(dataSource, preferenceTable, userIDColumn, itemIDColumn,
				preferenceColumn,
				// getPreferenceSQL
				"SELECT " + preferenceColumn + " FROM " + preferenceTable
						+ " WHERE " + userIDColumn + "=? AND " + itemIDColumn
						+ "=?",
				// getPreferenceTimeSQL
				"SELECT " + timestampColumn + " FROM " + preferenceTable
						+ " WHERE " + userIDColumn + "=? AND " + itemIDColumn
						+ "=?",
				// getUserSQL
				"SELECT DISTINCT " + userIDColumn + ", " + itemIDColumn + ", "
						+ preferenceColumn + " FROM " + preferenceTable
						+ " WHERE " + userIDColumn + "=? ORDER BY "
						+ itemIDColumn,
				// getAllUsersSQL
				"SELECT DISTINCT " + userIDColumn + ", " + itemIDColumn + ", "
						+ preferenceColumn + " FROM " + preferenceTable
						+ " ORDER BY " + userIDColumn + ", " + itemIDColumn,
				// getNumItemsSQL
				"SELECT COUNT(DISTINCT " + itemIDColumn + ") FROM "
						+ preferenceTable,
				// getNumUsersSQL
				"SELECT COUNT(DISTINCT " + userIDColumn + ") FROM "
						+ preferenceTable,
				// setPreferenceSQL
				"INSERT INTO " + preferenceTable + '(' + userIDColumn + ','
						+ itemIDColumn + ',' + preferenceColumn
						+ ") VALUES (?,?,?)",
				// removePreference SQL
				"DELETE FROM " + preferenceTable + " WHERE " + userIDColumn
						+ "=? AND " + itemIDColumn + "=?",
				// getUsersSQL
				"SELECT DISTINCT " + userIDColumn + " FROM " + preferenceTable
						+ " ORDER BY " + userIDColumn,
				// getItemsSQL
				"SELECT DISTINCT " + itemIDColumn + " FROM " + preferenceTable
						+ " ORDER BY " + itemIDColumn,
				// getPrefsForItemSQL
				"SELECT DISTINCT " + userIDColumn + ", " + itemIDColumn + ", "
						+ preferenceColumn + " FROM " + preferenceTable
						+ " WHERE " + itemIDColumn + "=? ORDER BY "
						+ userIDColumn,
				// getNumPreferenceForItemSQL
				"SELECT COUNT(1) FROM " + preferenceTable + " WHERE "
						+ itemIDColumn + "=?",
				// getNumPreferenceForItemsSQL
				"SELECT COUNT(1) FROM " + preferenceTable + " tp1 JOIN "
						+ preferenceTable + " tp2 WHERE tp1." + itemIDColumn
						+ "=? and tp2." + itemIDColumn + "=?", "SELECT MAX("
						+ preferenceColumn + ") FROM " + preferenceTable,
				"SELECT MIN(" + preferenceColumn + ") FROM " + preferenceTable);
		updatePreferenceSQL = "UPDATE " + preferenceTable + " SET "
				+ preferenceColumn + " = ? WHERE " + userIDColumn + " = ? AND "
				+ itemIDColumn + " = ?";

	}

	@Override
	public void setPreference(long userID, long itemID, float value)
			throws TasteException {

		try {
			super.setPreference(userID, itemID, value);
		} catch (TasteException e) {
			Connection conn = null;
			PreparedStatement stmt = null;
			try {
				conn = super.getDataSource().getConnection();
				stmt = conn.prepareStatement(updatePreferenceSQL);
				stmt.setFloat(1, value);
				stmt.setLong(2, userID);
				stmt.setLong(3, itemID);
				stmt.executeUpdate();

			} catch (SQLException sqle) {
				throw new TasteException(sqle);
			} finally {
				IOUtils.quietClose(null, stmt, conn);
			}
		}

	}

}

実行する部分 MahoutSampleH2.java

package org.kirino.mahout;

import java.util.List;

import javax.naming.NamingException;

import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.neighborhood.NearestNUserNeighborhood;
import org.apache.mahout.cf.taste.impl.recommender.GenericUserBasedRecommender;
import org.apache.mahout.cf.taste.impl.similarity.PearsonCorrelationSimilarity;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.apache.mahout.cf.taste.recommender.Recommender;
import org.apache.mahout.cf.taste.similarity.UserSimilarity;
import org.kirino.data.model.H2JDBCDataModel;

public class MahoutSampleH2 {

	/**
	 * @param args
	 */
	public static void main(String[] args) {
		try {
			String url = "jdbc:h2:~/hoge";
			String user = "sa";
			String password = "";
			// データの取り込み
			DataModel dataModel = new H2JDBCDataModel(
					H2Datasource.getDatasource(url, user, password),
					"TASETE_PREFERENCES", "USER_ID", "ITEM_ID", "PREFERENCE",
					"UPDATING_TIME");
			// 相関性の評価基準の設定
			UserSimilarity similarity = new PearsonCorrelationSimilarity(
					dataModel);
			// 評価の近い人を探すロジックを決めてる?
			UserNeighborhood neighborhood = new NearestNUserNeighborhood(3,
					similarity, dataModel);
			// レコメンダの作成
			Recommender recommender = new GenericUserBasedRecommender(
					dataModel, neighborhood, similarity);

			// 1番の人に対するレコメンドが1つ
			List<RecommendedItem> recommendations = recommender.recommend(1, 1);
			for (RecommendedItem recommendation : recommendations) {
				System.out.println(recommendation);
			}

			System.out.println("end");

		} catch (TasteException e) {
			e.printStackTrace();
		} catch (NamingException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}

	}

}

Datasource取ってる部分 H2Datasource.java

package org.kirino;

import javax.naming.NamingException;
import javax.sql.DataSource;

import org.h2.jdbcx.JdbcDataSource;

public class H2Datasource {

	public static DataSource getDatasource(String url,String user,String password) throws NamingException {
		JdbcDataSource dataSource = new JdbcDataSource();
		dataSource.setURL(url);
		dataSource.setUser(user);
		dataSource.setPassword(password);

		return dataSource;
	}

	private H2Datasource() {
		super();
	}

}

解説するよ

まずはH2からDataModel作ってくれるH2JDBCDataModel.javaはコンストラクタとsetPreferenceを実装しています。コンストラクタはそれぞれ該当するSQLをがちゃがちゃと文字列結合してる。MySQLJDBCDataModelもこんな感じの実装。違う点はデータをinsertするSQLMySQLではINSERT〜ON DUPLICATE KEY UPDATE構文でキー重複したらUPDATEということができるけどH2ではできない。だからsetPreferenceをオーバーライドしてる。setPreferenceはAbstractJDBCDataModel#setPreferenceでinsert文を発行して失敗したらH2JDBCDataModel#setPreferenceでupdate文を発行する。この辺がなんかイケてない気がする。
次に実行部分のH2MahoutSample.java csv取り込みと違う点はDataModelの作り方のちがいだけ。

DataModel dataModel = new H2JDBCDataModel(
			H2Datasource.getDatasource(url, user, password),
			"TASETE_PREFERENCES", "USER_ID", "ITEM_ID", "PREFERENCE",
			"UPDATING_TIME");

コンストラクタでDatasourceとテーブル名とそれぞれのカラム名を設定する。この辺はMySQLJDBCDataModelと同じ。カラム名変えたらここを変更するだけでいい。
最後はH2Datasource.java web.xml使わずに書く方法を忘れないようにメモ。

まとめ

今回は0.4を使ってるのでH2JDBCDataModel#setPreferenceの実装がこのような形になっていますが0.5-SNAPSHOTから各SQLのアクセサが入るのでもうちょっと綺麗な処理に書き換えることができます。この部分はDBの実装でキー重複したらUPDATEの処理があるならそちらを使う方が適切でしょう。