123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- package gorm
- import (
- "errors"
- "fmt"
- "reflect"
- "strings"
- )
- type JoinTableHandlerInterface interface {
-
- Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
-
- Table(db *DB) string
-
- Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
-
- Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
-
- JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
-
- SourceForeignKeys() []JoinTableForeignKey
-
- DestinationForeignKeys() []JoinTableForeignKey
- }
- type JoinTableForeignKey struct {
- DBName string
- AssociationDBName string
- }
- type JoinTableSource struct {
- ModelType reflect.Type
- ForeignKeys []JoinTableForeignKey
- }
- type JoinTableHandler struct {
- TableName string `sql:"-"`
- Source JoinTableSource `sql:"-"`
- Destination JoinTableSource `sql:"-"`
- }
- func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
- return s.Source.ForeignKeys
- }
- func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
- return s.Destination.ForeignKeys
- }
- func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
- s.TableName = tableName
- s.Source = JoinTableSource{ModelType: source}
- s.Source.ForeignKeys = []JoinTableForeignKey{}
- for idx, dbName := range relationship.ForeignFieldNames {
- s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
- DBName: relationship.ForeignDBNames[idx],
- AssociationDBName: dbName,
- })
- }
- s.Destination = JoinTableSource{ModelType: destination}
- s.Destination.ForeignKeys = []JoinTableForeignKey{}
- for idx, dbName := range relationship.AssociationForeignFieldNames {
- s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
- DBName: relationship.AssociationForeignDBNames[idx],
- AssociationDBName: dbName,
- })
- }
- }
- func (s JoinTableHandler) Table(db *DB) string {
- return DefaultTableNameHandler(db, s.TableName)
- }
- func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
- for _, source := range sources {
- scope := db.NewScope(source)
- modelType := scope.GetModelStruct().ModelType
- for _, joinTableSource := range joinTableSources {
- if joinTableSource.ModelType == modelType {
- for _, foreignKey := range joinTableSource.ForeignKeys {
- if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
- conditionMap[foreignKey.DBName] = field.Field.Interface()
- }
- }
- break
- }
- }
- }
- }
- func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
- var (
- scope = db.NewScope("")
- conditionMap = map[string]interface{}{}
- )
-
- s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
-
- s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
- var assignColumns, binVars, conditions []string
- var values []interface{}
- for key, value := range conditionMap {
- assignColumns = append(assignColumns, scope.Quote(key))
- binVars = append(binVars, `?`)
- conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
- values = append(values, value)
- }
- for _, value := range values {
- values = append(values, value)
- }
- quotedTable := scope.Quote(handler.Table(db))
- sql := fmt.Sprintf(
- "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
- quotedTable,
- strings.Join(assignColumns, ","),
- strings.Join(binVars, ","),
- scope.Dialect().SelectFromDummyTable(),
- quotedTable,
- strings.Join(conditions, " AND "),
- )
- return db.Exec(sql, values...).Error
- }
- func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
- var (
- scope = db.NewScope(nil)
- conditions []string
- values []interface{}
- conditionMap = map[string]interface{}{}
- )
- s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
- for key, value := range conditionMap {
- conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
- values = append(values, value)
- }
- return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
- }
- func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
- var (
- scope = db.NewScope(source)
- tableName = handler.Table(db)
- quotedTableName = scope.Quote(tableName)
- joinConditions []string
- values []interface{}
- )
- if s.Source.ModelType == scope.GetModelStruct().ModelType {
- destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
- for _, foreignKey := range s.Destination.ForeignKeys {
- joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
- }
- var foreignDBNames []string
- var foreignFieldNames []string
- for _, foreignKey := range s.Source.ForeignKeys {
- foreignDBNames = append(foreignDBNames, foreignKey.DBName)
- if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
- foreignFieldNames = append(foreignFieldNames, field.Name)
- }
- }
- foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
- var condString string
- if len(foreignFieldValues) > 0 {
- var quotedForeignDBNames []string
- for _, dbName := range foreignDBNames {
- quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
- }
- condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
- keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
- values = append(values, toQueryValues(keys))
- } else {
- condString = fmt.Sprintf("1 <> 1")
- }
- return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
- Where(condString, toQueryValues(foreignFieldValues)...)
- }
- db.Error = errors.New("wrong source type for join table handler")
- return db
- }
|