ServletFilterではなくIRequestCycle

DoS攻撃用のIRequestCycleを書きました。
あるIPアドレスから同じURLへ一定期間(expire)に閾値(threshold)以上のアクセスがあるとDosAttackExceptionをぶん投げます。

こういったものはServletFilterで書くのも手ですが、そうするとWicketTesterでテストをするのが手間になってしまうため、IRequestCycleのonBeginRequest()とonEndRequest()で書いています。
JPAのTransaction管理もServletFilterではなくIRequestCycle内で完結させることでTransactionがcommitできなかった場合はIRequestCycle.onRuntimeException(Page page, RuntimeException e)呼べるので便利。 これはまた別の機会に詳しく。

package net.nagaseyasuhito.sandbox;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.Predicate;
import org.apache.wicket.Application;
import org.apache.wicket.MetaDataKey;
import org.apache.wicket.RequestCycle;
import org.apache.wicket.Response;
import org.apache.wicket.protocol.http.WebApplication;
import org.apache.wicket.protocol.http.WebRequest;
import org.apache.wicket.protocol.http.WebRequestCycle;

public class DosDetectRequestCycle extends WebRequestCycle {
	private class Access {
		public String address;
		public long timestamp;
		public String url;
	}

	private static MetaDataKey<List<Access>> ACCESSES_HOLDER = new MetaDataKey<List<Access>>() {
		private static final long serialVersionUID = 1L;
	};

	private long expire;
	private long threshold;

	public DosDetectRequestCycle(WebApplication application, WebRequest request, Response response, long expire, long threshold) {
		super(application, request, response);

		this.expire = expire;
		this.threshold = threshold;
	}

	private List<Access> getAccesses() {
		List<Access> accesses = Application.get().getMetaData(DosDetectRequestCycle.ACCESSES_HOLDER);

		if (accesses == null) {
			Application.get().setMetaData(DosDetectRequestCycle.ACCESSES_HOLDER, new ArrayList<Access>());

			accesses = Application.get().getMetaData(DosDetectRequestCycle.ACCESSES_HOLDER);
		}

		return accesses;
	}

	@Override
	protected void onBeginRequest() {
		final String address = ((WebRequest) RequestCycle.get().getRequest()).getHttpServletRequest().getRemoteAddr();
		final String url = ((WebRequest) RequestCycle.get().getRequest()).getURL();
		final long border = System.currentTimeMillis() - this.expire;

		Access access = new Access();
		access.address = address;
		access.timestamp = System.currentTimeMillis();
		access.url = url;

		this.getAccesses().add(access);

		// cleanup
		CollectionUtils.filter(this.getAccesses(), new Predicate() {
			@Override
			public boolean evaluate(Object object) {
				Access access = (Access) object;

				return access.timestamp > border;
			}
		});

		// count
		int count = CollectionUtils.countMatches(this.getAccesses(), new Predicate() {
			@Override
			public boolean evaluate(Object object) {
				Access access = (Access) object;

				return access.address.equals(address) && access.url.equals(url);
			}
		});

		if (count > this.threshold) {
			throw new DosAttackException();
		}
	}
}

使う場合はWebApplication.newRequestCycle(Request request, Response response)をoverrideします。

@Override
public RequestCycle newRequestCycle(Request request, Response response) {
	return new DosDetectRequestCycle(this, (WebRequest) request, (WebResponse) response, 30000, 30);
}