diff --git a/mongoext/registry.go b/mongoext/registry.go index 9a5ebaf..c14eb44 100644 --- a/mongoext/registry.go +++ b/mongoext/registry.go @@ -16,6 +16,9 @@ func CreateGoExtBsonRegistry() *bsoncodec.Registry { rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339NanoTime{}), rfctime.RFC3339NanoTime{}) rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.RFC3339NanoTime{}), rfctime.RFC3339NanoTime{}) + rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.Date{}), rfctime.Date{}) + rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.Date{}), rfctime.Date{}) + bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb) bsoncodec.DefaultValueDecoders{}.RegisterDefaultDecoders(rb) diff --git a/rfctime/date.go b/rfctime/date.go new file mode 100644 index 0000000..38f9ef3 --- /dev/null +++ b/rfctime/date.go @@ -0,0 +1,240 @@ +package rfctime + +import ( + "encoding/json" + "errors" + "fmt" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" + "go.mongodb.org/mongo-driver/bson/bsontype" + "reflect" + "time" +) + +type Date struct { + Year int + Month int + Day int +} + +func (t Date) Time(loc *time.Location) time.Time { + return time.Date(t.Year, time.Month(t.Month), t.Day, 0, 0, 0, 0, loc) +} + +func (t Date) TimeUTC() time.Time { + return time.Date(t.Year, time.Month(t.Month), t.Day, 0, 0, 0, 0, time.UTC) +} + +func (t Date) TimeLocal() time.Time { + return time.Date(t.Year, time.Month(t.Month), t.Day, 0, 0, 0, 0, time.Local) +} + +func (t Date) MarshalBinary() ([]byte, error) { + return t.TimeUTC().MarshalBinary() +} + +func (t *Date) UnmarshalBinary(data []byte) error { + nt := time.Time{} + if err := nt.UnmarshalBinary(data); err != nil { + return err + } + t.Year = nt.Year() + t.Month = int(nt.Month()) + t.Day = nt.Day() + return nil +} + +func (t Date) GobEncode() ([]byte, error) { + return t.TimeUTC().GobEncode() +} + +func (t *Date) GobDecode(data []byte) error { + nt := time.Time{} + if err := nt.GobDecode(data); err != nil { + return err + } + t.Year = nt.Year() + t.Month = int(nt.Month()) + t.Day = nt.Day() + return nil +} + +func (t *Date) UnmarshalJSON(data []byte) error { + str := "" + if err := json.Unmarshal(data, &str); err != nil { + return err + } + t0, err := time.Parse(t.FormatStr(), str) + if err != nil { + return err + } + t.Year = t0.Year() + t.Month = int(t0.Month()) + t.Day = t0.Day() + return nil +} + +func (t Date) MarshalJSON() ([]byte, error) { + str := t.TimeUTC().Format(t.FormatStr()) + return json.Marshal(str) +} + +func (t Date) MarshalText() ([]byte, error) { + b := make([]byte, 0, len(t.FormatStr())) + return t.TimeUTC().AppendFormat(b, t.FormatStr()), nil +} + +func (t *Date) UnmarshalText(data []byte) error { + var err error + v, err := time.Parse(t.FormatStr(), string(data)) + if err != nil { + return err + } + t.Year = v.Year() + t.Month = int(v.Month()) + t.Day = v.Day() + return nil +} + +func (t *Date) UnmarshalBSONValue(bt bsontype.Type, data []byte) error { + if bt == bsontype.Null { + // we can't set nil in UnmarshalBSONValue (so we use default(struct)) + // Use mongoext.CreateGoExtBsonRegistry if you need to unmarsh pointer values + // https://stackoverflow.com/questions/75167597 + // https://jira.mongodb.org/browse/GODRIVER-2252 + *t = Date{} + return nil + } + if bt != bsontype.String { + return errors.New(fmt.Sprintf("cannot unmarshal %v into Date", bt)) + } + + var tt string + err := bson.RawValue{Type: bt, Value: data}.Unmarshal(&tt) + if err != nil { + return err + } + + v, err := time.Parse(t.FormatStr(), tt) + if err != nil { + return err + } + t.Year = v.Year() + t.Month = int(v.Month()) + t.Day = v.Day() + + return nil +} + +func (t Date) MarshalBSONValue() (bsontype.Type, []byte, error) { + return bson.MarshalValue(t.TimeUTC().Format(t.FormatStr())) +} + +func (t Date) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if val.Kind() == reflect.Ptr && val.IsNil() { + if !val.CanSet() { + return errors.New("ValueUnmarshalerDecodeValue") + } + val.Set(reflect.New(val.Type().Elem())) + } + + tp, src, err := bsonrw.Copier{}.CopyValueToBytes(vr) + if err != nil { + return err + } + + if val.Kind() == reflect.Ptr && len(src) == 0 { + val.Set(reflect.Zero(val.Type())) + return nil + } + + err = t.UnmarshalBSONValue(tp, src) + if err != nil { + return err + } + + if val.Kind() == reflect.Ptr { + val.Set(reflect.ValueOf(&t)) + } else { + val.Set(reflect.ValueOf(t)) + } + + return nil +} + +func (t Date) Serialize() string { + return t.TimeUTC().Format(t.FormatStr()) +} + +func (t Date) FormatStr() string { + return "2006-01-02" +} + +func (t Date) Date() (year int, month time.Month, day int) { + return t.TimeUTC().Date() +} + +func (t Date) Weekday() time.Weekday { + return t.TimeUTC().Weekday() +} + +func (t Date) ISOWeek() (year, week int) { + return t.TimeUTC().ISOWeek() +} + +func (t Date) YearDay() int { + return t.TimeUTC().YearDay() +} + +func (t Date) AddDate(years int, months int, days int) Date { + return NewDate(t.TimeUTC().AddDate(years, months, days)) +} + +func (t Date) Unix() int64 { + return t.TimeUTC().Unix() +} + +func (t Date) UnixMilli() int64 { + return t.TimeUTC().UnixMilli() +} + +func (t Date) UnixMicro() int64 { + return t.TimeUTC().UnixMicro() +} + +func (t Date) UnixNano() int64 { + return t.TimeUTC().UnixNano() +} + +func (t Date) Format(layout string) string { + return t.TimeUTC().Format(layout) +} + +func (t Date) GoString() string { + return t.TimeUTC().GoString() +} + +func (t Date) String() string { + return t.TimeUTC().String() +} + +func NewDate(t time.Time) Date { + return Date{ + Year: t.Year(), + Month: int(t.Month()), + Day: t.Day(), + } +} + +func NowDate(loc *time.Location) Date { + return NewDate(time.Now().In(loc)) +} + +func NowDateLoc() Date { + return NewDate(time.Now().In(time.UTC)) +} + +func NowDateUTC() Date { + return NewDate(time.Now().In(time.Local)) +}