自定义Dubbo反序列化

自定义Dubbo反序列化,第1张

自定义Dubbo反序列化

最近在项目当中遇到一个问题,就是使用Dubbo进行调用服务时,实体类中使用 jackson 的JsonNode进行数据传送时,导致序列化失败的问题。这里记录一下Dubbo是如何进行自定义反序列化的。

1. 自定义序列化器

借鉴于Dubbo自带JavaSerializer器,自己修改了源码

public class CustomizeDataDeserializer extends AbstractMapDeserializer {

    private static final Logger LOG = Logger.getLogger(CustomizeDataDeserializer.class.getName());
    //序列化对象的类型
    private final Class type;
    //实体类中字段所需要的类型
    private final Map fieldMap;
    //获取实体类中readResolve方法
    private final Method readResolve;
    //获取构造函数
    private Constructor constructor;
    //获取构造函数的参数
    private Object[] constructorArgs;

    static final Map PRIMITIVE_TYPE = new HashMap(9) {

        private static final long serialVersionUID = -3972519966755913466L;

        {
            put(Boolean.class.getName(), true);
            put(Character.class.getName(), true);
            put(Byte.class.getName(), true);
            put(Short.class.getName(), true);
            put(Integer.class.getName(), true);
            put(Long.class.getName(), true);
            put(Float.class.getName(), true);
            put(Double.class.getName(), true);
            put(Void.class.getName(), true);
        }
    };

    public CustomizeDataDeserializer(Class cl) {
        type = cl;
        //解析字段中所需要的反序列化器
        this.fieldMap = new HashMap<>();
        getFieldMap(cl, this.fieldMap);
        //解析readReolve方法
        readResolve = getReadResolve(cl);
        if (readResolve != null) {
            readResolve.setAccessible(true);
        }
        //获取所有的构造函数
        Constructor[] constructors = cl.getDeclaredConstructors();
        long bestCost = Long.MAX_VALUE;

        for (int i = 0; i < constructors.length; i++) {
            Class[] param = constructors[i].getParameterTypes();
            long cost = 0;
            for (int j = 0; j < param.length; j++) {
                cost = 4 * cost;
                if (Object.class.equals(param[j])) {
                    cost += 1;
                } else if (String.class.equals(param[j])) {
                    cost += 2;
                } else if (int.class.equals(param[j])) {
                    cost += 3;
                } else if (long.class.equals(param[j])) {
                    cost += 4;
                } else if (param[j].isPrimitive()) {
                    cost += 5;
                } else {
                    cost += 6;
                }
            }
            if (cost < 0 || cost > (1 << 48)) {
                cost = 1 << 48;
            }
            cost += (long) param.length << 48;
            if (cost < bestCost) {
                constructor = constructors[i];
                bestCost = cost;
            }
        }
        //设置构造函数的权限以及获取参数
        if (constructor != null) {
            constructor.setAccessible(true);
            Class[] params = constructor.getParameterTypes();
            constructorArgs = new Object[params.length];
            for (int i = 0; i < params.length; i++) {
                constructorArgs[i] = getParamArg(params[i]);
            }
        }
    }

    protected static Object getParamArg(Class cl) {
        if (!cl.isPrimitive()) {
            return null;
        } else if (boolean.class.equals(cl)) {
            return Boolean.FALSE;
        } else if (byte.class.equals(cl)) {
            return new Byte((byte) 0);
        } else if (short.class.equals(cl)) {
            return new Short((short) 0);
        } else if (char.class.equals(cl)) {
            return new Character((char) 0);
        } else if (int.class.equals(cl)) {
            return Integer.valueOf(0);
        } else if (long.class.equals(cl)) {
            return Long.valueOf(0);
        } else if (float.class.equals(cl)) {
            return Float.valueOf(0);
        } else if (double.class.equals(cl)) {
            return Double.valueOf(0);
        } else {
            throw new UnsupportedOperationException();
        }
    }

    static void logDeserializeError(Field field, Object obj, Object value,
                                    Throwable e)
        throws IOException {
        String fieldName = (field.getDeclaringClass().getName()
                            + "." + field.getName());

        if (e instanceof HessianFieldException) {
            throw (HessianFieldException) e;
        } else if (e instanceof IOException) {
            throw new HessianFieldException(fieldName + ": " + e.getMessage(), e);
        }
        if (value != null) {
            throw new HessianFieldException(fieldName + ": " + value.getClass().getName() + " (" + value + ")"
                                            + " cannot be assigned to '" + field.getType().getName() + "'", e);
        } else {
            throw new HessianFieldException(fieldName + ": " + field.getType().getName() + " cannot be assigned from null", e);
        }
    }

    @Override
    public Class getType() {
        return type;
    }

