Commit 6590d8ad authored by Him188's avatar Him188

Fix message select

parent 1af13913
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
* https://github.com/mamoe/mirai/blob/master/LICENSE * https://github.com/mamoe/mirai/blob/master/LICENSE
*/ */
@file:Suppress("DuplicatedCode")
package net.mamoe.mirai.event package net.mamoe.mirai.event
import kotlinx.coroutines.* import kotlinx.coroutines.*
...@@ -134,10 +136,46 @@ abstract class MessageSelectBuilder<M : ContactMessage, R> @PublishedApi interna ...@@ -134,10 +136,46 @@ abstract class MessageSelectBuilder<M : ContactMessage, R> @PublishedApi interna
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN) @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.reply(block: suspend () -> Any?): Nothing = error("prohibited") override infix fun MessageSelectionTimeoutChecker.reply(block: suspend () -> Any?): Nothing = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.reply(message: String): Nothing = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.reply(message: Message): Nothing = error("prohibited")
@JvmName("reply3")
@Suppress(
"INAPPLICABLE_JVM_NAME",
"unused",
"UNCHECKED_CAST",
"INVALID_CHARACTERS",
"NAME_CONTAINS_ILLEGAL_CHARS",
"FunctionName"
)
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.`->`(message: String): Nothing = error("prohibited")
@JvmName("reply3")
@Suppress(
"INAPPLICABLE_JVM_NAME",
"unused",
"UNCHECKED_CAST",
"INVALID_CHARACTERS",
"NAME_CONTAINS_ILLEGAL_CHARS",
"FunctionName"
)
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.`->`(message: Message): Nothing = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN) @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.quoteReply(block: suspend () -> Any?): Nothing = override infix fun MessageSelectionTimeoutChecker.quoteReply(block: suspend () -> Any?): Nothing =
error("prohibited") error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.quoteReply(message: String): Nothing = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.quoteReply(message: Message): Nothing = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN) @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun String.containsReply(reply: String): Nothing = error("prohibited") override fun String.containsReply(reply: String): Nothing = error("prohibited")
...@@ -172,6 +210,16 @@ abstract class MessageSelectBuilder<M : ContactMessage, R> @PublishedApi interna ...@@ -172,6 +210,16 @@ abstract class MessageSelectBuilder<M : ContactMessage, R> @PublishedApi interna
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN) @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun ListeningFilter.reply(message: Message) = error("prohibited") override fun ListeningFilter.reply(message: Message) = error("prohibited")
@JvmName("reply3")
@Suppress("INAPPLICABLE_JVM_NAME", "INVALID_CHARACTERS", "NAME_CONTAINS_ILLEGAL_CHARS", "FunctionName")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun ListeningFilter.`->`(toReply: String) = error("prohibited")
@JvmName("reply3")
@Suppress("INAPPLICABLE_JVM_NAME", "INVALID_CHARACTERS", "NAME_CONTAINS_ILLEGAL_CHARS", "FunctionName")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun ListeningFilter.`->`(message: Message) = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN) @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun ListeningFilter.reply(replier: suspend M.(String) -> Any?) = override fun ListeningFilter.reply(replier: suspend M.(String) -> Any?) =
error("prohibited") error("prohibited")
...@@ -221,7 +269,7 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int ...@@ -221,7 +269,7 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int
obtainCurrentCoroutineScope().launch { obtainCurrentCoroutineScope().launch {
delay(timeoutMillis) delay(timeoutMillis)
val deferred = obtainCurrentDeferred() ?: return@launch val deferred = obtainCurrentDeferred() ?: return@launch
if (deferred.isActive) { if (deferred.isActive && !deferred.isCompleted) {
deferred.completeExceptionally(exception()) deferred.completeExceptionally(exception())
} }
} }
...@@ -236,7 +284,7 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int ...@@ -236,7 +284,7 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int
obtainCurrentCoroutineScope().launch { obtainCurrentCoroutineScope().launch {
delay(timeoutMillis) delay(timeoutMillis)
val deferred = obtainCurrentDeferred() ?: return@launch val deferred = obtainCurrentDeferred() ?: return@launch
if (deferred.isActive) { if (deferred.isActive && !deferred.isCompleted) {
deferred.complete(block()) deferred.complete(block())
} }
} }
...@@ -281,6 +329,48 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int ...@@ -281,6 +329,48 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int
} }
} }
@Suppress("unused", "UNCHECKED_CAST")
open infix fun MessageSelectionTimeoutChecker.reply(message: Message) {
return timeout(this.timeoutMillis) {
ownerMessagePacket.reply(message)
Unit as R
}
}
@Suppress("unused", "UNCHECKED_CAST")
open infix fun MessageSelectionTimeoutChecker.reply(message: String) {
return timeout(this.timeoutMillis) {
ownerMessagePacket.reply(message)
Unit as R
}
}
@JvmName("reply3")
@Suppress(
"INAPPLICABLE_JVM_NAME",
"unused",
"UNCHECKED_CAST",
"INVALID_CHARACTERS",
"NAME_CONTAINS_ILLEGAL_CHARS",
"FunctionName"
)
open infix fun MessageSelectionTimeoutChecker.`->`(message: Message) {
return this.reply(message)
}
@JvmName("reply3")
@Suppress(
"INAPPLICABLE_JVM_NAME",
"unused",
"UNCHECKED_CAST",
"INVALID_CHARACTERS",
"NAME_CONTAINS_ILLEGAL_CHARS",
"FunctionName"
)
open infix fun MessageSelectionTimeoutChecker.`->`(message: String) {
return this.reply(message)
}
/** /**
* 在超时后引用回复原消息 * 在超时后引用回复原消息
* *
...@@ -297,6 +387,22 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int ...@@ -297,6 +387,22 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int
} }
} }
@Suppress("unused", "UNCHECKED_CAST")
open infix fun MessageSelectionTimeoutChecker.quoteReply(message: Message) {
return timeout(this.timeoutMillis) {
ownerMessagePacket.quoteReply(message)
Unit as R
}
}
@Suppress("unused", "UNCHECKED_CAST")
open infix fun MessageSelectionTimeoutChecker.quoteReply(message: String) {
return timeout(this.timeoutMillis) {
ownerMessagePacket.quoteReply(message)
Unit as R
}
}
/** /**
* 当其他条件都不满足时回复原消息. * 当其他条件都不满足时回复原消息.
* *
...@@ -359,16 +465,24 @@ internal suspend inline fun <R> withTimeoutOrCoroutineScope( ...@@ -359,16 +465,24 @@ internal suspend inline fun <R> withTimeoutOrCoroutineScope(
): R { ): R {
require(timeoutMillis == -1L || timeoutMillis > 0) { "timeoutMillis must be -1 or > 0 " } require(timeoutMillis == -1L || timeoutMillis > 0) { "timeoutMillis must be -1 or > 0 " }
return if (timeoutMillis == -1L) { return withContext(ExceptionHandlerIgnoringCancellationException) {
coroutineScope(block) if (timeoutMillis == -1L) {
} else { coroutineScope(block)
withTimeout(timeoutMillis, block) } else {
withTimeout(timeoutMillis, block)
}
} }
} }
@PublishedApi @PublishedApi
internal val SELECT_MESSAGE_STUB = Any() internal val SELECT_MESSAGE_STUB = Any()
@PublishedApi
internal val ExceptionHandlerIgnoringCancellationException = CoroutineExceptionHandler { _, throwable ->
if (throwable !is CancellationException) {
throw throwable
}
}
@PublishedApi @PublishedApi
@BuilderInference @BuilderInference
...@@ -379,7 +493,10 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl ...@@ -379,7 +493,10 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
@BuilderInference @BuilderInference
crossinline selectBuilder: @MessageDsl MessageSelectBuilderUnit<T, R>.() -> Unit crossinline selectBuilder: @MessageDsl MessageSelectBuilderUnit<T, R>.() -> Unit
): R = withTimeoutOrCoroutineScope(timeoutMillis) { ): R = withTimeoutOrCoroutineScope(timeoutMillis) {
val deferred = CompletableDeferred<R>() var deferred: CompletableDeferred<R>? = CompletableDeferred()
coroutineContext[Job]!!.invokeOnCompletion {
deferred?.cancel()
}
// ensure sequential invoking // ensure sequential invoking
val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf() val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf()
...@@ -421,14 +538,13 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl ...@@ -421,14 +538,13 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
// we don't have any way to reduce duplication yet, // we don't have any way to reduce duplication yet,
// until local functions are supported in inline functions // until local functions are supported in inline functions
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode") val subscribeAlways = subscribeAlways<T> { event ->
subscribeAlways<T> { event ->
if (!this.isContextIdenticalWith(this@selectMessagesImpl)) if (!this.isContextIdenticalWith(this@selectMessagesImpl))
return@subscribeAlways return@subscribeAlways
val toString = event.message.toString() val toString = event.message.toString()
listeners.forEach { (filter, listener) -> listeners.forEach { (filter, listener) ->
if (deferred.isCompleted || !isActive) if (deferred?.isCompleted == true || !isActive)
return@subscribeAlways return@subscribeAlways
if (filter.invoke(event, toString)) { if (filter.invoke(event, toString)) {
...@@ -436,12 +552,12 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl ...@@ -436,12 +552,12 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
val value = listener.invoke(event, toString) val value = listener.invoke(event, toString)
if (value !== SELECT_MESSAGE_STUB) { if (value !== SELECT_MESSAGE_STUB) {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
deferred.complete(value as R) deferred?.complete(value as R)
return@subscribeAlways return@subscribeAlways
} else if (isUnit) { // value === stub } else if (isUnit) { // value === stub
// unit mode: we can directly complete this selection // unit mode: we can directly complete this selection
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
deferred.complete(Unit as R) deferred?.complete(Unit as R)
} }
} }
} }
...@@ -450,17 +566,21 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl ...@@ -450,17 +566,21 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
val value = listener.invoke(event, toString) val value = listener.invoke(event, toString)
if (value !== SELECT_MESSAGE_STUB) { if (value !== SELECT_MESSAGE_STUB) {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
deferred.complete(value as R) deferred?.complete(value as R)
return@subscribeAlways return@subscribeAlways
} else if (isUnit) { // value === stub } else if (isUnit) { // value === stub
// unit mode: we can directly complete this selection // unit mode: we can directly complete this selection
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
deferred.complete(Unit as R) deferred?.complete(Unit as R)
} }
} }
} }
deferred.await().also { coroutineContext[Job]!!.cancelChildren() } deferred!!.await().also {
subscribeAlways.complete()
deferred = null
coroutineContext.cancelChildren()
}
} }
@Suppress("unused") @Suppress("unused")
...@@ -468,50 +588,43 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl ...@@ -468,50 +588,43 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
internal suspend inline fun <reified T : ContactMessage> T.whileSelectMessagesImpl( internal suspend inline fun <reified T : ContactMessage> T.whileSelectMessagesImpl(
timeoutMillis: Long = -1, timeoutMillis: Long = -1,
crossinline selectBuilder: @MessageDsl MessageSelectBuilder<T, Boolean>.() -> Unit crossinline selectBuilder: @MessageDsl MessageSelectBuilder<T, Boolean>.() -> Unit
) { ) = withTimeoutOrCoroutineScope(timeoutMillis) {
withTimeoutOrCoroutineScope(timeoutMillis) { var deferred: CompletableDeferred<Boolean>? = CompletableDeferred()
var deferred: CompletableDeferred<Boolean>? = CompletableDeferred() coroutineContext[Job]!!.invokeOnCompletion {
deferred?.cancel()
}
// ensure sequential invoking // ensure sequential invoking
val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf() val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf()
val defaltListeners: MutableList<MessageListener<T, Any?>> = mutableListOf() val defaultListeners: MutableList<MessageListener<T, Any?>> = mutableListOf()
// https://youtrack.jetbrains.com/issue/KT-37716 // https://youtrack.jetbrains.com/issue/KT-37716
val outside = { filter: T.(String) -> Boolean, listener: MessageListener<T, Any?> -> val outside = { filter: T.(String) -> Boolean, listener: MessageListener<T, Any?> ->
listeners += filter to listener listeners += filter to listener
}
object : MessageSelectBuilder<T, Boolean>(
this@whileSelectMessagesImpl,
SELECT_MESSAGE_STUB,
outside
) {
override fun obtainCurrentCoroutineScope(): CoroutineScope = this@withTimeoutOrCoroutineScope
override fun obtainCurrentDeferred(): CompletableDeferred<Boolean>? = deferred
override fun default(onEvent: MessageListener<T, Boolean>) {
defaultListeners += onEvent
} }
object : MessageSelectBuilder<T, Boolean>( }.apply(selectBuilder)
this@whileSelectMessagesImpl,
SELECT_MESSAGE_STUB,
outside
) {
override fun obtainCurrentCoroutineScope(): CoroutineScope = this@withTimeoutOrCoroutineScope
override fun obtainCurrentDeferred(): CompletableDeferred<Boolean>? = deferred
override fun default(onEvent: MessageListener<T, Boolean>) {
defaltListeners += onEvent
}
}.apply(selectBuilder)
// ensure atomic completing // ensure atomic completing
subscribeAlways<T>(concurrency = Listener.ConcurrencyKind.LOCKED) { event -> val subscribeAlways = subscribeAlways<T>(concurrency = Listener.ConcurrencyKind.LOCKED) { event ->
if (!this.isContextIdenticalWith(this@whileSelectMessagesImpl)) if (!this.isContextIdenticalWith(this@whileSelectMessagesImpl))
return@subscribeAlways return@subscribeAlways
val toString = event.message.toString() val toString = event.message.toString()
listeners.forEach { (filter, listener) -> listeners.forEach { (filter, listener) ->
if (deferred?.isCompleted != false || !isActive) if (deferred?.isCompleted != false || !isActive)
return@subscribeAlways return@subscribeAlways
if (filter.invoke(event, toString)) { if (filter.invoke(event, toString)) {
listener.invoke(event, toString).let { value ->
if (value !== SELECT_MESSAGE_STUB) {
deferred?.complete(value as Boolean)
return@subscribeAlways // accept the first value only
}
}
}
}
defaltListeners.forEach { listener ->
listener.invoke(event, toString).let { value -> listener.invoke(event, toString).let { value ->
if (value !== SELECT_MESSAGE_STUB) { if (value !== SELECT_MESSAGE_STUB) {
deferred?.complete(value as Boolean) deferred?.complete(value as Boolean)
...@@ -520,11 +633,20 @@ internal suspend inline fun <reified T : ContactMessage> T.whileSelectMessagesIm ...@@ -520,11 +633,20 @@ internal suspend inline fun <reified T : ContactMessage> T.whileSelectMessagesIm
} }
} }
} }
defaultListeners.forEach { listener ->
while (deferred?.await() == true) { listener.invoke(event, toString).let { value ->
deferred = CompletableDeferred() if (value !== SELECT_MESSAGE_STUB) {
deferred?.complete(value as Boolean)
return@subscribeAlways // accept the first value only
}
}
} }
deferred = null
coroutineContext[Job]!!.cancelChildren()
} }
while (deferred?.await() == true) {
deferred = CompletableDeferred()
}
subscribeAlways.complete()
deferred = null
coroutineContext.cancelChildren()
} }
\ No newline at end of file
...@@ -331,6 +331,20 @@ open class MessageSubscribersBuilder<M : ContactMessage, out Ret, R : RR, RR>( ...@@ -331,6 +331,20 @@ open class MessageSubscribersBuilder<M : ContactMessage, out Ret, R : RR, RR>(
return content(filter) { reply(message);this@MessageSubscribersBuilder.stub } return content(filter) { reply(message);this@MessageSubscribersBuilder.stub }
} }
@JvmName("reply3")
@Suppress("INAPPLICABLE_JVM_NAME", "INVALID_CHARACTERS", "NAME_CONTAINS_ILLEGAL_CHARS", "FunctionName")
@SinceMirai("0.33.0")
open infix fun ListeningFilter.`->`(toReply: String): Ret {
return this.reply(toReply)
}
@JvmName("reply3")
@Suppress("INAPPLICABLE_JVM_NAME", "INVALID_CHARACTERS", "NAME_CONTAINS_ILLEGAL_CHARS", "FunctionName")
@SinceMirai("0.33.0")
open infix fun ListeningFilter.`->`(message: Message): Ret {
return this.reply(message)
}
@SinceMirai("0.29.0") @SinceMirai("0.29.0")
open infix fun ListeningFilter.reply(replier: (@MessageDsl suspend M.(String) -> Any?)): Ret { open infix fun ListeningFilter.reply(replier: (@MessageDsl suspend M.(String) -> Any?)): Ret {
return content(filter) { return content(filter) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment