低スペックな頭の僕がJavaの機械学習ライブラリmahoutをH2つないでみる。
mahoutのDataModelはJDBCをつかってDBのデータから作ることが出来ます。0.4ではMySQLのみ、0.5-SNAPSHOTではMySQLとPostgreSQLのみが標準でサポートされ他のDBを利用する時にはAbstractJDBCDataModelを継承して実装します。
pom.xml
mahoutで使うものにH2のJDBCを追加しる
<dependencies> <dependency> <groupId>org.apache.mahout</groupId> <artifactId>mahout-core</artifactId> <version>0.4</version> </dependency> <dependency> <groupId>com.github.mcpat.slf4j</groupId> <artifactId>slf4cldc-nop</artifactId> <version>1.6.0</version> </dependency> <dependency> <groupId>com.h2database</groupId> <artifactId>h2</artifactId> <version>1.3.153</version> </dependency> </dependencies>
テーブル
columName | type | PK |
---|---|---|
USER_ID | BIGINT | PK |
ITEM_ID | BIGINT | PK |
PREFERENCE | REAL | |
UPDATING_TIME | TIMESTAMP |
テーブルのカラム名はDataModelのコンストラクタで指定するので変更可能。でも、Javadoc読むとUPDATING_TIMEの部分は必要ないかもしれない。でもMySQLJDBCDataModelのソース見ると参照するSQLがあるので念のためつけておく。デフォルト値は好きに設定する。
ソース
まずはH2からDataModel作ってくれるH2JDBCDataModel.java
package org.kirino.data.model; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; import javax.sql.DataSource; import org.apache.mahout.cf.taste.common.TasteException; import org.apache.mahout.cf.taste.impl.model.jdbc.AbstractJDBCDataModel; import org.apache.mahout.common.IOUtils; public class H2JDBCDataModel extends AbstractJDBCDataModel { private static final long serialVersionUID = 1L; private final String updatePreferenceSQL; public H2JDBCDataModel(DataSource dataSource, String preferenceTable, String userIDColumn, String itemIDColumn, String preferenceColumn, String timestampColumn) { super(dataSource, preferenceTable, userIDColumn, itemIDColumn, preferenceColumn, // getPreferenceSQL "SELECT " + preferenceColumn + " FROM " + preferenceTable + " WHERE " + userIDColumn + "=? AND " + itemIDColumn + "=?", // getPreferenceTimeSQL "SELECT " + timestampColumn + " FROM " + preferenceTable + " WHERE " + userIDColumn + "=? AND " + itemIDColumn + "=?", // getUserSQL "SELECT DISTINCT " + userIDColumn + ", " + itemIDColumn + ", " + preferenceColumn + " FROM " + preferenceTable + " WHERE " + userIDColumn + "=? ORDER BY " + itemIDColumn, // getAllUsersSQL "SELECT DISTINCT " + userIDColumn + ", " + itemIDColumn + ", " + preferenceColumn + " FROM " + preferenceTable + " ORDER BY " + userIDColumn + ", " + itemIDColumn, // getNumItemsSQL "SELECT COUNT(DISTINCT " + itemIDColumn + ") FROM " + preferenceTable, // getNumUsersSQL "SELECT COUNT(DISTINCT " + userIDColumn + ") FROM " + preferenceTable, // setPreferenceSQL "INSERT INTO " + preferenceTable + '(' + userIDColumn + ',' + itemIDColumn + ',' + preferenceColumn + ") VALUES (?,?,?)", // removePreference SQL "DELETE FROM " + preferenceTable + " WHERE " + userIDColumn + "=? AND " + itemIDColumn + "=?", // getUsersSQL "SELECT DISTINCT " + userIDColumn + " FROM " + preferenceTable + " ORDER BY " + userIDColumn, // getItemsSQL "SELECT DISTINCT " + itemIDColumn + " FROM " + preferenceTable + " ORDER BY " + itemIDColumn, // getPrefsForItemSQL "SELECT DISTINCT " + userIDColumn + ", " + itemIDColumn + ", " + preferenceColumn + " FROM " + preferenceTable + " WHERE " + itemIDColumn + "=? ORDER BY " + userIDColumn, // getNumPreferenceForItemSQL "SELECT COUNT(1) FROM " + preferenceTable + " WHERE " + itemIDColumn + "=?", // getNumPreferenceForItemsSQL "SELECT COUNT(1) FROM " + preferenceTable + " tp1 JOIN " + preferenceTable + " tp2 WHERE tp1." + itemIDColumn + "=? and tp2." + itemIDColumn + "=?", "SELECT MAX(" + preferenceColumn + ") FROM " + preferenceTable, "SELECT MIN(" + preferenceColumn + ") FROM " + preferenceTable); updatePreferenceSQL = "UPDATE " + preferenceTable + " SET " + preferenceColumn + " = ? WHERE " + userIDColumn + " = ? AND " + itemIDColumn + " = ?"; } @Override public void setPreference(long userID, long itemID, float value) throws TasteException { try { super.setPreference(userID, itemID, value); } catch (TasteException e) { Connection conn = null; PreparedStatement stmt = null; try { conn = super.getDataSource().getConnection(); stmt = conn.prepareStatement(updatePreferenceSQL); stmt.setFloat(1, value); stmt.setLong(2, userID); stmt.setLong(3, itemID); stmt.executeUpdate(); } catch (SQLException sqle) { throw new TasteException(sqle); } finally { IOUtils.quietClose(null, stmt, conn); } } } }
実行する部分 MahoutSampleH2.java
package org.kirino.mahout; import java.util.List; import javax.naming.NamingException; import org.apache.mahout.cf.taste.common.TasteException; import org.apache.mahout.cf.taste.impl.neighborhood.NearestNUserNeighborhood; import org.apache.mahout.cf.taste.impl.recommender.GenericUserBasedRecommender; import org.apache.mahout.cf.taste.impl.similarity.PearsonCorrelationSimilarity; import org.apache.mahout.cf.taste.model.DataModel; import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood; import org.apache.mahout.cf.taste.recommender.RecommendedItem; import org.apache.mahout.cf.taste.recommender.Recommender; import org.apache.mahout.cf.taste.similarity.UserSimilarity; import org.kirino.data.model.H2JDBCDataModel; public class MahoutSampleH2 { /** * @param args */ public static void main(String[] args) { try { String url = "jdbc:h2:~/hoge"; String user = "sa"; String password = ""; // データの取り込み DataModel dataModel = new H2JDBCDataModel( H2Datasource.getDatasource(url, user, password), "TASETE_PREFERENCES", "USER_ID", "ITEM_ID", "PREFERENCE", "UPDATING_TIME"); // 相関性の評価基準の設定 UserSimilarity similarity = new PearsonCorrelationSimilarity( dataModel); // 評価の近い人を探すロジックを決めてる? UserNeighborhood neighborhood = new NearestNUserNeighborhood(3, similarity, dataModel); // レコメンダの作成 Recommender recommender = new GenericUserBasedRecommender( dataModel, neighborhood, similarity); // 1番の人に対するレコメンドが1つ List<RecommendedItem> recommendations = recommender.recommend(1, 1); for (RecommendedItem recommendation : recommendations) { System.out.println(recommendation); } System.out.println("end"); } catch (TasteException e) { e.printStackTrace(); } catch (NamingException e) { // TODO Auto-generated catch block e.printStackTrace(); } } }
Datasource取ってる部分 H2Datasource.java
package org.kirino; import javax.naming.NamingException; import javax.sql.DataSource; import org.h2.jdbcx.JdbcDataSource; public class H2Datasource { public static DataSource getDatasource(String url,String user,String password) throws NamingException { JdbcDataSource dataSource = new JdbcDataSource(); dataSource.setURL(url); dataSource.setUser(user); dataSource.setPassword(password); return dataSource; } private H2Datasource() { super(); } }
解説するよ
まずはH2からDataModel作ってくれるH2JDBCDataModel.javaはコンストラクタとsetPreferenceを実装しています。コンストラクタはそれぞれ該当するSQLをがちゃがちゃと文字列結合してる。MySQLJDBCDataModelもこんな感じの実装。違う点はデータをinsertするSQLでMySQLではINSERT〜ON DUPLICATE KEY UPDATE構文でキー重複したらUPDATEということができるけどH2ではできない。だからsetPreferenceをオーバーライドしてる。setPreferenceはAbstractJDBCDataModel#setPreferenceでinsert文を発行して失敗したらH2JDBCDataModel#setPreferenceでupdate文を発行する。この辺がなんかイケてない気がする。
次に実行部分のH2MahoutSample.java csv取り込みと違う点はDataModelの作り方のちがいだけ。
DataModel dataModel = new H2JDBCDataModel( H2Datasource.getDatasource(url, user, password), "TASETE_PREFERENCES", "USER_ID", "ITEM_ID", "PREFERENCE", "UPDATING_TIME");
コンストラクタでDatasourceとテーブル名とそれぞれのカラム名を設定する。この辺はMySQLJDBCDataModelと同じ。カラム名変えたらここを変更するだけでいい。
最後はH2Datasource.java web.xml使わずに書く方法を忘れないようにメモ。
まとめ
今回は0.4を使ってるのでH2JDBCDataModel#setPreferenceの実装がこのような形になっていますが0.5-SNAPSHOTから各SQLのアクセサが入るのでもうちょっと綺麗な処理に書き換えることができます。この部分はDBの実装でキー重複したらUPDATEの処理があるならそちらを使う方が適切でしょう。