    @Override
    public Object readMap(AbstractHessianInput in)
        throws IOException {
        try {
        	//实例化对象出来
            Object obj = instantiate();
            return readMap(in, obj);
        } catch (IOException | RuntimeException e) {
            throw e;
        } catch (Exception e) {
            throw new IOExceptionWrapper(type.getName() + ":" + e.getMessage(), e);
        }
    }

    @Override
    public Object readObject(AbstractHessianInput in, String[] fieldNames)
        throws IOException {
        try {
        	//实例化对象
            Object obj = instantiate();
            return readObject(in, obj, fieldNames);
        } catch (IOException | RuntimeException e) {
            throw e;
        } catch (Exception e) {
            throw new IOExceptionWrapper(type.getName() + ":" + e.getMessage(), e);
        }
    }

    protected Method getReadResolve(Class cl) {
        for (; cl != null; cl = cl.getSuperclass()) {
            Method[] methods = cl.getDeclaredMethods();
            for (int i = 0; i < methods.length; i++) {
                Method method = methods[i];
                if (method.getName().equals("readResolve") && method.getParameterTypes().length == 0) {
                    return method;
                }
            }
        }

        return null;
    }

    public Object readMap(AbstractHessianInput in, Object obj)
        throws IOException {
        try {
            int ref = in.addRef(obj);

            while (!in.isEnd()) {
                Object key = in.readObject();
                FieldDeserializer deser = fieldMap.get(key);
                if (deser != null) {
                    deser.deserialize(in, obj);
                } else {
                    in.readObject();
                }
            }
            in.readMapEnd();
            Object resolve = resolve(obj);
            if (obj != resolve) {
                in.setRef(ref, resolve);
            }
            return resolve;
        } catch (IOException e) {
            throw e;
        } catch (Exception e) {
            throw new IOExceptionWrapper(e);
        }
    }

    public Object readObject(AbstractHessianInput in,
                             Object obj,
                             String[] fieldNames)
        throws IOException {
        try {
        	//将实例化的对象添加到HessianInput引用中去
            int ref = in.addRef(obj);
			//编译反序列化器,根据对应的数据类型取出
            for (int i = 0; i < fieldNames.length; i++) {
                String name = fieldNames[i];
                FieldDeserializer deser = fieldMap.get(name);
                //如果缓存的map中没有找到反序列化器,直接到序列化工厂中查询序列化器
                if (deser != null) {
                    deser.deserialize(in, obj);
                } else {
                	//序列化工厂中查询反序列化器(还是当前类,并且调用readObject()方法)
                    in.readObject();
                }
            }
            Object resolve = resolve(obj);
            if (obj != resolve) {
                in.setRef(ref, resolve);
            }
            return resolve;
        } catch (IOException e) {
            throw e;
        } catch (Exception e) {
            throw new IOExceptionWrapper(obj.getClass().getName() + ":" + e, e);
        }
    }

    private Object resolve(Object obj)
        throws Exception {
        try {
            if (readResolve != null) {
                return readResolve.invoke(obj, new Object[0]);
            }
        } catch (InvocationTargetException e) {
            if (e.getTargetException() != null) {
                throw e;
            }
        }

        return obj;
    }

    protected Object instantiate()
        throws Exception {
        try {
            if (constructor != null) {
            	//通过构造函数创建实例对象
                return constructor.newInstance(constructorArgs);
            } else {
                return type.newInstance();
            }
        } catch (Exception e) {
            throw new HessianProtocolException("'" + type.getName() + "' could not be instantiated", e);
        }
    }
    //读取实体类中字段的类型,并且创建对应的解析器
    protected void getFieldMap(Class cl, Map fieldMap) {
        for (; cl != null; cl = cl.getSuperclass()) {
            Field[] fields = cl.getDeclaredFields();
            for (int i = 0; i < fields.length; i++) {
                Field field = fields[i];
                if (Modifier.isTransient(field.getModifiers()) || Modifier.isStatic(field.getModifiers())) {
                    continue;
                } else if (fieldMap.get(field.getName()) != null) {
                    continue;
                }
                try {
                    field.setAccessible(true);
                } catch (Throwable e) {
                    LOG.warning("字段权限设置失败");
                }
                Class type = field.getType();
                FieldDeserializer deser;
                if (String.class.equals(type)) {
                    deser = new StringFieldDeserializer(field);
                } else if (byte.class.equals(type)) {
                    deser = new ByteFieldDeserializer(field);
                } else if (short.class.equals(type)) {
                    deser = new ShortFieldDeserializer(field);
                } else if (int.class.equals(type)) {
                    deser = new IntFieldDeserializer(field);
                } else if (long.class.equals(type)) {
                    deser = new LongFieldDeserializer(field);
                } else if (float.class.equals(type)) {
                    deser = new FloatFieldDeserializer(field);
                } else if (double.class.equals(type)) {
                    deser = new DoubleFieldDeserializer(field);
                } else if (boolean.class.equals(type)) {
                    deser = new BooleanFieldDeserializer(field);
                } else if (java.sql.Date.class.equals(type)) {
                    deser = new SqlDateFieldDeserializer(field);
                } else if (java.sql.Timestamp.class.equals(type)) {
                    deser = new SqlTimestampFieldDeserializer(field);
                } else if (java.sql.Time.class.equals(type)) {
                    deser = new SqlTimeFieldDeserializer(field);
                } else if (Map.class.equals(type)
                           && field.getGenericType() != field.getType()) {
                    deser = new ObjectMapFieldDeserializer(field);
                } else if (List.class.equals(type)
                           && field.getGenericType() != field.getType()) {
                    deser = new ObjectListFieldDeserializer(field);
                } else if (Set.class.equals(type)
                           && field.getGenericType() != field.getType()) {
                    deser = new ObjectSetFieldDeserializer(field);
                } else if (JsonNode.class.equals(type)) {
                  	//如果是JsonNode,使用对应的反序列化方式
                    deser = new JsonNodeDeserializer(field);
                } else {
                    deser = new ObjectFieldDeserializer(field);
                }
                fieldMap.putIfAbsent(field.getName(), deser);
            }
        }
    }

