Commit 26b9aa10 authored by Him188's avatar Him188

Jce head fix

parent e41c9231
...@@ -8,12 +8,8 @@ import kotlinx.serialization.modules.EmptyModule ...@@ -8,12 +8,8 @@ import kotlinx.serialization.modules.EmptyModule
import kotlinx.serialization.modules.SerialModule import kotlinx.serialization.modules.SerialModule
import net.mamoe.mirai.qqandroid.io.JceStruct import net.mamoe.mirai.qqandroid.io.JceStruct
import net.mamoe.mirai.qqandroid.io.ProtoBuf import net.mamoe.mirai.qqandroid.io.ProtoBuf
import net.mamoe.mirai.qqandroid.network.protocol.packet.withUse
import net.mamoe.mirai.utils.io.readIoBuffer
import net.mamoe.mirai.utils.io.readString import net.mamoe.mirai.utils.io.readString
import net.mamoe.mirai.utils.io.toIoBuffer import net.mamoe.mirai.utils.io.toReadPacket
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
@PublishedApi @PublishedApi
internal val CharsetGBK = Charset.forName("GBK") internal val CharsetGBK = Charset.forName("GBK")
...@@ -383,11 +379,10 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo ...@@ -383,11 +379,10 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
} }
} }
if (!input.input.endOfInput) { val tag = currentTagOrNull
val tag = currentTagOrNull val jceHead = input.peakHeadOrNull()
if (tag != null && input.peakHead().tag > tag) { if (tag != null && (jceHead == null || jceHead.tag > tag)) {
return NullReader(this.input) return NullReader(this.input)
}
} }
return super.beginStructure(desc, *typeParams) return super.beginStructure(desc, *typeParams)
...@@ -402,7 +397,8 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo ...@@ -402,7 +397,8 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
} }
fun isTagOptional(tag: Int): Boolean { fun isTagOptional(tag: Int): Boolean {
return input.input.endOfInput || input.peakHead().tag > tag val head = input.peakHeadOrNull()
return input.isEndOfInput || head == null || head.tag > tag
} }
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
...@@ -492,7 +488,7 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo ...@@ -492,7 +488,7 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T { override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
return decodeNullableSerializableValue(deserializer as DeserializationStrategy<Any?>) as? T return decodeNullableSerializableValue(deserializer as DeserializationStrategy<Any?>) as? T
?: error("value with tag $currentTagOrNull(by ${deserializer.getClassName()}) is not optional but cannot find") ?: error("value with tag $currentTagOrNull(by ${deserializer.getClassName()}) is not optional but cannot find. currentJceHead = ${input.currentJceHead}")
} }
} }
...@@ -500,33 +496,48 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo ...@@ -500,33 +496,48 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
@UseExperimental(ExperimentalUnsignedTypes::class) @UseExperimental(ExperimentalUnsignedTypes::class)
internal inner class JceInput( internal inner class JceInput(
@PublishedApi @PublishedApi
internal val input: IoBuffer internal val input: ByteReadPacket,
maxReadSize: Long = input.remaining
) : Closeable { ) : Closeable {
override fun close() = IoBuffer.Pool.recycle(input) internal val leastRemaining = input.remaining - maxReadSize
internal val isEndOfInput: Boolean get() = input.remaining <= leastRemaining
@PublishedApi internal var currentJceHead: JceHead? = input.doReadHead().also { println("first jce head = $it") }
internal fun readHead(): JceHead = input.readHead() ?: error("no enough data to read head")
@PublishedApi override fun close() = input.close()
internal fun readHeadOrNull(): JceHead? = input.readHead()
internal fun peakHeadOrNull(): JceHead? = currentJceHead ?: readHeadOrNull()
internal fun peakHead(): JceHead = peakHeadOrNull() ?: error("no enough data to read head")
@PublishedApi @PublishedApi
internal fun peakHead(): JceHead = input.makeView().readHead() ?: error("no enough data to read head") internal fun readHead(): JceHead = readHeadOrNull() ?: error("no enough data to read head")
@PublishedApi @PublishedApi
internal fun peakHeadOrNull(): JceHead? = input.makeView().readHead() internal fun readHeadOrNull(): JceHead? = input.doReadHead()
@Suppress("NOTHING_TO_INLINE") // 避免 stacktrace 出现两个 readHead /**
private inline fun IoBuffer.readHead(): JceHead? { * 读取下一个 head 存储到 [currentJceHead]
if (endOfInput) return null */
private fun ByteReadPacket.doReadHead(): JceHead? {
if (isEndOfInput) {
currentJceHead = null
println("doReadHead: endOfInput")
return null
}
val var2 = readUByte() val var2 = readUByte()
val type = var2 and 15u val type = var2 and 15u
var tag = var2.toUInt() shr 4 var tag = var2.toUInt() shr 4
if (tag == 15u) { if (tag == 15u) {
if (endOfInput) return null if (isEndOfInput) {
currentJceHead = null
println("doReadHead: endOfInput2")
return null
}
tag = readUByte().toUInt() tag = readUByte().toUInt()
} }
return JceHead(tag = tag.toInt(), type = type.toByte()) currentJceHead = JceHead(tag = tag.toInt(), type = type.toByte())
println("doReadHead: $currentJceHead")
return currentJceHead
} }
fun readBoolean(tag: Int): Boolean = readBooleanOrNull(tag) ?: error("cannot find tag $tag") fun readBoolean(tag: Int): Boolean = readBooleanOrNull(tag) ?: error("cannot find tag $tag")
...@@ -583,8 +594,9 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo ...@@ -583,8 +594,9 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
LIST -> ByteArray(readInt(0)) { readByte(0) } LIST -> ByteArray(readInt(0)) { readByte(0) }
SIMPLE_LIST -> { SIMPLE_LIST -> {
val head = readHead() val head = readHead()
readHead()
check(head.type.toInt() == 0) { "type mismatch" } check(head.type.toInt() == 0) { "type mismatch" }
input.readBytes(readInt(0)) input.readBytes(readInt(0).also { println("list size=$it") })
} }
else -> error("type mismatch") else -> error("type mismatch")
} }
...@@ -610,7 +622,7 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo ...@@ -610,7 +622,7 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
input.readUInt().toInt().also { require(it in 1 until 104857600) { "bad string length: $it" } }, input.readUInt().toInt().also { require(it in 1 until 104857600) { "bad string length: $it" } },
charset = charset.kotlinCharset charset = charset.kotlinCharset
) )
else -> error("type mismatch: ${head.type}") else -> error("type mismatch: ${head.type}, expecting 6 or 7 (for string)")
} }
} }
...@@ -762,40 +774,51 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo ...@@ -762,40 +774,51 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
return dumpAsPacket(serializer, obj).readBytes() return dumpAsPacket(serializer, obj).readBytes()
} }
/**
* 注意 close [packet]!!
*/
fun <T> load(deserializer: DeserializationStrategy<T>, packet: ByteReadPacket, length: Int = packet.remaining.toInt()): T { fun <T> load(deserializer: DeserializationStrategy<T>, packet: ByteReadPacket, length: Int = packet.remaining.toInt()): T {
packet.readIoBuffer(n = length).withUse { return JceDecoder(JceInput(packet, length.toLong())).decode(deserializer)
val decoder = JceDecoder(JceInput(this))
return decoder.decode(deserializer)
}
} }
override fun <T> load(deserializer: DeserializationStrategy<T>, bytes: ByteArray): T { override fun <T> load(deserializer: DeserializationStrategy<T>, bytes: ByteArray): T {
return bytes.toIoBuffer().withUse { return bytes.toReadPacket().use {
val decoder = JceDecoder(JceInput(this)) val decoder = JceDecoder(JceInput(it))
decoder.decode(deserializer) decoder.decode(deserializer)
} }
} }
} }
@UseExperimental(ExperimentalContracts::class)
internal inline fun <R> Jce.JceInput.skipToTagOrNull(tag: Int, block: (JceHead) -> R): R? { internal inline fun <R> Jce.JceInput.skipToTagOrNull(tag: Int, block: (JceHead) -> R): R? {
contract { println("skipping to $tag start")
callsInPlace(block, kotlin.contracts.InvocationKind.UNKNOWN)
}
while (true) { while (true) {
if (this.input.endOfInput) { if (isEndOfInput) { // 读不了了
currentJceHead = null
println("skipping to $tag: endOfInput")
return null return null
} }
val head = peakHead() var head = currentJceHead
if (head == null) { // 没有新的 head 了
head = readHeadOrNull() ?: return null
}
if (head.tag > tag) { if (head.tag > tag) {
println("skipping to $tag: head.tag > tag")
return null return null
} }
readHead() // readHead()
if (head.tag == tag) { if (head.tag == tag) {
// readHeadOrNull()
currentJceHead = null
println("skipping to $tag: run block")
return block(head) return block(head)
} else {
println("skipping to $tag: tag not matching")
} }
println("skipping to $tag: skipField")
this.skipField(head.type) this.skipField(head.type)
currentJceHead = readHeadOrNull()
} }
} }
......
...@@ -20,7 +20,9 @@ class JceDecoderTest { ...@@ -20,7 +20,9 @@ class JceDecoderTest {
@SerialId(3) val int: Int = 123, @SerialId(3) val int: Int = 123,
@SerialId(4) val long: Long = 123, @SerialId(4) val long: Long = 123,
@SerialId(5) val float: Float = 123f, @SerialId(5) val float: Float = 123f,
@SerialId(6) val double: Double = 123.0 @SerialId(6) val double: Double = 123.0,
@SerialId(7) val byteArray: ByteArray = byteArrayOf(1, 2, 3),
@SerialId(8) val byteArray2: ByteArray = byteArrayOf(1, 2, 3)
) : JceStruct { ) : JceStruct {
override fun writeTo(output: JceOutput) = output.run { override fun writeTo(output: JceOutput) = output.run {
writeString(string, 0) writeString(string, 0)
...@@ -30,10 +32,81 @@ class JceDecoderTest { ...@@ -30,10 +32,81 @@ class JceDecoderTest {
writeLong(long, 4) writeLong(long, 4)
writeFloat(float, 5) writeFloat(float, 5)
writeDouble(double, 6) writeDouble(double, 6)
writeFully(byteArray, 7)
writeFully(byteArray2, 8)
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as TestSimpleJceStruct
if (string != other.string) return false
if (byte != other.byte) return false
if (short != other.short) return false
if (int != other.int) return false
if (long != other.long) return false
if (float != other.float) return false
if (double != other.double) return false
if (!byteArray.contentEquals(other.byteArray)) return false
if (!byteArray2.contentEquals(other.byteArray2)) return false
return true
}
override fun hashCode(): Int {
var result = string.hashCode()
result = 31 * result + byte
result = 31 * result + short
result = 31 * result + int
result = 31 * result + long.hashCode()
result = 31 * result + float.hashCode()
result = 31 * result + double.hashCode()
result = 31 * result + byteArray.contentHashCode()
result = 31 * result + byteArray2.contentHashCode()
return result
} }
} }
@Test
fun testByteArray() {
@Serializable
data class TestByteArray(
@SerialId(0) val byteArray: ByteArray = byteArrayOf(1, 2, 3)
) : JceStruct {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as TestByteArray
if (!byteArray.contentEquals(other.byteArray)) return false
return true
}
override fun hashCode(): Int {
return byteArray.contentHashCode()
}
}
assertEquals(
TestByteArray(),
TestByteArray().toByteArray(TestByteArray.serializer()).loadAs(TestByteArray.serializer())
)
}
@Test
fun testSimpleStruct() {
assertEquals(
TestSimpleJceStruct(),
TestSimpleJceStruct().toByteArray(TestSimpleJceStruct.serializer()).loadAs(TestSimpleJceStruct.serializer())
)
}
@Serializable @Serializable
class TestComplexJceStruct( class TestComplexJceStruct(
@SerialId(6) val string: String = "haha", @SerialId(6) val string: String = "haha",
...@@ -77,7 +150,7 @@ class JceDecoderTest { ...@@ -77,7 +150,7 @@ class JceDecoderTest {
@Test @Test
fun testNestedList() { fun testNestedList() {
@Serializable @Serializable
class TestNestedList( data class TestNestedList(
@SerialId(7) val array: List<List<Int>> = listOf(listOf(1, 2, 3), listOf(1, 2, 3), listOf(1, 2, 3)) @SerialId(7) val array: List<List<Int>> = listOf(listOf(1, 2, 3), listOf(1, 2, 3), listOf(1, 2, 3))
) : JceStruct ) : JceStruct
...@@ -133,6 +206,28 @@ class JceDecoderTest { ...@@ -133,6 +206,28 @@ class JceDecoderTest {
}.readBytes().loadAs(TestNestedMap.serializer()).map.entries.first().value.contentToString(), "{01=[0x0002(2)]}") }.readBytes().loadAs(TestNestedMap.serializer()).map.entries.first().value.contentToString(), "{01=[0x0002(2)]}")
} }
@Test
fun testMap3() {
@Serializable
class TestNestedMap(
@SerialId(7) val map: Map<Byte, ShortArray> = mapOf(1.toByte() to shortArrayOf(2))
) : JceStruct
assertEquals("{0x01(1)=[0x0002(2)]}", buildJcePacket {
writeMap(mapOf(1.toByte() to shortArrayOf(2)), 7)
}.readBytes().loadAs(TestNestedMap.serializer()).map.contentToString())
}
@Test
fun testNestedMap2() {
@Serializable
class TestNestedMap(
@SerialId(7) val map: Map<Int, Map<Byte, ShortArray>> = mapOf(1 to mapOf(1.toByte() to shortArrayOf(2)))
) : JceStruct
assertEquals(buildJcePacket {
writeMap(mapOf(1 to mapOf(1.toByte() to shortArrayOf(2))), 7)
}.readBytes().loadAs(TestNestedMap.serializer()).map.entries.first().value.contentToString(), "{0x01(1)=[0x0002(2)]}")
}
@Test @Test
fun testNullableEncode() { fun testNullableEncode() {
...@@ -186,6 +281,9 @@ class JceDecoderTest { ...@@ -186,6 +281,9 @@ class JceDecoderTest {
@SerialId(0) val innerStructList: List<TestSimpleJceStruct> @SerialId(0) val innerStructList: List<TestSimpleJceStruct>
) : JceStruct ) : JceStruct
println(buildJcePacket {
writeCollection(listOf(TestSimpleJceStruct(), TestSimpleJceStruct()), 0)
}.readBytes().loadAs(OuterStruct.serializer()).innerStructList.toString())
assertEquals( assertEquals(
buildJcePacket { buildJcePacket {
writeCollection(listOf(TestSimpleJceStruct(), TestSimpleJceStruct()), 0) writeCollection(listOf(TestSimpleJceStruct(), TestSimpleJceStruct()), 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment