Commit 6590d8ad authored by Him188's avatar Him188

Fix message select

parent 1af13913
......@@ -7,6 +7,8 @@
* https://github.com/mamoe/mirai/blob/master/LICENSE
*/
@file:Suppress("DuplicatedCode")
package net.mamoe.mirai.event
import kotlinx.coroutines.*
......@@ -134,10 +136,46 @@ abstract class MessageSelectBuilder<M : ContactMessage, R> @PublishedApi interna
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.reply(block: suspend () -> Any?): Nothing = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.reply(message: String): Nothing = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.reply(message: Message): Nothing = error("prohibited")
@JvmName("reply3")
@Suppress(
"INAPPLICABLE_JVM_NAME",
"unused",
"UNCHECKED_CAST",
"INVALID_CHARACTERS",
"NAME_CONTAINS_ILLEGAL_CHARS",
"FunctionName"
)
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.`->`(message: String): Nothing = error("prohibited")
@JvmName("reply3")
@Suppress(
"INAPPLICABLE_JVM_NAME",
"unused",
"UNCHECKED_CAST",
"INVALID_CHARACTERS",
"NAME_CONTAINS_ILLEGAL_CHARS",
"FunctionName"
)
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.`->`(message: Message): Nothing = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.quoteReply(block: suspend () -> Any?): Nothing =
error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.quoteReply(message: String): Nothing = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.quoteReply(message: Message): Nothing = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun String.containsReply(reply: String): Nothing = error("prohibited")
......@@ -172,6 +210,16 @@ abstract class MessageSelectBuilder<M : ContactMessage, R> @PublishedApi interna
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun ListeningFilter.reply(message: Message) = error("prohibited")
@JvmName("reply3")
@Suppress("INAPPLICABLE_JVM_NAME", "INVALID_CHARACTERS", "NAME_CONTAINS_ILLEGAL_CHARS", "FunctionName")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun ListeningFilter.`->`(toReply: String) = error("prohibited")
@JvmName("reply3")
@Suppress("INAPPLICABLE_JVM_NAME", "INVALID_CHARACTERS", "NAME_CONTAINS_ILLEGAL_CHARS", "FunctionName")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun ListeningFilter.`->`(message: Message) = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun ListeningFilter.reply(replier: suspend M.(String) -> Any?) =
error("prohibited")
......@@ -221,7 +269,7 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int
obtainCurrentCoroutineScope().launch {
delay(timeoutMillis)
val deferred = obtainCurrentDeferred() ?: return@launch
if (deferred.isActive) {
if (deferred.isActive && !deferred.isCompleted) {
deferred.completeExceptionally(exception())
}
}
......@@ -236,7 +284,7 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int
obtainCurrentCoroutineScope().launch {
delay(timeoutMillis)
val deferred = obtainCurrentDeferred() ?: return@launch
if (deferred.isActive) {
if (deferred.isActive && !deferred.isCompleted) {
deferred.complete(block())
}
}
......@@ -281,6 +329,48 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int
}
}
@Suppress("unused", "UNCHECKED_CAST")
open infix fun MessageSelectionTimeoutChecker.reply(message: Message) {
return timeout(this.timeoutMillis) {
ownerMessagePacket.reply(message)
Unit as R
}
}
@Suppress("unused", "UNCHECKED_CAST")
open infix fun MessageSelectionTimeoutChecker.reply(message: String) {
return timeout(this.timeoutMillis) {
ownerMessagePacket.reply(message)
Unit as R
}
}
@JvmName("reply3")
@Suppress(
"INAPPLICABLE_JVM_NAME",
"unused",
"UNCHECKED_CAST",
"INVALID_CHARACTERS",
"NAME_CONTAINS_ILLEGAL_CHARS",
"FunctionName"
)
open infix fun MessageSelectionTimeoutChecker.`->`(message: Message) {
return this.reply(message)
}
@JvmName("reply3")
@Suppress(
"INAPPLICABLE_JVM_NAME",
"unused",
"UNCHECKED_CAST",
"INVALID_CHARACTERS",
"NAME_CONTAINS_ILLEGAL_CHARS",
"FunctionName"
)
open infix fun MessageSelectionTimeoutChecker.`->`(message: String) {
return this.reply(message)
}
/**
* 在超时后引用回复原消息
*
......@@ -297,6 +387,22 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int
}
}
@Suppress("unused", "UNCHECKED_CAST")
open infix fun MessageSelectionTimeoutChecker.quoteReply(message: Message) {
return timeout(this.timeoutMillis) {
ownerMessagePacket.quoteReply(message)
Unit as R
}
}
@Suppress("unused", "UNCHECKED_CAST")
open infix fun MessageSelectionTimeoutChecker.quoteReply(message: String) {
return timeout(this.timeoutMillis) {
ownerMessagePacket.quoteReply(message)
Unit as R
}
}
/**
* 当其他条件都不满足时回复原消息.
*
......@@ -359,16 +465,24 @@ internal suspend inline fun <R> withTimeoutOrCoroutineScope(
): R {
require(timeoutMillis == -1L || timeoutMillis > 0) { "timeoutMillis must be -1 or > 0 " }
return if (timeoutMillis == -1L) {
coroutineScope(block)
} else {
withTimeout(timeoutMillis, block)
return withContext(ExceptionHandlerIgnoringCancellationException) {
if (timeoutMillis == -1L) {
coroutineScope(block)
} else {
withTimeout(timeoutMillis, block)
}
}
}
@PublishedApi
internal val SELECT_MESSAGE_STUB = Any()
@PublishedApi
internal val ExceptionHandlerIgnoringCancellationException = CoroutineExceptionHandler { _, throwable ->
if (throwable !is CancellationException) {
throw throwable
}
}
@PublishedApi
@BuilderInference
......@@ -379,7 +493,10 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
@BuilderInference
crossinline selectBuilder: @MessageDsl MessageSelectBuilderUnit<T, R>.() -> Unit
): R = withTimeoutOrCoroutineScope(timeoutMillis) {
val deferred = CompletableDeferred<R>()
var deferred: CompletableDeferred<R>? = CompletableDeferred()
coroutineContext[Job]!!.invokeOnCompletion {
deferred?.cancel()
}
// ensure sequential invoking
val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf()
......@@ -421,14 +538,13 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
// we don't have any way to reduce duplication yet,
// until local functions are supported in inline functions
@Suppress("DuplicatedCode")
subscribeAlways<T> { event ->
@Suppress("DuplicatedCode") val subscribeAlways = subscribeAlways<T> { event ->
if (!this.isContextIdenticalWith(this@selectMessagesImpl))
return@subscribeAlways
val toString = event.message.toString()
listeners.forEach { (filter, listener) ->
if (deferred.isCompleted || !isActive)
if (deferred?.isCompleted == true || !isActive)
return@subscribeAlways
if (filter.invoke(event, toString)) {
......@@ -436,12 +552,12 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
val value = listener.invoke(event, toString)
if (value !== SELECT_MESSAGE_STUB) {
@Suppress("UNCHECKED_CAST")
deferred.complete(value as R)
deferred?.complete(value as R)
return@subscribeAlways
} else if (isUnit) { // value === stub
// unit mode: we can directly complete this selection
@Suppress("UNCHECKED_CAST")
deferred.complete(Unit as R)
deferred?.complete(Unit as R)
}
}
}
......@@ -450,17 +566,21 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
val value = listener.invoke(event, toString)
if (value !== SELECT_MESSAGE_STUB) {
@Suppress("UNCHECKED_CAST")
deferred.complete(value as R)
deferred?.complete(value as R)
return@subscribeAlways
} else if (isUnit) { // value === stub
// unit mode: we can directly complete this selection
@Suppress("UNCHECKED_CAST")
deferred.complete(Unit as R)
deferred?.complete(Unit as R)
}
}
}
deferred.await().also { coroutineContext[Job]!!.cancelChildren() }
deferred!!.await().also {
subscribeAlways.complete()
deferred = null
coroutineContext.cancelChildren()
}
}
@Suppress("unused")
......@@ -468,50 +588,43 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
internal suspend inline fun <reified T : ContactMessage> T.whileSelectMessagesImpl(
timeoutMillis: Long = -1,
crossinline selectBuilder: @MessageDsl MessageSelectBuilder<T, Boolean>.() -> Unit
) {
withTimeoutOrCoroutineScope(timeoutMillis) {
var deferred: CompletableDeferred<Boolean>? = CompletableDeferred()
) = withTimeoutOrCoroutineScope(timeoutMillis) {
var deferred: CompletableDeferred<Boolean>? = CompletableDeferred()
coroutineContext[Job]!!.invokeOnCompletion {
deferred?.cancel()
}
// ensure sequential invoking
val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf()
val defaltListeners: MutableList<MessageListener<T, Any?>> = mutableListOf()
// ensure sequential invoking
val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf()
val defaultListeners: MutableList<MessageListener<T, Any?>> = mutableListOf()
// https://youtrack.jetbrains.com/issue/KT-37716
val outside = { filter: T.(String) -> Boolean, listener: MessageListener<T, Any?> ->
listeners += filter to listener
// https://youtrack.jetbrains.com/issue/KT-37716
val outside = { filter: T.(String) -> Boolean, listener: MessageListener<T, Any?> ->
listeners += filter to listener
}
object : MessageSelectBuilder<T, Boolean>(
this@whileSelectMessagesImpl,
SELECT_MESSAGE_STUB,
outside
) {
override fun obtainCurrentCoroutineScope(): CoroutineScope = this@withTimeoutOrCoroutineScope
override fun obtainCurrentDeferred(): CompletableDeferred<Boolean>? = deferred
override fun default(onEvent: MessageListener<T, Boolean>) {
defaultListeners += onEvent
}
object : MessageSelectBuilder<T, Boolean>(
this@whileSelectMessagesImpl,
SELECT_MESSAGE_STUB,
outside
) {
override fun obtainCurrentCoroutineScope(): CoroutineScope = this@withTimeoutOrCoroutineScope
override fun obtainCurrentDeferred(): CompletableDeferred<Boolean>? = deferred
override fun default(onEvent: MessageListener<T, Boolean>) {
defaltListeners += onEvent
}
}.apply(selectBuilder)
}.apply(selectBuilder)
// ensure atomic completing
subscribeAlways<T>(concurrency = Listener.ConcurrencyKind.LOCKED) { event ->
if (!this.isContextIdenticalWith(this@whileSelectMessagesImpl))
return@subscribeAlways
// ensure atomic completing
val subscribeAlways = subscribeAlways<T>(concurrency = Listener.ConcurrencyKind.LOCKED) { event ->
if (!this.isContextIdenticalWith(this@whileSelectMessagesImpl))
return@subscribeAlways
val toString = event.message.toString()
listeners.forEach { (filter, listener) ->
if (deferred?.isCompleted != false || !isActive)
return@subscribeAlways
val toString = event.message.toString()
listeners.forEach { (filter, listener) ->
if (deferred?.isCompleted != false || !isActive)
return@subscribeAlways
if (filter.invoke(event, toString)) {
listener.invoke(event, toString).let { value ->
if (value !== SELECT_MESSAGE_STUB) {
deferred?.complete(value as Boolean)
return@subscribeAlways // accept the first value only
}
}
}
}
defaltListeners.forEach { listener ->
if (filter.invoke(event, toString)) {
listener.invoke(event, toString).let { value ->
if (value !== SELECT_MESSAGE_STUB) {
deferred?.complete(value as Boolean)
......@@ -520,11 +633,20 @@ internal suspend inline fun <reified T : ContactMessage> T.whileSelectMessagesIm
}
}
}
while (deferred?.await() == true) {
deferred = CompletableDeferred()
defaultListeners.forEach { listener ->
listener.invoke(event, toString).let { value ->
if (value !== SELECT_MESSAGE_STUB) {
deferred?.complete(value as Boolean)
return@subscribeAlways // accept the first value only
}
}
}
deferred = null
coroutineContext[Job]!!.cancelChildren()
}
while (deferred?.await() == true) {
deferred = CompletableDeferred()
}
subscribeAlways.complete()
deferred = null
coroutineContext.cancelChildren()
}
\ No newline at end of file
......@@ -331,6 +331,20 @@ open class MessageSubscribersBuilder<M : ContactMessage, out Ret, R : RR, RR>(
return content(filter) { reply(message);this@MessageSubscribersBuilder.stub }
}
@JvmName("reply3")
@Suppress("INAPPLICABLE_JVM_NAME", "INVALID_CHARACTERS", "NAME_CONTAINS_ILLEGAL_CHARS", "FunctionName")
@SinceMirai("0.33.0")
open infix fun ListeningFilter.`->`(toReply: String): Ret {
return this.reply(toReply)
}
@JvmName("reply3")
@Suppress("INAPPLICABLE_JVM_NAME", "INVALID_CHARACTERS", "NAME_CONTAINS_ILLEGAL_CHARS", "FunctionName")
@SinceMirai("0.33.0")
open infix fun ListeningFilter.`->`(message: Message): Ret {
return this.reply(message)
}
@SinceMirai("0.29.0")
open infix fun ListeningFilter.reply(replier: (@MessageDsl suspend M.(String) -> Any?)): Ret {
return content(filter) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment