diff --git a/lib/helpers/model/castBulkWrite.js b/lib/helpers/model/castBulkWrite.js index 6420ad7e2e..6d7a780a81 100644 --- a/lib/helpers/model/castBulkWrite.js +++ b/lib/helpers/model/castBulkWrite.js @@ -104,7 +104,8 @@ module.exports = function castBulkWrite(originalModel, op, options) { op['updateOne']['update'] = castUpdate(model.schema, update, { strict: strict, upsert: op['updateOne'].upsert, - arrayFilters: op['updateOne'].arrayFilters + arrayFilters: op['updateOne'].arrayFilters, + overwriteDiscriminatorKey: op['updateOne'].overwriteDiscriminatorKey }, model, op['updateOne']['filter']); } catch (error) { return callback(error, null); @@ -164,7 +165,8 @@ module.exports = function castBulkWrite(originalModel, op, options) { op['updateMany']['update'] = castUpdate(model.schema, op['updateMany']['update'], { strict: strict, upsert: op['updateMany'].upsert, - arrayFilters: op['updateMany'].arrayFilters + arrayFilters: op['updateMany'].arrayFilters, + overwriteDiscriminatorKey: op['updateMany'].overwriteDiscriminatorKey }, model, op['updateMany']['filter']); } catch (error) { return callback(error, null); diff --git a/lib/helpers/query/castUpdate.js b/lib/helpers/query/castUpdate.js index d48c592755..3cf30cb0e1 100644 --- a/lib/helpers/query/castUpdate.js +++ b/lib/helpers/query/castUpdate.js @@ -8,6 +8,7 @@ const ValidationError = require('../../error/validation'); const castNumber = require('../../cast/number'); const cast = require('../../cast'); const getConstructorName = require('../getConstructorName'); +const getDiscriminatorByValue = require('../discriminator/getDiscriminatorByValue'); const getEmbeddedDiscriminatorPath = require('./getEmbeddedDiscriminatorPath'); const handleImmutable = require('./handleImmutable'); const moveImmutableProperties = require('../update/moveImmutableProperties'); @@ -62,6 +63,27 @@ module.exports = function castUpdate(schema, obj, options, context, filter) { return obj; } + if (schema != null && + filter != null && + utils.hasUserDefinedProperty(filter, schema.options.discriminatorKey) && + typeof filter[schema.options.discriminatorKey] !== 'object' && + schema.discriminators != null) { + const discriminatorValue = filter[schema.options.discriminatorKey]; + const byValue = getDiscriminatorByValue(context.model.discriminators, discriminatorValue); + schema = schema.discriminators[discriminatorValue] || + (byValue && byValue.schema) || + schema; + } else if (schema != null && + options.overwriteDiscriminatorKey && + utils.hasUserDefinedProperty(obj, schema.options.discriminatorKey) && + schema.discriminators != null) { + const discriminatorValue = obj[schema.options.discriminatorKey]; + const byValue = getDiscriminatorByValue(context.model.discriminators, discriminatorValue); + schema = schema.discriminators[discriminatorValue] || + (byValue && byValue.schema) || + schema; + } + if (options.upsert) { moveImmutableProperties(schema, obj, context); } diff --git a/lib/query.js b/lib/query.js index 32d8b1187a..a45d32fb91 100644 --- a/lib/query.js +++ b/lib/query.js @@ -4700,18 +4700,6 @@ Query.prototype._castUpdate = function _castUpdate(obj) { upsert = this.options.upsert; } - const filter = this._conditions; - if (schema != null && - utils.hasUserDefinedProperty(filter, schema.options.discriminatorKey) && - typeof filter[schema.options.discriminatorKey] !== 'object' && - schema.discriminators != null) { - const discriminatorValue = filter[schema.options.discriminatorKey]; - const byValue = getDiscriminatorByValue(this.model.discriminators, discriminatorValue); - schema = schema.discriminators[discriminatorValue] || - (byValue && byValue.schema) || - schema; - } - return castUpdate(schema, obj, { strict: this._mongooseOptions.strict, upsert: upsert, diff --git a/test/model.test.js b/test/model.test.js index 5bc1317e52..73d2e809ef 100644 --- a/test/model.test.js +++ b/test/model.test.js @@ -4174,6 +4174,51 @@ describe('Model', function() { assert.strictEqual(r2.testArray[0].nonexistentProp, undefined); }); + it('handles overwriteDiscriminatorKey (gh-15040)', async function() { + const dSchema1 = new mongoose.Schema({ + field1: String + }); + const dSchema2 = new mongoose.Schema({ + field2: String + }); + const baseSchema = new mongoose.Schema({ + field: String, + key: String + }, { discriminatorKey: 'key' }); + const type1Key = 'Type1'; + const type2Key = 'Type2'; + + baseSchema.discriminator(type1Key, dSchema1); + baseSchema.discriminator(type2Key, dSchema2); + + const TestModel = db.model('Test', baseSchema); + + const test = new TestModel({ + field: 'base field', + key: type1Key, + field1: 'field1' + }); + const r1 = await test.save(); + assert.equal(r1.field1, 'field1'); + assert.equal(r1.key, type1Key); + + const field2 = 'field2'; + await TestModel.bulkWrite([{ + updateOne: { + filter: { _id: r1._id }, + update: { + key: type2Key, + field2 + }, + overwriteDiscriminatorKey: true + } + }]); + + const r2 = await TestModel.findById(r1._id); + assert.equal(r2.key, type2Key); + assert.equal(r2.field2, field2); + }); + it('with child timestamps and array filters (gh-7032)', async function() { const childSchema = new Schema({ name: String }, { timestamps: true });