	//抽象反序列化器
    abstract static class FieldDeserializer {
        abstract void deserialize(AbstractHessianInput in, Object obj)
            throws IOException;
    }
	//对象反序列化器
    static class ObjectFieldDeserializer extends FieldDeserializer {
        private final Field field;

        ObjectFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            Object value = null;

            try {
                value = in.readObject(field.getType());
                field.set(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class BooleanFieldDeserializer extends FieldDeserializer {
        private final Field field;

        BooleanFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            boolean value = false;

            try {
                value = in.readBoolean();

                field.setBoolean(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class ByteFieldDeserializer extends FieldDeserializer {
        private final Field field;

        ByteFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            int value = 0;

            try {
                value = in.readInt();

                field.setByte(obj, (byte) value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class ShortFieldDeserializer extends FieldDeserializer {
        private final Field field;

        ShortFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            int value = 0;

            try {
                value = in.readInt();

                field.setShort(obj, (short) value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class ObjectMapFieldDeserializer extends FieldDeserializer {
        private final Field field;

        ObjectMapFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            Object value = null;

            try {

                Type[] types = ((ParameterizedType) field.getGenericType()).getActualTypeArguments();
                value = in.readObject(field.getType(),
                                      isPrimitive(types[0]) ? (Class) types[0] : null,
                                      isPrimitive(types[1]) ? (Class) types[1] : null
                                     );

                field.set(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class ObjectListFieldDeserializer extends FieldDeserializer {
        private final Field field;

        ObjectListFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            Object value = null;

            try {

                Type[] types = ((ParameterizedType) field.getGenericType()).getActualTypeArguments();
                value = in.readObject(field.getType(),
                                      isPrimitive(types[0]) ? (Class) types[0] : null
                                     );

                field.set(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class ObjectSetFieldDeserializer extends FieldDeserializer {
        private final Field field;

        ObjectSetFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            Object value = null;

            try {

                Type[] types = ((ParameterizedType) field.getGenericType()).getActualTypeArguments();
                value = in.readObject(field.getType(),
                                      isPrimitive(types[0]) ? (Class) types[0] : null
                                     );

                field.set(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }


    static class IntFieldDeserializer extends FieldDeserializer {
        private final Field field;

        IntFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            int value = 0;

            try {
                value = in.readInt();

                field.setInt(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class LongFieldDeserializer extends FieldDeserializer {
        private final Field field;

        LongFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            long value = 0;

            try {
                value = in.readLong();

                field.setLong(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class FloatFieldDeserializer extends FieldDeserializer {
        private final Field field;

        FloatFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            double value = 0;

            try {
                value = in.readDouble();

                field.setFloat(obj, (float) value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class DoubleFieldDeserializer extends FieldDeserializer {
        private final Field field;

        DoubleFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            double value = 0;

            try {
                value = in.readDouble();

                field.setDouble(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class StringFieldDeserializer extends FieldDeserializer {
        private final Field field;

        StringFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            String value = null;

            try {
                value = in.readString();

                field.set(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class SqlDateFieldDeserializer extends FieldDeserializer {
        private final Field field;

        SqlDateFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            java.sql.Date value = null;

            try {
                java.util.Date date = (java.util.Date) in.readObject();
                if (date != null) {
                    value = new java.sql.Date(date.getTime());
                }

                field.set(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class SqlTimestampFieldDeserializer extends FieldDeserializer {
        private final Field field;

        SqlTimestampFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            java.sql.Timestamp value = null;

            try {
                java.util.Date date = (java.util.Date) in.readObject();
                if (date != null) {
                    value = new java.sql.Timestamp(date.getTime());
                }

                field.set(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    static class SqlTimeFieldDeserializer extends FieldDeserializer {
        private final Field field;

        SqlTimeFieldDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj)
            throws IOException {
            java.sql.Time value = null;

            try {
                java.util.Date date = (java.util.Date) in.readObject();
                if (date != null) {
                    value = new java.sql.Time(date.getTime());
                }

                field.set(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

	//JsonNode 数据反序列化方式
    static class JsonNodeDeserializer extends FieldDeserializer {

        private final Field field;

        JsonNodeDeserializer(Field field) {
            this.field = field;
        }

        @Override
        void deserialize(AbstractHessianInput in, Object obj) throws IOException {
            JsonNode value = null;
            try {
            	//直接将数据读取出来后强制转换成JsonNode,因为Dubbo序列化时会把类型一起写如到流中,读取出来数据会转换成指定的类型
                value = (JsonNode) in.readObject();
                field.set(obj, value);
            } catch (Exception e) {
                logDeserializeError(field, obj, value, e);
            }
        }
    }

    private static boolean isPrimitive(Type type) {
        try {
            if (type != null) {
                if (type instanceof Class) {
                    Class clazz = (Class) type;
                    return clazz.isPrimitive() || PRIMITIVE_TYPE.containsKey(clazz.getName());
                }
            }
        } catch (Exception e) {
            // ignore exception
        }
        return false;
    }

}

2. 创建序列化工厂
public class DubboSerializerFactory extends SerializerFactory {

    
    public static final SerializerFactory SERIALIZER_FACTORY = new DubboSerializerFactory();

    
    public DubboSerializerFactory() {

    }

    
    @Override
    public Deserializer getDefaultDeserializer(Class cl) {
        return new CustomizeDataDeserializer(cl);
    }
}

3. 创建序列化对象输入流
public class DubboSeiralizerObjectInput implements ObjectInput {

    
    private final Hessian2Input h2i;

    
    public DubboSeiralizerObjectInput(InputStream is) {
        h2i = new Hessian2Input(is);
        //设置反序列化工厂
        h2i.setSerializerFactory(DubboSerializerFactory.SERIALIZER_FACTORY);
    }

    
    @Override
    public boolean readBool() throws IOException {
        return this.h2i.readBoolean();
    }

    
    @Override
    public byte readByte() throws IOException {
        return (byte) h2i.readInt();
    }

    
    @Override
    public short readShort() throws IOException {
        return (short) h2i.readInt();
    }

    
    @Override
    public int readInt() throws IOException {
        return h2i.readInt();
    }

    
    @Override
    public long readLong() throws IOException {
        return h2i.readLong();
    }

    
    @Override
    public float readFloat() throws IOException {
        return (float) h2i.readDouble();
    }

    
    @Override
    public double readDouble() throws IOException {
        return h2i.readDouble();
    }

    
    @Override
    public byte[] readBytes() throws IOException {
        return h2i.readBytes();
    }

    
    @Override
    @SuppressWarnings({"checkstyle:LowerCamelCaseVariableNamingRule",
                       "PMD.LowerCamelCaseVariableNamingRule"})
    public String readUTF() throws IOException {
        return h2i.readString();
    }

    
    @Override
    public Object readObject() throws IOException {
        return h2i.readObject();
    }

    
    @Override
    @SuppressWarnings("unchecked")
    public  T readObject(Class cls) throws IOException {
        return (T) h2i.readObject(cls);
    }

    
    @Override
    public  T readObject(Class cls, Type type) throws IOException {
        return readObject(cls);
    }
}

4. 自定义协议序列化器

这里只实现了反序列化方式,并没有自定义序列化方式

public class CustomizeDataSerialization implements Serialization {
    
    @Override
    public byte getContentTypeId() {
        return 26;
    }

    
    @Override
    public String getContentType() {
        return "x-application/customize";
    }

    
    @Override
    public ObjectOutput serialize(URL url, OutputStream output) throws IOException {
        return new Hessian2ObjectOutput(output);
    }

    
    @Override
    public ObjectInput deserialize(URL url, InputStream input) throws IOException {
        return new DubboSeiralizerObjectInput(input);
    }
}
5. 创建Spi文件

在resources文件下创建路径:

内容:序列化名称+类路径
customize=com.xxx.client.dubbo.serialization.CustomizeDataSerialization

配置文件中引用就可以了

server:
  port: 18104
dubbo:
  protocol:
    port: 28033
  provider:
    serialization: customize #自定义序列化方式

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5706536.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)

保存