Skip to content
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2015 the original author or authors.
* Copyright 2015-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,5 +64,4 @@ public interface StateMachine<S, E> extends Region<S, E> {
* @return true, if error has been set
*/
boolean hasStateMachineError();

}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
/*
* Copyright 2019-2020 the original author or authors.
* Copyright 2019-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -15,6 +15,7 @@
*/
package org.springframework.statemachine;

import java.util.Optional;
import org.springframework.messaging.Message;
import org.springframework.statemachine.region.Region;

Expand Down Expand Up @@ -59,6 +60,14 @@ public interface StateMachineEventResult<S, E> {
*/
Mono<Void> complete();

/**
* If there was an exception that caused the transition to be denied - return that
* @return Optional Throwable that caused the transition to be denied
*/
default Optional<Throwable> getDenialCause() {
return Optional.empty();
};

/**
* Enumeration of a result type indicating whether a region accepted, denied or
* deferred an event.
Expand All @@ -82,7 +91,7 @@ public enum ResultType {
*/
public static <S, E> StateMachineEventResult<S, E> from(Region<S, E> region, Message<E> message,
ResultType resultType) {
return new DefaultStateMachineEventResult<>(region, message, resultType, null);
return new DefaultStateMachineEventResult<>(region, message, resultType, null, null);
}


Expand All @@ -100,7 +109,24 @@ public static <S, E> StateMachineEventResult<S, E> from(Region<S, E> region, Mes
*/
public static <S, E> StateMachineEventResult<S, E> from(Region<S, E> region, Message<E> message,
ResultType resultType, Mono<Void> complete) {
return new DefaultStateMachineEventResult<>(region, message, resultType, complete);
return new DefaultStateMachineEventResult<>(region, message, resultType, complete, null);
}

/**
* Create a {@link StateMachineEventResult} from a {@link Region},
* {@link Message} and a {@link ResultType}.
*
* @param <S> the type of state
* @param <E> the type of event
* @param region the region
* @param message the message
* @param resultType the result type
* @param denialCause the throwable (that most likely caused transition denial)
* @return the state machine event result
*/
public static <S, E> StateMachineEventResult<S, E> from(Region<S, E> region, Message<E> message,
ResultType resultType, Throwable denialCause) {
return new DefaultStateMachineEventResult<>(region, message, resultType, null, denialCause);
}

static class DefaultStateMachineEventResult<S, E> implements StateMachineEventResult<S, E> {
Expand All @@ -109,13 +135,15 @@ static class DefaultStateMachineEventResult<S, E> implements StateMachineEventRe
private final Message<E> message;
private final ResultType resultType;
private Mono<Void> complete;
private Throwable denialCause;

DefaultStateMachineEventResult(Region<S, E> region, Message<E> message, ResultType resultType,
Mono<Void> complete) {
Mono<Void> complete, Throwable denialCause) {
this.region = region;
this.message = message;
this.resultType = resultType;
this.complete = complete != null ? complete : Mono.empty();
this.denialCause = denialCause;
}

@Override
Expand All @@ -138,6 +166,11 @@ public Mono<Void> complete() {
return complete;
}

@Override
public Optional<Throwable> getDenialCause() {
return Optional.ofNullable(denialCause);
}

@Override
public String toString() {
return "DefaultStateMachineEventResult [region=" + region + ", message=" + message + ", resultType="
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,12 @@ public void setTransitionConflightPolicy(TransitionConflictPolicy transitionConf

private Flux<StateMachineEventResult<S, E>> handleEvent(Message<E> message) {
if (hasStateMachineError()) {
return Flux.just(StateMachineEventResult.<S, E>from(this, message, ResultType.DENIED));
return Flux.just(StateMachineEventResult.<S, E>from(this, message, ResultType.DENIED, currentError.getCause()));
}
return Mono.just(message)
.map(m -> getStateMachineInterceptors().preEvent(m, this))
.flatMapMany(m -> acceptEvent(m))
.onErrorResume(error -> Flux.just(StateMachineEventResult.<S, E>from(this, message, ResultType.DENIED)))
.onErrorResume(error -> Flux.just(StateMachineEventResult.<S, E>from(this, message, ResultType.DENIED, error.getCause())))
.doOnNext(notifyOnDenied());
}

Expand Down Expand Up @@ -668,7 +668,7 @@ private Flux<StateMachineEventResult<S, E>> acceptEvent(Message<E> message) {
}))
.onErrorResume(t -> {
return Mono.defer(() -> {
return Mono.just(StateMachineEventResult.<S, E>from(this, message, ResultType.DENIED));
return Mono.just(StateMachineEventResult.<S, E>from(this, message, ResultType.DENIED, t.getCause()));
});
});
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.springframework.beans.factory.BeanFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.statemachine.StateMachineEventResult.ResultType;
import org.springframework.statemachine.action.Action;
import org.springframework.statemachine.config.StateMachineFactory;
Expand Down Expand Up @@ -125,6 +126,15 @@ public static <S, E> void doSendEventAndConsumeResultAsDenied(StateMachine<S, E>
.verifyComplete();
}

public static <S, E> void doSendEventAndConsumeResultAsDeniedWithAccessDeniedException(StateMachine<S, E> stateMachine, E event) {
StepVerifier.create(stateMachine.sendEvent(eventAsMono(event)))
.consumeNextWith(result -> {
assertThat(result.getResultType()).isEqualTo(ResultType.DENIED);
assertThat(result.getDenialCause().map(t -> t instanceof AccessDeniedException).orElse(false)).isTrue();
})
.verifyComplete();
}

public static <S, E> void doSendEventAndConsumeResultAsDenied(StateMachine<S, E> stateMachine, Message<E> event) {
StepVerifier.create(stateMachine.sendEvent(eventAsMono(event)))
.consumeNextWith(result -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2015-2020 the original author or authors.
* Copyright 2015-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,10 +18,12 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.statemachine.TestUtils.doSendEventAndConsumeAll;
import static org.springframework.statemachine.TestUtils.doSendEventAndConsumeResultAsDenied;
import static org.springframework.statemachine.TestUtils.doSendEventAndConsumeResultAsDeniedWithAccessDeniedException;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.springframework.security.access.AccessDeniedException;
import org.springframework.statemachine.AbstractStateMachineTests;
import org.springframework.statemachine.StateMachine;
import org.springframework.statemachine.config.StateMachineBuilder;
Expand Down Expand Up @@ -53,7 +55,7 @@ protected static void assertTransitionDenied(StateMachine<States, Events> machin
assertThat(machine.getState().getIds()).containsOnly(States.S0);

listener.reset(1);
doSendEventAndConsumeAll(machine, Events.A);
doSendEventAndConsumeResultAsDeniedWithAccessDeniedException(machine, Events.A);
assertThat(listener.stateChangedLatch.await(2, TimeUnit.SECONDS)).isFalse();
assertThat(listener.stateChangedCount).isZero();
assertThat(machine.getState().getIds()).containsOnly(States.S0);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2015-2020 the original author or authors.
* Copyright 2015-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,7 +35,7 @@ public class EventSecurityTests extends AbstractSecurityTests {
public void testNoSecurityContext() throws Exception {
TestListener listener = new TestListener();
StateMachine<States, Events> machine = buildMachine(listener, "ROLE_ANONYMOUS", ComparisonType.ANY, null);
assertTransitionDeniedResultAsDenied(machine, listener);
assertTransitionDenied(machine, listener);
}

@Test
Expand Down