ServletFilterではなくIRequestCycle
対DoS攻撃用のIRequestCycleを書きました。
あるIPアドレスから同じURLへ一定期間(expire)に閾値(threshold)以上のアクセスがあるとDosAttackExceptionをぶん投げます。
こういったものはServletFilterで書くのも手ですが、そうするとWicketTesterでテストをするのが手間になってしまうため、IRequestCycleのonBeginRequest()とonEndRequest()で書いています。
JPAのTransaction管理もServletFilterではなくIRequestCycle内で完結させることでTransactionがcommitできなかった場合はIRequestCycle.onRuntimeException(Page page, RuntimeException e)呼べるので便利。 これはまた別の機会に詳しく。
package net.nagaseyasuhito.sandbox; import java.util.ArrayList; import java.util.List; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.Predicate; import org.apache.wicket.Application; import org.apache.wicket.MetaDataKey; import org.apache.wicket.RequestCycle; import org.apache.wicket.Response; import org.apache.wicket.protocol.http.WebApplication; import org.apache.wicket.protocol.http.WebRequest; import org.apache.wicket.protocol.http.WebRequestCycle; public class DosDetectRequestCycle extends WebRequestCycle { private class Access { public String address; public long timestamp; public String url; } private static MetaDataKey<List<Access>> ACCESSES_HOLDER = new MetaDataKey<List<Access>>() { private static final long serialVersionUID = 1L; }; private long expire; private long threshold; public DosDetectRequestCycle(WebApplication application, WebRequest request, Response response, long expire, long threshold) { super(application, request, response); this.expire = expire; this.threshold = threshold; } private List<Access> getAccesses() { List<Access> accesses = Application.get().getMetaData(DosDetectRequestCycle.ACCESSES_HOLDER); if (accesses == null) { Application.get().setMetaData(DosDetectRequestCycle.ACCESSES_HOLDER, new ArrayList<Access>()); accesses = Application.get().getMetaData(DosDetectRequestCycle.ACCESSES_HOLDER); } return accesses; } @Override protected void onBeginRequest() { final String address = ((WebRequest) RequestCycle.get().getRequest()).getHttpServletRequest().getRemoteAddr(); final String url = ((WebRequest) RequestCycle.get().getRequest()).getURL(); final long border = System.currentTimeMillis() - this.expire; Access access = new Access(); access.address = address; access.timestamp = System.currentTimeMillis(); access.url = url; this.getAccesses().add(access); // cleanup CollectionUtils.filter(this.getAccesses(), new Predicate() { @Override public boolean evaluate(Object object) { Access access = (Access) object; return access.timestamp > border; } }); // count int count = CollectionUtils.countMatches(this.getAccesses(), new Predicate() { @Override public boolean evaluate(Object object) { Access access = (Access) object; return access.address.equals(address) && access.url.equals(url); } }); if (count > this.threshold) { throw new DosAttackException(); } } }
使う場合はWebApplication.newRequestCycle(Request request, Response response)をoverrideします。
@Override public RequestCycle newRequestCycle(Request request, Response response) { return new DosDetectRequestCycle(this, (WebRequest) request, (WebResponse) response, 30000, 30); }