Commit e375d17f authored by Him188's avatar Him188

Support maps and nesting

parent 6731525e
...@@ -7,11 +7,12 @@ ...@@ -7,11 +7,12 @@
* https://github.com/mamoe/mirai/blob/master/LICENSE * https://github.com/mamoe/mirai/blob/master/LICENSE
*/ */
@file:Suppress("PrivatePropertyName")
package net.mamoe.mirai.qqandroid.io.serialization.jce package net.mamoe.mirai.qqandroid.io.serialization.jce
import kotlinx.serialization.* import kotlinx.serialization.*
import kotlinx.serialization.builtins.AbstractDecoder import kotlinx.serialization.builtins.AbstractDecoder
import kotlinx.serialization.builtins.ByteArraySerializer
import kotlinx.serialization.internal.TaggedDecoder import kotlinx.serialization.internal.TaggedDecoder
import kotlinx.serialization.modules.SerialModule import kotlinx.serialization.modules.SerialModule
import net.mamoe.mirai.qqandroid.io.serialization.Jce import net.mamoe.mirai.qqandroid.io.serialization.Jce
...@@ -38,16 +39,14 @@ internal class JceDecoder( ...@@ -38,16 +39,14 @@ internal class JceDecoder(
} }
private fun SerialDescriptor.getJceTagId(index: Int): Int { private fun SerialDescriptor.getJceTagId(index: Int): Int {
return getElementAnnotations(index).filterIsInstance<JceId>().single().id println("getTag: ${getElementName(index)}")
return getElementAnnotations(index).filterIsInstance<JceId>().singleOrNull()?.id
?: error("missing @JceId for ${getElementName(index)} in ${this.serialName}")
} }
private val SimpleByteArrayReader: SimpleByteArrayReaderImpl = SimpleByteArrayReaderImpl()
companion object { private inner class SimpleByteArrayReaderImpl : AbstractDecoder() {
private val ByteArraySerializer: KSerializer<ByteArray> = ByteArraySerializer()
}
// TODO: 2020/3/6 can be object
private inner class SimpleByteArrayReader : AbstractDecoder() {
override fun decodeSequentially(): Boolean = true override fun decodeSequentially(): Boolean = true
override fun endStructure(descriptor: SerialDescriptor) { override fun endStructure(descriptor: SerialDescriptor) {
...@@ -80,8 +79,9 @@ internal class JceDecoder( ...@@ -80,8 +79,9 @@ internal class JceDecoder(
} }
} }
// TODO: 2020/3/6 can be object private val ListReader: ListReaderImpl = ListReaderImpl()
private inner class ListReader : AbstractDecoder() {
private inner class ListReaderImpl : AbstractDecoder() {
override fun decodeSequentially(): Boolean = true override fun decodeSequentially(): Boolean = true
override fun decodeElementIndex(descriptor: SerialDescriptor): Int = error("should not be reached") override fun decodeElementIndex(descriptor: SerialDescriptor): Int = error("should not be reached")
override fun endStructure(descriptor: SerialDescriptor) { override fun endStructure(descriptor: SerialDescriptor) {
...@@ -113,33 +113,73 @@ internal class JceDecoder( ...@@ -113,33 +113,73 @@ internal class JceDecoder(
override fun endStructure(descriptor: SerialDescriptor) { override fun endStructure(descriptor: SerialDescriptor) {
println("endStructure: $descriptor") println("endStructure: $descriptor")
if (descriptor == ByteArraySerializer.descriptor) { if (currentTagOrNull?.isSimpleByteArray == true) {
jce.prepareNextHead() // list 里面没读 head jce.prepareNextHead() // read to next head
} else jce.prepareNextHead() // TODO ?? 测试这里 }
super.endStructure(descriptor) if (descriptor.kind == StructureKind.CLASS) {
if (currentTagOrNull == null) {
return
}
while (true) {
val currentHead = jce.currentHeadOrNull ?: return
if (currentHead.type == Jce.STRUCT_END) {
break
}
println("skipping")
jce.skipField(currentHead.type)
jce.prepareNextHead()
}
// pushTag(JceTag(0, true))
// skip STRUCT_END
// popTag()
}
} }
override fun beginStructure(descriptor: SerialDescriptor, vararg typeParams: KSerializer<*>): CompositeDecoder { override fun beginStructure(descriptor: SerialDescriptor, vararg typeParams: KSerializer<*>): CompositeDecoder {
println()
println("beginStructure: ${descriptor.serialName}") println("beginStructure: ${descriptor.serialName}")
return when (descriptor.kind) { return when (descriptor.kind) {
is PrimitiveKind -> this@JceDecoder
StructureKind.MAP -> { StructureKind.MAP -> {
error("map") println("!! MAP")
return jce.skipToHeadAndUseIfPossibleOrFail(popTag().id) {
it.checkType(Jce.MAP)
ListReader
}
} }
StructureKind.LIST -> { StructureKind.LIST -> {
println("!! ByteArray") println("!! ByteArray")
println("decoderTag: $currentTagOrNull") println("decoderTag: $currentTagOrNull")
println("jceHead: " + jce.currentHeadOrNull) println("jceHead: " + jce.currentHeadOrNull)
return jce.skipToHeadAndUseIfPossibleOrFail(popTag().id) { return jce.skipToHeadAndUseIfPossibleOrFail(currentTag.id) {
println("listHead: $it") println("listHead: $it")
when (it.type) { when (it.type) {
Jce.SIMPLE_LIST -> SimpleByteArrayReader().also { jce.prepareNextHead() } // 无用的元素类型 Jce.SIMPLE_LIST -> {
Jce.LIST -> ListReader() currentTag.isSimpleByteArray = true
jce.prepareNextHead() // 无用的元素类型
SimpleByteArrayReader
}
Jce.LIST -> ListReader
else -> error("type mismatch. Expected SIMPLE_LIST or LIST, got ${it.type} instead") else -> error("type mismatch. Expected SIMPLE_LIST or LIST, got ${it.type} instead")
} }
} }
} }
StructureKind.CLASS -> {
val currentTag = currentTagOrNull ?: return this@JceDecoder
println("!! CLASS")
println("decoderTag: $currentTag")
println("jceHead: " + jce.currentHeadOrNull)
return jce.skipToHeadAndUseIfPossibleOrFail(popTag().id) {
it.checkType(Jce.STRUCT_BEGIN)
this@JceDecoder
}
}
else -> this@JceDecoder StructureKind.OBJECT -> error("unsupported StructureKind.OBJECT: ${descriptor.serialName}")
is UnionKind -> error("unsupported UnionKind: ${descriptor.serialName}")
is PolymorphicKind -> error("unsupported PolymorphicKind: ${descriptor.serialName}")
} }
} }
...@@ -154,6 +194,10 @@ internal class JceDecoder( ...@@ -154,6 +194,10 @@ internal class JceDecoder(
override fun decodeSequentially(): Boolean = false override fun decodeSequentially(): Boolean = false
override fun decodeElementIndex(descriptor: SerialDescriptor): Int { override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
val jceHead = jce.currentHeadOrNull ?: return CompositeDecoder.READ_DONE val jceHead = jce.currentHeadOrNull ?: return CompositeDecoder.READ_DONE
if (jceHead.type == Jce.STRUCT_END) {
return CompositeDecoder.READ_DONE
}
repeat(descriptor.elementsCount) { repeat(descriptor.elementsCount) {
val tag = descriptor.getJceTagId(it) val tag = descriptor.getJceTagId(it)
if (tag == jceHead.tag) { if (tag == jceHead.tag) {
......
...@@ -30,7 +30,9 @@ annotation class JceId(val id: Int) ...@@ -30,7 +30,9 @@ annotation class JceId(val id: Int)
internal data class JceTag( internal data class JceTag(
val id: Int, val id: Int,
val isNullable: Boolean val isNullable: Boolean
) ){
internal var isSimpleByteArray: Boolean = false
}
fun JceHead.checkType(type: Byte) { fun JceHead.checkType(type: Byte) {
check(this.type == type) {"type mismatch. Expected $type, actual ${this.type}"} check(this.type == type) {"type mismatch. Expected $type, actual ${this.type}"}
......
...@@ -41,6 +41,57 @@ internal const val ZERO_TYPE: Byte = 12 ...@@ -41,6 +41,57 @@ internal const val ZERO_TYPE: Byte = 12
@Suppress("INVISIBLE_MEMBER") // bug @Suppress("INVISIBLE_MEMBER") // bug
internal class JceInputTest { internal class JceInputTest {
@Test
fun testNestedJceStruct() {
@Serializable
data class TestSerializableClassC(
@JceId(5) val value3: Int = 123123
)
@Serializable
data class TestSerializableClassB(
@JceId(0) val value: Int,
@JceId(123) val nested2: TestSerializableClassC
)
@Serializable
data class TestSerializableClassA(
@JceId(0) val value1: Int,
@JceId(1) val nestedStruct: TestSerializableClassB,
@JceId(2) val optional: Int = 3,
@JceId(4) val notOptional: Int
)
val input = buildPacket {
writeJceHead(INT, 0)
writeInt(444)
writeJceHead(STRUCT_BEGIN, 1); // TestSerializableClassB
{
writeJceHead(INT, 0)
writeInt(123)
writeJceHead(STRUCT_BEGIN, 123); // TestSerializableClassC
{
writeJceHead(INT, 5)
writeInt(123123)
}()
writeJceHead(STRUCT_END, 0)
writeJceHead(INT, 2) // 多余
writeInt(123)
}()
writeJceHead(STRUCT_END, 0)
writeJceHead(INT, 4)
writeInt(5)
}
assertEquals(
TestSerializableClassA(444, TestSerializableClassB(123, TestSerializableClassC(123123)), notOptional = 5),
JceNew.UTF_8.load(TestSerializableClassA.serializer(), input)
)
}
@Test @Test
fun testNestedList() { fun testNestedList() {
...@@ -80,6 +131,44 @@ internal class JceInputTest { ...@@ -80,6 +131,44 @@ internal class JceInputTest {
assertEquals(TestSerializableClassA(), JceNew.UTF_8.load(TestSerializableClassA.serializer(), input)) assertEquals(TestSerializableClassA(), JceNew.UTF_8.load(TestSerializableClassA.serializer(), input))
} }
@Test
fun testMap() {
@Serializable
data class TestSerializableClassA(
@JceId(0) val byteArray: Map<Int, Int>
)
val input = buildPacket {
writeJceHead(MAP, 0)
mapOf(1 to 2, 33 to 44).let {
writeJceHead(BYTE, 0)
writeByte(it.size.toByte())
it.forEach { (key, value) ->
writeJceHead(INT, 0)
writeInt(key)
writeJceHead(INT, 1)
writeInt(value)
}
}
writeJceHead(SIMPLE_LIST, 3)
writeJceHead(BYTE, 0)
byteArrayOf(1, 2, 3, 4).let {
writeJceHead(BYTE, 0)
writeByte(it.size.toByte())
writeFully(it)
}
}
assertEquals(
TestSerializableClassA(mapOf(1 to 2, 33 to 44)),
JceNew.UTF_8.load(TestSerializableClassA.serializer(), input)
)
}
@Test @Test
fun testSimpleByteArray() { fun testSimpleByteArray() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment